Use Storage Account types, rather than account_id (#320)

We need to move to supporting data sharding.

One of the steps towards that is stop passing around `account_id`, rather we need to specify the type of storage we need.
This commit is contained in:
bmc-msft 2020-11-18 09:06:14 -05:00 committed by GitHub
parent 52eca33237
commit e47e89609a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 205 additions and 139 deletions

View File

@ -13,7 +13,8 @@ from onefuzztypes.requests import AgentRegistrationGet, AgentRegistrationPost
from onefuzztypes.responses import AgentRegistration
from ..onefuzzlib.agent_authorization import call_if_agent
from ..onefuzzlib.azure.creds import get_fuzz_storage, get_instance_url
from ..onefuzzlib.azure.containers import StorageType
from ..onefuzzlib.azure.creds import get_instance_url
from ..onefuzzlib.azure.queue import get_queue_sas
from ..onefuzzlib.pools import Node, NodeMessage, NodeTasks, Pool
from ..onefuzzlib.request import not_ok, ok, parse_uri
@ -25,7 +26,7 @@ def create_registration_response(machine_id: UUID, pool: Pool) -> func.HttpRespo
commands_url = "%s/api/agents/commands" % base_address
work_queue = get_queue_sas(
pool.get_pool_queue(),
account_id=get_fuzz_storage(),
StorageType.corpus,
read=True,
update=True,
process=True,

View File

@ -13,6 +13,7 @@ from onefuzztypes.requests import ContainerCreate, ContainerDelete, ContainerGet
from onefuzztypes.responses import BoolResult, ContainerInfo, ContainerInfoBase
from ..onefuzzlib.azure.containers import (
StorageType,
create_container,
delete_container,
get_container_metadata,
@ -30,7 +31,7 @@ def get(req: func.HttpRequest) -> func.HttpResponse:
if isinstance(request, Error):
return not_ok(request, context="container get")
if request is not None:
metadata = get_container_metadata(request.name)
metadata = get_container_metadata(request.name, StorageType.corpus)
if metadata is None:
return not_ok(
Error(code=ErrorCode.INVALID_REQUEST, errors=["invalid container"]),
@ -41,6 +42,7 @@ def get(req: func.HttpRequest) -> func.HttpResponse:
name=request.name,
sas_url=get_container_sas_url(
request.name,
StorageType.corpus,
read=True,
write=True,
create=True,
@ -51,7 +53,7 @@ def get(req: func.HttpRequest) -> func.HttpResponse:
)
return ok(info)
containers = get_containers()
containers = get_containers(StorageType.corpus)
container_info = []
for name in containers:
@ -66,7 +68,7 @@ def post(req: func.HttpRequest) -> func.HttpResponse:
return not_ok(request, context="container create")
logging.info("container - creating %s", request.name)
sas = create_container(request.name, metadata=request.metadata)
sas = create_container(request.name, StorageType.corpus, metadata=request.metadata)
if sas:
return ok(
ContainerInfo(name=request.name, sas_url=sas, metadata=request.metadata)
@ -83,7 +85,7 @@ def delete(req: func.HttpRequest) -> func.HttpResponse:
return not_ok(request, context="container delete")
logging.info("container - deleting %s", request.name)
return ok(BoolResult(result=delete_container(request.name)))
return ok(BoolResult(result=delete_container(request.name, StorageType.corpus)))
def main(req: func.HttpRequest) -> func.HttpResponse:

View File

@ -8,6 +8,7 @@ from onefuzztypes.enums import ErrorCode
from onefuzztypes.models import Error, FileEntry
from ..onefuzzlib.azure.containers import (
StorageType,
blob_exists,
container_exists,
get_file_sas_url,
@ -20,13 +21,13 @@ def get(req: func.HttpRequest) -> func.HttpResponse:
if isinstance(request, Error):
return not_ok(request, context="download")
if not container_exists(request.container):
if not container_exists(request.container, StorageType.corpus):
return not_ok(
Error(code=ErrorCode.INVALID_REQUEST, errors=["invalid container"]),
context=request.container,
)
if not blob_exists(request.container, request.filename):
if not blob_exists(request.container, request.filename, StorageType.corpus):
return not_ok(
Error(code=ErrorCode.INVALID_REQUEST, errors=["invalid filename"]),
context=request.filename,
@ -34,7 +35,12 @@ def get(req: func.HttpRequest) -> func.HttpResponse:
return redirect(
get_file_sas_url(
request.container, request.filename, read=True, days=0, minutes=5
request.container,
request.filename,
StorageType.corpus,
read=True,
days=0,
minutes=5,
)
)

View File

@ -6,37 +6,61 @@
import datetime
import os
import urllib.parse
from typing import Dict, Optional, Union, cast
from enum import Enum
from typing import Any, Dict, Optional, Union, cast
from azure.common import AzureHttpError, AzureMissingResourceHttpError
from azure.storage.blob import BlobPermissions, ContainerPermissions
from memoization import cached
from .creds import get_blob_service
from .creds import get_blob_service, get_func_storage, get_fuzz_storage
class StorageType(Enum):
corpus = "corpus"
config = "config"
def get_account_id_by_type(storage_type: StorageType) -> str:
if storage_type == StorageType.corpus:
account_id = get_fuzz_storage()
elif storage_type == StorageType.config:
account_id = get_func_storage()
else:
raise NotImplementedError
return account_id
@cached(ttl=5)
def container_exists(name: str, account_id: Optional[str] = None) -> bool:
def get_blob_service_by_type(storage_type: StorageType) -> Any:
account_id = get_account_id_by_type(storage_type)
return get_blob_service(account_id)
@cached(ttl=5)
def container_exists(name: str, storage_type: StorageType) -> bool:
try:
get_blob_service(account_id).get_container_properties(name)
get_blob_service_by_type(storage_type).get_container_properties(name)
return True
except AzureHttpError:
return False
def get_containers(account_id: Optional[str] = None) -> Dict[str, Dict[str, str]]:
def get_containers(storage_type: StorageType) -> Dict[str, Dict[str, str]]:
return {
x.name: x.metadata
for x in get_blob_service(account_id).list_containers(include_metadata=True)
for x in get_blob_service_by_type(storage_type).list_containers(
include_metadata=True
)
if not x.name.startswith("$")
}
def get_container_metadata(
name: str, account_id: Optional[str] = None
name: str, storage_type: StorageType
) -> Optional[Dict[str, str]]:
try:
result = get_blob_service(account_id).get_container_metadata(name)
result = get_blob_service_by_type(storage_type).get_container_metadata(name)
return cast(Dict[str, str], result)
except AzureHttpError:
pass
@ -44,22 +68,29 @@ def get_container_metadata(
def create_container(
name: str, metadata: Optional[Dict[str, str]], account_id: Optional[str] = None
name: str, storage_type: StorageType, metadata: Optional[Dict[str, str]]
) -> Optional[str]:
try:
get_blob_service(account_id).create_container(name, metadata=metadata)
get_blob_service_by_type(storage_type).create_container(name, metadata=metadata)
except AzureHttpError:
# azure storage already logs errors
return None
return get_container_sas_url(
name, read=True, add=True, create=True, write=True, delete=True, list=True
name,
storage_type,
read=True,
add=True,
create=True,
write=True,
delete=True,
list=True,
)
def delete_container(name: str, account_id: Optional[str] = None) -> bool:
def delete_container(name: str, storage_type: StorageType) -> bool:
try:
return bool(get_blob_service(account_id).delete_container(name))
return bool(get_blob_service_by_type(storage_type).delete_container(name))
except AzureHttpError:
# azure storage already logs errors
return False
@ -67,7 +98,8 @@ def delete_container(name: str, account_id: Optional[str] = None) -> bool:
def get_container_sas_url(
container: str,
account_id: Optional[str] = None,
storage_type: StorageType,
*,
read: bool = False,
add: bool = False,
create: bool = False,
@ -75,7 +107,7 @@ def get_container_sas_url(
delete: bool = False,
list: bool = False,
) -> str:
service = get_blob_service(account_id)
service = get_blob_service_by_type(storage_type)
expiry = datetime.datetime.utcnow() + datetime.timedelta(days=30)
permission = ContainerPermissions(read, add, create, write, delete, list)
@ -91,7 +123,8 @@ def get_container_sas_url(
def get_file_sas_url(
container: str,
name: str,
account_id: Optional[str] = None,
storage_type: StorageType,
*,
read: bool = False,
add: bool = False,
create: bool = False,
@ -102,7 +135,7 @@ def get_file_sas_url(
hours: int = 0,
minutes: int = 0,
) -> str:
service = get_blob_service(account_id)
service = get_blob_service_by_type(storage_type)
expiry = datetime.datetime.utcnow() + datetime.timedelta(
days=days, hours=hours, minutes=minutes
)
@ -117,9 +150,9 @@ def get_file_sas_url(
def save_blob(
container: str, name: str, data: Union[str, bytes], account_id: Optional[str] = None
container: str, name: str, data: Union[str, bytes], storage_type: StorageType
) -> None:
service = get_blob_service(account_id)
service = get_blob_service_by_type(storage_type)
service.create_container(container)
if isinstance(data, str):
service.create_blob_from_text(container, name, data)
@ -127,10 +160,8 @@ def save_blob(
service.create_blob_from_bytes(container, name, data)
def get_blob(
container: str, name: str, account_id: Optional[str] = None
) -> Optional[bytes]:
service = get_blob_service(account_id)
def get_blob(container: str, name: str, storage_type: StorageType) -> Optional[bytes]:
service = get_blob_service_by_type(storage_type)
try:
blob = service.get_blob_to_bytes(container, name).content
return cast(bytes, blob)
@ -138,8 +169,8 @@ def get_blob(
return None
def blob_exists(container: str, name: str, account_id: Optional[str] = None) -> bool:
service = get_blob_service(account_id)
def blob_exists(container: str, name: str, storage_type: StorageType) -> bool:
service = get_blob_service_by_type(storage_type)
try:
service.get_blob_properties(container, name)
return True
@ -147,8 +178,8 @@ def blob_exists(container: str, name: str, account_id: Optional[str] = None) ->
return False
def delete_blob(container: str, name: str, account_id: Optional[str] = None) -> bool:
service = get_blob_service(account_id)
def delete_blob(container: str, name: str, storage_type: StorageType) -> bool:
service = get_blob_service_by_type(storage_type)
try:
service.delete_blob(container, name)
return True

View File

@ -87,12 +87,12 @@ def get_insights_appid() -> str:
return os.environ["APPINSIGHTS_APPID"]
@cached
# @cached
def get_fuzz_storage() -> str:
return os.environ["ONEFUZZ_DATA_STORAGE"]
@cached
# @cached
def get_func_storage() -> str:
return os.environ["ONEFUZZ_FUNC_STORAGE"]
@ -109,9 +109,9 @@ def get_instance_url() -> str:
@cached
def get_instance_id() -> UUID:
from .containers import get_blob
from .containers import StorageType, get_blob
blob = get_blob("base-config", "instance_id", account_id=get_func_storage())
blob = get_blob("base-config", "instance_id", StorageType.config)
if blob is None:
raise Exception("missing instance_id")
return UUID(blob.decode())

View File

@ -19,6 +19,7 @@ from azure.storage.queue import (
from memoization import cached
from pydantic import BaseModel
from .containers import StorageType, get_account_id_by_type
from .creds import get_storage_account_name_key
QueueNameType = Union[str, UUID]
@ -27,7 +28,8 @@ DEFAULT_TTL = -1
@cached(ttl=60)
def get_queue_client(account_id: str) -> QueueServiceClient:
def get_queue_client(storage_type: StorageType) -> QueueServiceClient:
account_id = get_account_id_by_type(storage_type)
logging.debug("getting blob container (account_id: %s)", account_id)
name, key = get_storage_account_name_key(account_id)
account_url = "https://%s.queue.core.windows.net" % name
@ -41,13 +43,14 @@ def get_queue_client(account_id: str) -> QueueServiceClient:
@cached(ttl=60)
def get_queue_sas(
queue: QueueNameType,
storage_type: StorageType,
*,
account_id: str,
read: bool = False,
add: bool = False,
update: bool = False,
process: bool = False,
) -> str:
account_id = get_account_id_by_type(storage_type)
logging.debug("getting queue sas %s (account_id: %s)", queue, account_id)
name, key = get_storage_account_name_key(account_id)
expiry = datetime.datetime.utcnow() + datetime.timedelta(days=30)
@ -67,31 +70,33 @@ def get_queue_sas(
@cached(ttl=60)
def create_queue(name: QueueNameType, *, account_id: str) -> None:
client = get_queue_client(account_id)
def create_queue(name: QueueNameType, storage_type: StorageType) -> None:
client = get_queue_client(storage_type)
try:
client.create_queue(str(name))
except ResourceExistsError:
pass
def delete_queue(name: QueueNameType, *, account_id: str) -> None:
client = get_queue_client(account_id)
def delete_queue(name: QueueNameType, storage_type: StorageType) -> None:
client = get_queue_client(storage_type)
queues = client.list_queues()
if str(name) in [x["name"] for x in queues]:
client.delete_queue(name)
def get_queue(name: QueueNameType, *, account_id: str) -> Optional[QueueServiceClient]:
client = get_queue_client(account_id)
def get_queue(
name: QueueNameType, storage_type: StorageType
) -> Optional[QueueServiceClient]:
client = get_queue_client(storage_type)
try:
return client.get_queue_client(str(name))
except ResourceNotFoundError:
return None
def clear_queue(name: QueueNameType, *, account_id: str) -> None:
queue = get_queue(name, account_id=account_id)
def clear_queue(name: QueueNameType, storage_type: StorageType) -> None:
queue = get_queue(name, storage_type)
if queue:
try:
queue.clear_messages()
@ -102,12 +107,12 @@ def clear_queue(name: QueueNameType, *, account_id: str) -> None:
def send_message(
name: QueueNameType,
message: bytes,
storage_type: StorageType,
*,
account_id: str,
visibility_timeout: Optional[int] = None,
time_to_live: int = DEFAULT_TTL,
) -> None:
queue = get_queue(name, account_id=account_id)
queue = get_queue(name, storage_type)
if queue:
try:
queue.send_message(
@ -119,9 +124,8 @@ def send_message(
pass
def remove_first_message(name: QueueNameType, *, account_id: str) -> bool:
create_queue(name, account_id=account_id)
queue = get_queue(name, account_id=account_id)
def remove_first_message(name: QueueNameType, storage_type: StorageType) -> bool:
queue = get_queue(name, storage_type)
if queue:
try:
for message in queue.receive_messages():
@ -143,8 +147,8 @@ MAX_PEEK_SIZE = 32
# https://docs.microsoft.com/en-us/python/api/azure-storage-queue/azure.storage.queue.queueclient
def peek_queue(
name: QueueNameType,
storage_type: StorageType,
*,
account_id: str,
object_type: Type[A],
max_messages: int = MAX_PEEK_SIZE,
) -> List[A]:
@ -154,7 +158,7 @@ def peek_queue(
if max_messages < MIN_PEEK_SIZE or max_messages > MAX_PEEK_SIZE:
raise ValueError("invalid max messages: %s" % max_messages)
queue = get_queue(name, account_id=account_id)
queue = get_queue(name, storage_type)
if not queue:
return result
@ -168,12 +172,12 @@ def peek_queue(
def queue_object(
name: QueueNameType,
message: BaseModel,
storage_type: StorageType,
*,
account_id: str,
visibility_timeout: Optional[int] = None,
time_to_live: int = DEFAULT_TTL,
) -> bool:
queue = get_queue(name, account_id=account_id)
queue = get_queue(name, storage_type)
if not queue:
raise Exception("unable to queue object, no such queue: %s" % queue)

View File

@ -11,8 +11,13 @@ from onefuzztypes.enums import OS, AgentMode
from onefuzztypes.models import AgentConfig, ReproConfig
from onefuzztypes.primitives import Extension, Region
from .azure.containers import get_container_sas_url, get_file_sas_url, save_blob
from .azure.creds import get_func_storage, get_instance_id, get_instance_url
from .azure.containers import (
StorageType,
get_container_sas_url,
get_file_sas_url,
save_blob,
)
from .azure.creds import get_instance_id, get_instance_url
from .azure.monitor import get_monitor_settings
from .azure.queue import get_queue_sas
from .reports import get_report
@ -96,7 +101,7 @@ def build_pool_config(pool_name: str) -> str:
instrumentation_key=os.environ.get("APPINSIGHTS_INSTRUMENTATIONKEY"),
heartbeat_queue=get_queue_sas(
"node-heartbeat",
account_id=get_func_storage(),
StorageType.config,
add=True,
),
telemetry_key=os.environ.get("ONEFUZZ_TELEMETRY"),
@ -107,13 +112,13 @@ def build_pool_config(pool_name: str) -> str:
"vm-scripts",
"%s/config.json" % pool_name,
config.json(),
account_id=get_func_storage(),
StorageType.config,
)
return get_file_sas_url(
"vm-scripts",
"%s/config.json" % pool_name,
account_id=get_func_storage(),
StorageType.config,
read=True,
)
@ -124,30 +129,26 @@ def update_managed_scripts() -> None:
% (
get_container_sas_url(
"instance-specific-setup",
StorageType.config,
read=True,
list=True,
account_id=get_func_storage(),
)
),
"azcopy sync '%s' tools"
% (
get_container_sas_url(
"tools", read=True, list=True, account_id=get_func_storage()
)
),
% (get_container_sas_url("tools", StorageType.config, read=True, list=True)),
]
save_blob(
"vm-scripts",
"managed.ps1",
"\r\n".join(commands) + "\r\n",
account_id=get_func_storage(),
StorageType.config,
)
save_blob(
"vm-scripts",
"managed.sh",
"\n".join(commands) + "\n",
account_id=get_func_storage(),
StorageType.config,
)
@ -164,25 +165,25 @@ def agent_config(
get_file_sas_url(
"vm-scripts",
"managed.ps1",
account_id=get_func_storage(),
StorageType.config,
read=True,
),
get_file_sas_url(
"tools",
"win64/azcopy.exe",
account_id=get_func_storage(),
StorageType.config,
read=True,
),
get_file_sas_url(
"tools",
"win64/setup.ps1",
account_id=get_func_storage(),
StorageType.config,
read=True,
),
get_file_sas_url(
"tools",
"win64/onefuzz.ps1",
account_id=get_func_storage(),
StorageType.config,
read=True,
),
]
@ -206,19 +207,19 @@ def agent_config(
get_file_sas_url(
"vm-scripts",
"managed.sh",
account_id=get_func_storage(),
StorageType.config,
read=True,
),
get_file_sas_url(
"tools",
"linux/azcopy",
account_id=get_func_storage(),
StorageType.config,
read=True,
),
get_file_sas_url(
"tools",
"linux/setup.sh",
account_id=get_func_storage(),
StorageType.config,
read=True,
),
]
@ -263,13 +264,22 @@ def repro_extensions(
if setup_container:
commands += [
"azcopy sync '%s' ./setup"
% (get_container_sas_url(setup_container, read=True, list=True)),
% (
get_container_sas_url(
setup_container, StorageType.corpus, read=True, list=True
)
),
]
urls = [
get_file_sas_url(repro_config.container, repro_config.path, read=True),
get_file_sas_url(
report.input_blob.container, report.input_blob.name, read=True
repro_config.container, repro_config.path, StorageType.corpus, read=True
),
get_file_sas_url(
report.input_blob.container,
report.input_blob.name,
StorageType.corpus,
read=True,
),
]
@ -288,7 +298,7 @@ def repro_extensions(
"task-configs",
"%s/%s" % (repro_id, script_name),
task_script,
account_id=get_func_storage(),
StorageType.config,
)
for repro_file in repro_files:
@ -296,13 +306,13 @@ def repro_extensions(
get_file_sas_url(
"repro-scripts",
repro_file,
account_id=get_func_storage(),
StorageType.config,
read=True,
),
get_file_sas_url(
"task-configs",
"%s/%s" % (repro_id, script_name),
account_id=get_func_storage(),
StorageType.config,
read=True,
),
]
@ -318,13 +328,13 @@ def proxy_manager_extensions(region: Region) -> List[Extension]:
get_file_sas_url(
"proxy-configs",
"%s/config.json" % region,
account_id=get_func_storage(),
StorageType.config,
read=True,
),
get_file_sas_url(
"tools",
"linux/onefuzz-proxy-manager",
account_id=get_func_storage(),
StorageType.config,
read=True,
),
]

View File

@ -21,11 +21,11 @@ from onefuzztypes.models import (
from onefuzztypes.primitives import Container, Event
from ..azure.containers import (
StorageType,
container_exists,
get_container_metadata,
get_file_sas_url,
)
from ..azure.creds import get_fuzz_storage
from ..azure.queue import send_message
from ..dashboard import add_event
from ..orm import ORMMixin
@ -72,7 +72,7 @@ class Notification(models.Notification, ORMMixin):
def create(
cls, container: Container, config: NotificationTemplate
) -> Result["Notification"]:
if not container_exists(container):
if not container_exists(container, StorageType.corpus):
return Error(code=ErrorCode.INVALID_REQUEST, errors=["invalid container"])
existing = cls.get_existing(container, config)
@ -106,7 +106,7 @@ def get_queue_tasks() -> Sequence[Tuple[Task, Sequence[str]]]:
@cached(ttl=60)
def container_metadata(container: Container) -> Optional[Dict[str, str]]:
return get_container_metadata(container)
return get_container_metadata(container, StorageType.corpus)
def new_files(container: Container, filename: str) -> None:
@ -149,9 +149,9 @@ def new_files(container: Container, filename: str) -> None:
for (task, containers) in get_queue_tasks():
if container in containers:
logging.info("queuing input %s %s %s", container, filename, task.task_id)
url = get_file_sas_url(container, filename, read=True, delete=True)
send_message(
task.task_id, bytes(url, "utf-8"), account_id=get_fuzz_storage()
url = get_file_sas_url(
container, filename, StorageType.corpus, read=True, delete=True
)
send_message(task.task_id, bytes(url, "utf-8"), StorageType.corpus)
add_event("new_file", results)

View File

@ -35,7 +35,7 @@ from pydantic import BaseModel, Field
from .__version__ import __version__
from .azure.auth import build_auth
from .azure.creds import get_func_storage, get_fuzz_storage
from .azure.containers import StorageType
from .azure.image import get_os
from .azure.network import Network
from .azure.queue import (
@ -442,7 +442,7 @@ class Pool(BASE_POOL, ORMMixin):
return
worksets = peek_queue(
self.get_pool_queue(), account_id=get_fuzz_storage(), object_type=WorkSet
self.get_pool_queue(), StorageType.corpus, object_type=WorkSet
)
for workset in worksets:
@ -460,7 +460,7 @@ class Pool(BASE_POOL, ORMMixin):
return "pool-%s" % self.pool_id.hex
def init(self) -> None:
create_queue(self.get_pool_queue(), account_id=get_fuzz_storage())
create_queue(self.get_pool_queue(), StorageType.corpus)
self.state = PoolState.running
self.save()
@ -470,7 +470,9 @@ class Pool(BASE_POOL, ORMMixin):
return False
return queue_object(
self.get_pool_queue(), work_set, account_id=get_fuzz_storage()
self.get_pool_queue(),
work_set,
StorageType.corpus,
)
@classmethod
@ -531,7 +533,7 @@ class Pool(BASE_POOL, ORMMixin):
scalesets = Scaleset.search_by_pool(self.name)
nodes = Node.search(query={"pool_name": [self.name]})
if not scalesets and not nodes:
delete_queue(self.get_pool_queue(), account_id=get_fuzz_storage())
delete_queue(self.get_pool_queue(), StorageType.corpus)
logging.info("pool stopped, deleting: %s", self.name)
self.state = PoolState.halt
self.delete()
@ -1053,16 +1055,16 @@ class ScalesetShrinkQueue:
return "to-shrink-%s" % self.scaleset_id.hex
def clear(self) -> None:
clear_queue(self.queue_name(), account_id=get_func_storage())
clear_queue(self.queue_name(), StorageType.config)
def create(self) -> None:
create_queue(self.queue_name(), account_id=get_func_storage())
create_queue(self.queue_name(), StorageType.config)
def delete(self) -> None:
delete_queue(self.queue_name(), account_id=get_func_storage())
delete_queue(self.queue_name(), StorageType.config)
def add_entry(self) -> None:
queue_object(self.queue_name(), ShrinkEntry(), account_id=get_func_storage())
queue_object(self.queue_name(), ShrinkEntry(), StorageType.config)
def should_shrink(self) -> bool:
return remove_first_message(self.queue_name(), account_id=get_func_storage())
return remove_first_message(self.queue_name(), StorageType.config)

View File

@ -21,8 +21,7 @@ from pydantic import Field
from .__version__ import __version__
from .azure.auth import build_auth
from .azure.containers import get_file_sas_url, save_blob
from .azure.creds import get_func_storage
from .azure.containers import StorageType, get_file_sas_url, save_blob
from .azure.ip import get_public_ip
from .azure.queue import get_queue_sas
from .azure.vm import VM
@ -191,12 +190,12 @@ class Proxy(ORMMixin):
url=get_file_sas_url(
"proxy-configs",
"%s/config.json" % self.region,
account_id=get_func_storage(),
StorageType.config,
read=True,
),
notification=get_queue_sas(
"proxy",
account_id=get_func_storage(),
StorageType.config,
add=True,
),
forwards=forwards,
@ -207,7 +206,7 @@ class Proxy(ORMMixin):
"proxy-configs",
"%s/config.json" % self.region,
proxy_config.json(),
account_id=get_func_storage(),
StorageType.config,
)
@classmethod

View File

@ -10,7 +10,7 @@ from typing import Optional, Union
from onefuzztypes.models import Report
from pydantic import ValidationError
from .azure.containers import get_blob
from .azure.containers import StorageType, get_blob
def parse_report(
@ -50,7 +50,7 @@ def get_report(container: str, filename: str) -> Optional[Report]:
logging.error("get_report invalid extension: %s", metadata)
return None
blob = get_blob(container, filename)
blob = get_blob(container, filename, StorageType.corpus)
if blob is None:
logging.error("get_report invalid blob: %s", metadata)
return None

View File

@ -14,8 +14,8 @@ from onefuzztypes.models import Repro as BASE_REPRO
from onefuzztypes.models import ReproConfig, TaskVm
from .azure.auth import build_auth
from .azure.containers import save_blob
from .azure.creds import get_base_region, get_func_storage
from .azure.containers import StorageType, save_blob
from .azure.creds import get_base_region
from .azure.ip import get_public_ip
from .azure.vm import VM
from .extension import repro_extensions
@ -205,7 +205,7 @@ class Repro(BASE_REPRO, ORMMixin):
"repro-scripts",
"%s/%s" % (self.vm_id, filename),
files[filename],
account_id=get_func_storage(),
StorageType.config,
)
logging.info("saved repro script")

View File

@ -11,13 +11,13 @@ from uuid import UUID
from onefuzztypes.enums import Compare, ContainerPermission, ContainerType, TaskFeature
from onefuzztypes.models import TaskConfig, TaskDefinition, TaskUnitConfig
from ..azure.containers import blob_exists, container_exists, get_container_sas_url
from ..azure.creds import (
get_func_storage,
get_fuzz_storage,
get_instance_id,
get_instance_url,
from ..azure.containers import (
StorageType,
blob_exists,
container_exists,
get_container_sas_url,
)
from ..azure.creds import get_instance_id, get_instance_url
from ..azure.queue import get_queue_sas
from .defs import TASK_DEFINITIONS
@ -68,7 +68,7 @@ def check_containers(definition: TaskDefinition, config: TaskConfig) -> None:
containers: Dict[ContainerType, List[str]] = {}
for container in config.containers:
if container.name not in checked:
if not container_exists(container.name):
if not container_exists(container.name, StorageType.corpus):
raise TaskConfigError("missing container: %s" % container.name)
checked.add(container.name)
@ -137,7 +137,7 @@ def check_config(config: TaskConfig) -> None:
if TaskFeature.target_exe in definition.features:
container = [x for x in config.containers if x.type == ContainerType.setup][0]
if not blob_exists(container.name, config.task.target_exe):
if not blob_exists(container.name, config.task.target_exe, StorageType.corpus):
err = "target_exe `%s` does not exist in the setup container `%s`" % (
config.task.target_exe,
container.name,
@ -153,7 +153,7 @@ def check_config(config: TaskConfig) -> None:
for tool_path in tools_paths:
if config.task.generator_exe.startswith(tool_path):
generator = config.task.generator_exe.replace(tool_path, "")
if not blob_exists(container.name, generator):
if not blob_exists(container.name, generator, StorageType.corpus):
err = (
"generator_exe `%s` does not exist in the tools container `%s`"
% (
@ -188,7 +188,7 @@ def build_task_config(
telemetry_key=os.environ.get("ONEFUZZ_TELEMETRY"),
heartbeat_queue=get_queue_sas(
"task-heartbeat",
account_id=get_func_storage(),
StorageType.config,
add=True,
),
back_channel_address="https://%s/api/back_channel" % (get_instance_url()),
@ -198,11 +198,11 @@ def build_task_config(
if definition.monitor_queue:
config.input_queue = get_queue_sas(
task_id,
StorageType.corpus,
add=True,
read=True,
update=True,
process=True,
account_id=get_fuzz_storage(),
)
for container_def in definition.containers:
@ -219,6 +219,7 @@ def build_task_config(
"path": "_".join(["task", container_def.type.name, str(i)]),
"url": get_container_sas_url(
container.name,
StorageType.corpus,
read=ContainerPermission.Read in container_def.permissions,
write=ContainerPermission.Write in container_def.permissions,
add=ContainerPermission.Add in container_def.permissions,

View File

@ -18,7 +18,7 @@ from onefuzztypes.webhooks import (
WebhookEventTaskStopped,
)
from ..azure.creds import get_fuzz_storage
from ..azure.containers import StorageType
from ..azure.image import get_os
from ..azure.queue import create_queue, delete_queue
from ..orm import MappingIntStrAny, ORMMixin, QueryFilter
@ -123,7 +123,7 @@ class Task(BASE_TASK, ORMMixin):
}
def init(self) -> None:
create_queue(self.task_id, account_id=get_fuzz_storage())
create_queue(self.task_id, StorageType.corpus)
self.state = TaskState.waiting
self.save()
@ -132,7 +132,7 @@ class Task(BASE_TASK, ORMMixin):
logging.info("stopping task: %s:%s", self.job_id, self.task_id)
ProxyForward.remove_forward(self.task_id)
delete_queue(str(self.task_id), account_id=get_fuzz_storage())
delete_queue(str(self.task_id), StorageType.corpus)
Node.stop_task(self.task_id)
self.state = TaskState.stopped
self.save()

View File

@ -10,8 +10,12 @@ from uuid import UUID
from onefuzztypes.enums import OS, PoolState, TaskState
from onefuzztypes.models import WorkSet, WorkUnit
from ..azure.containers import blob_exists, get_container_sas_url, save_blob
from ..azure.creds import get_func_storage
from ..azure.containers import (
StorageType,
blob_exists,
get_container_sas_url,
save_blob,
)
from ..pools import Pool
from .config import build_task_config, get_setup_container
from .main import Task
@ -60,20 +64,26 @@ def schedule_tasks() -> None:
agent_config = build_task_config(task.job_id, task.task_id, task.config)
setup_container = get_setup_container(task.config)
setup_url = get_container_sas_url(setup_container, read=True, list=True)
setup_url = get_container_sas_url(
setup_container, StorageType.corpus, read=True, list=True
)
setup_script = None
if task.os == OS.windows and blob_exists(setup_container, "setup.ps1"):
if task.os == OS.windows and blob_exists(
setup_container, "setup.ps1", StorageType.corpus
):
setup_script = "setup.ps1"
if task.os == OS.linux and blob_exists(setup_container, "setup.sh"):
if task.os == OS.linux and blob_exists(
setup_container, "setup.sh", StorageType.corpus
):
setup_script = "setup.sh"
save_blob(
"task-configs",
"%s/config.json" % task.task_id,
agent_config.json(exclude_none=True),
account_id=get_func_storage(),
StorageType.config,
)
reboot = False
count = 1

View File

@ -10,7 +10,7 @@ from msrestazure.azure_exceptions import CloudError
from onefuzztypes.enums import UpdateType
from pydantic import BaseModel
from .azure.creds import get_func_storage
from .azure.containers import StorageType
from .azure.queue import queue_object
@ -46,7 +46,7 @@ def queue_update(
if not queue_object(
"update-queue",
update,
account_id=get_func_storage(),
StorageType.config,
visibility_timeout=visibility_timeout,
):
logging.error("unable to queue update")

View File

@ -27,7 +27,7 @@ from onefuzztypes.webhooks import WebhookMessageLog as BASE_WEBHOOK_MESSAGE_LOG
from pydantic import BaseModel
from .__version__ import __version__
from .azure.creds import get_func_storage
from .azure.containers import StorageType
from .azure.queue import queue_object
from .orm import ORMMixin
@ -135,8 +135,8 @@ class WebhookMessageLog(BASE_WEBHOOK_MESSAGE_LOG, ORMMixin):
queue_object(
"webhooks",
obj,
StorageType.config,
visibility_timeout=visibility_timeout,
account_id=get_func_storage(),
)

View File

@ -12,9 +12,9 @@ from onefuzztypes.models import AgentConfig, Error
from onefuzztypes.requests import PoolCreate, PoolSearch, PoolStop
from onefuzztypes.responses import BoolResult
from ..onefuzzlib.azure.containers import StorageType
from ..onefuzzlib.azure.creds import (
get_base_region,
get_func_storage,
get_instance_id,
get_instance_url,
get_regions,
@ -33,7 +33,7 @@ def set_config(pool: Pool) -> Pool:
telemetry_key=os.environ.get("ONEFUZZ_TELEMETRY"),
heartbeat_queue=get_queue_sas(
"node-heartbeat",
account_id=get_func_storage(),
StorageType.config,
add=True,
),
instance_id=get_instance_id(),