mirror of
https://github.com/microsoft/onefuzz.git
synced 2025-06-16 20:08:09 +00:00
Use Storage Account types, rather than account_id (#320)
We need to move to supporting data sharding. One of the steps towards that is stop passing around `account_id`, rather we need to specify the type of storage we need.
This commit is contained in:
@ -6,37 +6,61 @@
|
||||
import datetime
|
||||
import os
|
||||
import urllib.parse
|
||||
from typing import Dict, Optional, Union, cast
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional, Union, cast
|
||||
|
||||
from azure.common import AzureHttpError, AzureMissingResourceHttpError
|
||||
from azure.storage.blob import BlobPermissions, ContainerPermissions
|
||||
from memoization import cached
|
||||
|
||||
from .creds import get_blob_service
|
||||
from .creds import get_blob_service, get_func_storage, get_fuzz_storage
|
||||
|
||||
|
||||
class StorageType(Enum):
|
||||
corpus = "corpus"
|
||||
config = "config"
|
||||
|
||||
|
||||
def get_account_id_by_type(storage_type: StorageType) -> str:
|
||||
if storage_type == StorageType.corpus:
|
||||
account_id = get_fuzz_storage()
|
||||
elif storage_type == StorageType.config:
|
||||
account_id = get_func_storage()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return account_id
|
||||
|
||||
|
||||
@cached(ttl=5)
|
||||
def container_exists(name: str, account_id: Optional[str] = None) -> bool:
|
||||
def get_blob_service_by_type(storage_type: StorageType) -> Any:
|
||||
account_id = get_account_id_by_type(storage_type)
|
||||
return get_blob_service(account_id)
|
||||
|
||||
|
||||
@cached(ttl=5)
|
||||
def container_exists(name: str, storage_type: StorageType) -> bool:
|
||||
try:
|
||||
get_blob_service(account_id).get_container_properties(name)
|
||||
get_blob_service_by_type(storage_type).get_container_properties(name)
|
||||
return True
|
||||
except AzureHttpError:
|
||||
return False
|
||||
|
||||
|
||||
def get_containers(account_id: Optional[str] = None) -> Dict[str, Dict[str, str]]:
|
||||
def get_containers(storage_type: StorageType) -> Dict[str, Dict[str, str]]:
|
||||
return {
|
||||
x.name: x.metadata
|
||||
for x in get_blob_service(account_id).list_containers(include_metadata=True)
|
||||
for x in get_blob_service_by_type(storage_type).list_containers(
|
||||
include_metadata=True
|
||||
)
|
||||
if not x.name.startswith("$")
|
||||
}
|
||||
|
||||
|
||||
def get_container_metadata(
|
||||
name: str, account_id: Optional[str] = None
|
||||
name: str, storage_type: StorageType
|
||||
) -> Optional[Dict[str, str]]:
|
||||
try:
|
||||
result = get_blob_service(account_id).get_container_metadata(name)
|
||||
result = get_blob_service_by_type(storage_type).get_container_metadata(name)
|
||||
return cast(Dict[str, str], result)
|
||||
except AzureHttpError:
|
||||
pass
|
||||
@ -44,22 +68,29 @@ def get_container_metadata(
|
||||
|
||||
|
||||
def create_container(
|
||||
name: str, metadata: Optional[Dict[str, str]], account_id: Optional[str] = None
|
||||
name: str, storage_type: StorageType, metadata: Optional[Dict[str, str]]
|
||||
) -> Optional[str]:
|
||||
try:
|
||||
get_blob_service(account_id).create_container(name, metadata=metadata)
|
||||
get_blob_service_by_type(storage_type).create_container(name, metadata=metadata)
|
||||
except AzureHttpError:
|
||||
# azure storage already logs errors
|
||||
return None
|
||||
|
||||
return get_container_sas_url(
|
||||
name, read=True, add=True, create=True, write=True, delete=True, list=True
|
||||
name,
|
||||
storage_type,
|
||||
read=True,
|
||||
add=True,
|
||||
create=True,
|
||||
write=True,
|
||||
delete=True,
|
||||
list=True,
|
||||
)
|
||||
|
||||
|
||||
def delete_container(name: str, account_id: Optional[str] = None) -> bool:
|
||||
def delete_container(name: str, storage_type: StorageType) -> bool:
|
||||
try:
|
||||
return bool(get_blob_service(account_id).delete_container(name))
|
||||
return bool(get_blob_service_by_type(storage_type).delete_container(name))
|
||||
except AzureHttpError:
|
||||
# azure storage already logs errors
|
||||
return False
|
||||
@ -67,7 +98,8 @@ def delete_container(name: str, account_id: Optional[str] = None) -> bool:
|
||||
|
||||
def get_container_sas_url(
|
||||
container: str,
|
||||
account_id: Optional[str] = None,
|
||||
storage_type: StorageType,
|
||||
*,
|
||||
read: bool = False,
|
||||
add: bool = False,
|
||||
create: bool = False,
|
||||
@ -75,7 +107,7 @@ def get_container_sas_url(
|
||||
delete: bool = False,
|
||||
list: bool = False,
|
||||
) -> str:
|
||||
service = get_blob_service(account_id)
|
||||
service = get_blob_service_by_type(storage_type)
|
||||
expiry = datetime.datetime.utcnow() + datetime.timedelta(days=30)
|
||||
permission = ContainerPermissions(read, add, create, write, delete, list)
|
||||
|
||||
@ -91,7 +123,8 @@ def get_container_sas_url(
|
||||
def get_file_sas_url(
|
||||
container: str,
|
||||
name: str,
|
||||
account_id: Optional[str] = None,
|
||||
storage_type: StorageType,
|
||||
*,
|
||||
read: bool = False,
|
||||
add: bool = False,
|
||||
create: bool = False,
|
||||
@ -102,7 +135,7 @@ def get_file_sas_url(
|
||||
hours: int = 0,
|
||||
minutes: int = 0,
|
||||
) -> str:
|
||||
service = get_blob_service(account_id)
|
||||
service = get_blob_service_by_type(storage_type)
|
||||
expiry = datetime.datetime.utcnow() + datetime.timedelta(
|
||||
days=days, hours=hours, minutes=minutes
|
||||
)
|
||||
@ -117,9 +150,9 @@ def get_file_sas_url(
|
||||
|
||||
|
||||
def save_blob(
|
||||
container: str, name: str, data: Union[str, bytes], account_id: Optional[str] = None
|
||||
container: str, name: str, data: Union[str, bytes], storage_type: StorageType
|
||||
) -> None:
|
||||
service = get_blob_service(account_id)
|
||||
service = get_blob_service_by_type(storage_type)
|
||||
service.create_container(container)
|
||||
if isinstance(data, str):
|
||||
service.create_blob_from_text(container, name, data)
|
||||
@ -127,10 +160,8 @@ def save_blob(
|
||||
service.create_blob_from_bytes(container, name, data)
|
||||
|
||||
|
||||
def get_blob(
|
||||
container: str, name: str, account_id: Optional[str] = None
|
||||
) -> Optional[bytes]:
|
||||
service = get_blob_service(account_id)
|
||||
def get_blob(container: str, name: str, storage_type: StorageType) -> Optional[bytes]:
|
||||
service = get_blob_service_by_type(storage_type)
|
||||
try:
|
||||
blob = service.get_blob_to_bytes(container, name).content
|
||||
return cast(bytes, blob)
|
||||
@ -138,8 +169,8 @@ def get_blob(
|
||||
return None
|
||||
|
||||
|
||||
def blob_exists(container: str, name: str, account_id: Optional[str] = None) -> bool:
|
||||
service = get_blob_service(account_id)
|
||||
def blob_exists(container: str, name: str, storage_type: StorageType) -> bool:
|
||||
service = get_blob_service_by_type(storage_type)
|
||||
try:
|
||||
service.get_blob_properties(container, name)
|
||||
return True
|
||||
@ -147,8 +178,8 @@ def blob_exists(container: str, name: str, account_id: Optional[str] = None) ->
|
||||
return False
|
||||
|
||||
|
||||
def delete_blob(container: str, name: str, account_id: Optional[str] = None) -> bool:
|
||||
service = get_blob_service(account_id)
|
||||
def delete_blob(container: str, name: str, storage_type: StorageType) -> bool:
|
||||
service = get_blob_service_by_type(storage_type)
|
||||
try:
|
||||
service.delete_blob(container, name)
|
||||
return True
|
||||
|
@ -87,12 +87,12 @@ def get_insights_appid() -> str:
|
||||
return os.environ["APPINSIGHTS_APPID"]
|
||||
|
||||
|
||||
@cached
|
||||
# @cached
|
||||
def get_fuzz_storage() -> str:
|
||||
return os.environ["ONEFUZZ_DATA_STORAGE"]
|
||||
|
||||
|
||||
@cached
|
||||
# @cached
|
||||
def get_func_storage() -> str:
|
||||
return os.environ["ONEFUZZ_FUNC_STORAGE"]
|
||||
|
||||
@ -109,9 +109,9 @@ def get_instance_url() -> str:
|
||||
|
||||
@cached
|
||||
def get_instance_id() -> UUID:
|
||||
from .containers import get_blob
|
||||
from .containers import StorageType, get_blob
|
||||
|
||||
blob = get_blob("base-config", "instance_id", account_id=get_func_storage())
|
||||
blob = get_blob("base-config", "instance_id", StorageType.config)
|
||||
if blob is None:
|
||||
raise Exception("missing instance_id")
|
||||
return UUID(blob.decode())
|
||||
|
@ -19,6 +19,7 @@ from azure.storage.queue import (
|
||||
from memoization import cached
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .containers import StorageType, get_account_id_by_type
|
||||
from .creds import get_storage_account_name_key
|
||||
|
||||
QueueNameType = Union[str, UUID]
|
||||
@ -27,7 +28,8 @@ DEFAULT_TTL = -1
|
||||
|
||||
|
||||
@cached(ttl=60)
|
||||
def get_queue_client(account_id: str) -> QueueServiceClient:
|
||||
def get_queue_client(storage_type: StorageType) -> QueueServiceClient:
|
||||
account_id = get_account_id_by_type(storage_type)
|
||||
logging.debug("getting blob container (account_id: %s)", account_id)
|
||||
name, key = get_storage_account_name_key(account_id)
|
||||
account_url = "https://%s.queue.core.windows.net" % name
|
||||
@ -41,13 +43,14 @@ def get_queue_client(account_id: str) -> QueueServiceClient:
|
||||
@cached(ttl=60)
|
||||
def get_queue_sas(
|
||||
queue: QueueNameType,
|
||||
storage_type: StorageType,
|
||||
*,
|
||||
account_id: str,
|
||||
read: bool = False,
|
||||
add: bool = False,
|
||||
update: bool = False,
|
||||
process: bool = False,
|
||||
) -> str:
|
||||
account_id = get_account_id_by_type(storage_type)
|
||||
logging.debug("getting queue sas %s (account_id: %s)", queue, account_id)
|
||||
name, key = get_storage_account_name_key(account_id)
|
||||
expiry = datetime.datetime.utcnow() + datetime.timedelta(days=30)
|
||||
@ -67,31 +70,33 @@ def get_queue_sas(
|
||||
|
||||
|
||||
@cached(ttl=60)
|
||||
def create_queue(name: QueueNameType, *, account_id: str) -> None:
|
||||
client = get_queue_client(account_id)
|
||||
def create_queue(name: QueueNameType, storage_type: StorageType) -> None:
|
||||
client = get_queue_client(storage_type)
|
||||
try:
|
||||
client.create_queue(str(name))
|
||||
except ResourceExistsError:
|
||||
pass
|
||||
|
||||
|
||||
def delete_queue(name: QueueNameType, *, account_id: str) -> None:
|
||||
client = get_queue_client(account_id)
|
||||
def delete_queue(name: QueueNameType, storage_type: StorageType) -> None:
|
||||
client = get_queue_client(storage_type)
|
||||
queues = client.list_queues()
|
||||
if str(name) in [x["name"] for x in queues]:
|
||||
client.delete_queue(name)
|
||||
|
||||
|
||||
def get_queue(name: QueueNameType, *, account_id: str) -> Optional[QueueServiceClient]:
|
||||
client = get_queue_client(account_id)
|
||||
def get_queue(
|
||||
name: QueueNameType, storage_type: StorageType
|
||||
) -> Optional[QueueServiceClient]:
|
||||
client = get_queue_client(storage_type)
|
||||
try:
|
||||
return client.get_queue_client(str(name))
|
||||
except ResourceNotFoundError:
|
||||
return None
|
||||
|
||||
|
||||
def clear_queue(name: QueueNameType, *, account_id: str) -> None:
|
||||
queue = get_queue(name, account_id=account_id)
|
||||
def clear_queue(name: QueueNameType, storage_type: StorageType) -> None:
|
||||
queue = get_queue(name, storage_type)
|
||||
if queue:
|
||||
try:
|
||||
queue.clear_messages()
|
||||
@ -102,12 +107,12 @@ def clear_queue(name: QueueNameType, *, account_id: str) -> None:
|
||||
def send_message(
|
||||
name: QueueNameType,
|
||||
message: bytes,
|
||||
storage_type: StorageType,
|
||||
*,
|
||||
account_id: str,
|
||||
visibility_timeout: Optional[int] = None,
|
||||
time_to_live: int = DEFAULT_TTL,
|
||||
) -> None:
|
||||
queue = get_queue(name, account_id=account_id)
|
||||
queue = get_queue(name, storage_type)
|
||||
if queue:
|
||||
try:
|
||||
queue.send_message(
|
||||
@ -119,9 +124,8 @@ def send_message(
|
||||
pass
|
||||
|
||||
|
||||
def remove_first_message(name: QueueNameType, *, account_id: str) -> bool:
|
||||
create_queue(name, account_id=account_id)
|
||||
queue = get_queue(name, account_id=account_id)
|
||||
def remove_first_message(name: QueueNameType, storage_type: StorageType) -> bool:
|
||||
queue = get_queue(name, storage_type)
|
||||
if queue:
|
||||
try:
|
||||
for message in queue.receive_messages():
|
||||
@ -143,8 +147,8 @@ MAX_PEEK_SIZE = 32
|
||||
# https://docs.microsoft.com/en-us/python/api/azure-storage-queue/azure.storage.queue.queueclient
|
||||
def peek_queue(
|
||||
name: QueueNameType,
|
||||
storage_type: StorageType,
|
||||
*,
|
||||
account_id: str,
|
||||
object_type: Type[A],
|
||||
max_messages: int = MAX_PEEK_SIZE,
|
||||
) -> List[A]:
|
||||
@ -154,7 +158,7 @@ def peek_queue(
|
||||
if max_messages < MIN_PEEK_SIZE or max_messages > MAX_PEEK_SIZE:
|
||||
raise ValueError("invalid max messages: %s" % max_messages)
|
||||
|
||||
queue = get_queue(name, account_id=account_id)
|
||||
queue = get_queue(name, storage_type)
|
||||
if not queue:
|
||||
return result
|
||||
|
||||
@ -168,12 +172,12 @@ def peek_queue(
|
||||
def queue_object(
|
||||
name: QueueNameType,
|
||||
message: BaseModel,
|
||||
storage_type: StorageType,
|
||||
*,
|
||||
account_id: str,
|
||||
visibility_timeout: Optional[int] = None,
|
||||
time_to_live: int = DEFAULT_TTL,
|
||||
) -> bool:
|
||||
queue = get_queue(name, account_id=account_id)
|
||||
queue = get_queue(name, storage_type)
|
||||
if not queue:
|
||||
raise Exception("unable to queue object, no such queue: %s" % queue)
|
||||
|
||||
|
Reference in New Issue
Block a user