initial public release

This commit is contained in:
Brian Caswell
2020-09-18 12:21:04 -04:00
parent 9c3aa0bdfb
commit d3a0b292e6
387 changed files with 43810 additions and 28 deletions

View File

@ -0,0 +1,5 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# pylint: disable=W0612,C0111
__version__ = "0.0.0"

View File

@ -0,0 +1,75 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import Callable, Union
from uuid import UUID
import azure.functions as func
import jwt
from memoization import cached
from onefuzztypes.enums import ErrorCode
from onefuzztypes.models import Error
from pydantic import BaseModel
from .pools import Scaleset
from .request import not_ok
class TokenData(BaseModel):
application_id: UUID
object_id: UUID
def try_get_token_auth_header(request: func.HttpRequest) -> Union[Error, TokenData]:
""" Obtains the Access Token from the Authorization Header """
auth: str = request.headers.get("Authorization", None)
if not auth:
return Error(
code=ErrorCode.INVALID_REQUEST, errors=["Authorization header is expected"]
)
parts = auth.split()
if parts[0].lower() != "bearer":
return Error(
code=ErrorCode.INVALID_REQUEST,
errors=["Authorization header must start with Bearer"],
)
elif len(parts) == 1:
return Error(code=ErrorCode.INVALID_REQUEST, errors=["Token not found"])
elif len(parts) > 2:
return Error(
code=ErrorCode.INVALID_REQUEST,
errors=["Authorization header must be Bearer token"],
)
# This token has already been verified by the azure authentication layer
token = jwt.decode(parts[1], verify=False)
return TokenData(application_id=UUID(token["appid"]), object_id=UUID(token["oid"]))
@cached(ttl=60)
def is_authorized(token_data: TokenData) -> bool:
scalesets = Scaleset.get_by_object_id(token_data.object_id)
return len(scalesets) > 0
def verify_token(
req: func.HttpRequest, func: Callable[[func.HttpRequest], func.HttpResponse]
) -> func.HttpResponse:
token = try_get_token_auth_header(req)
if isinstance(token, Error):
return not_ok(token, status_code=401, context="token verification")
if not is_authorized(token):
return not_ok(
Error(code=ErrorCode.UNAUTHORIZED, errors=["Unrecognized agent"]),
status_code=401,
context="token verification",
)
return func(req)

View File

@ -0,0 +1,37 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import subprocess # nosec - used for ssh key generation
import tempfile
from typing import Tuple
from uuid import uuid4
from onefuzztypes.models import Authentication
def generate_keypair() -> Tuple[str, str]:
with tempfile.TemporaryDirectory() as tmpdir:
filename = os.path.join(tmpdir, "key")
cmd = ["ssh-keygen", "-t", "rsa", "-f", filename, "-P", "", "-b", "2048"]
subprocess.check_output(cmd) # nosec - all arguments are under our control
with open(filename, "r") as handle:
private = handle.read()
with open(filename + ".pub", "r") as handle:
public = handle.read().strip()
return (public, private)
def build_auth() -> Authentication:
public_key, private_key = generate_keypair()
auth = Authentication(
password=str(uuid4()), public_key=public_key, private_key=private_key
)
return auth

View File

@ -0,0 +1,155 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import datetime
import os
import urllib.parse
from typing import Any, Dict, Optional, Union, cast
from azure.common import AzureHttpError, AzureMissingResourceHttpError
from azure.storage.blob import BlobPermissions, ContainerPermissions
from .creds import get_blob_service
def get_containers(account_id: Optional[str] = None) -> Dict[str, Dict[str, str]]:
return {
x.name: x.metadata
for x in get_blob_service(account_id).list_containers(include_metadata=True)
if not x.name.startswith("$")
}
def get_container_metadata(
name: str, account_id: Optional[str] = None
) -> Optional[Dict[str, str]]:
try:
result = get_blob_service(account_id).get_container_metadata(name)
if result:
return cast(Dict[str, str], result)
except AzureHttpError:
pass
return None
def create_container(
name: str, metadata: Optional[Dict[str, str]], account_id: Optional[str] = None
) -> Optional[str]:
try:
get_blob_service(account_id).create_container(name, metadata=metadata)
except AzureHttpError:
# azure storage already logs errors
return None
return get_container_sas_url(
name, read=True, add=True, create=True, write=True, delete=True, list=True
)
def delete_container(name: str, account_id: Optional[str] = None) -> bool:
try:
return bool(get_blob_service(account_id).delete_container(name))
except AzureHttpError:
# azure storage already logs errors
return False
def get_container_sas_url(
container: str,
account_id: Optional[str] = None,
read: bool = False,
add: bool = False,
create: bool = False,
write: bool = False,
delete: bool = False,
list: bool = False,
) -> str:
service = get_blob_service(account_id)
expiry = datetime.datetime.utcnow() + datetime.timedelta(days=30)
permission = ContainerPermissions(read, add, create, write, delete, list)
sas_token = service.generate_container_shared_access_signature(
container, permission=permission, expiry=expiry
)
url = service.make_container_url(container, sas_token=sas_token)
url = url.replace("?restype=container&", "?")
return str(url)
def get_file_sas_url(
container: str,
name: str,
account_id: Optional[str] = None,
read: bool = False,
add: bool = False,
create: bool = False,
write: bool = False,
delete: bool = False,
list: bool = False,
days: int = 30,
hours: int = 0,
minutes: int = 0,
) -> str:
service = get_blob_service(account_id)
expiry = datetime.datetime.utcnow() + datetime.timedelta(
days=days, hours=hours, minutes=minutes
)
permission = BlobPermissions(read, add, create, write, delete, list)
sas_token = service.generate_blob_shared_access_signature(
container, name, permission=permission, expiry=expiry
)
url = service.make_blob_url(container, name, sas_token=sas_token)
return str(url)
def save_blob(
container: str, name: str, data: Union[str, bytes], account_id: Optional[str] = None
) -> None:
service = get_blob_service(account_id)
service.create_container(container)
if isinstance(data, str):
service.create_blob_from_text(container, name, data)
elif isinstance(data, bytes):
service.create_blob_from_bytes(container, name, data)
def get_blob(
container: str, name: str, account_id: Optional[str] = None
) -> Optional[Any]: # should be bytes
service = get_blob_service(account_id)
try:
blob = service.get_blob_to_bytes(container, name).content
return blob
except AzureMissingResourceHttpError:
return None
def blob_exists(container: str, name: str, account_id: Optional[str] = None) -> bool:
service = get_blob_service(account_id)
try:
service.get_blob_properties(container, name)
return True
except AzureMissingResourceHttpError:
return False
def delete_blob(container: str, name: str, account_id: Optional[str] = None) -> bool:
service = get_blob_service(account_id)
try:
service.delete_blob(container, name)
return True
except AzureMissingResourceHttpError:
return False
def auth_download_url(container: str, filename: str) -> str:
instance = os.environ["ONEFUZZ_INSTANCE"]
return "%s/api/download?%s" % (
instance,
urllib.parse.urlencode({"container": container, "filename": filename}),
)

View File

@ -0,0 +1,116 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
import os
from typing import Any, List, Optional, Tuple
from azure.cli.core import CLIError
from azure.common.client_factory import get_client_from_cli_profile
from azure.graphrbac import GraphRbacManagementClient
from azure.graphrbac.models import CheckGroupMembershipParameters
from azure.mgmt.resource import ResourceManagementClient
from azure.mgmt.storage import StorageManagementClient
from azure.mgmt.subscription import SubscriptionClient
from azure.storage.blob import BlockBlobService
from memoization import cached
from msrestazure.azure_active_directory import MSIAuthentication
from msrestazure.tools import parse_resource_id
from .monkeypatch import allow_more_workers
@cached(ttl=60)
def get_msi() -> MSIAuthentication:
return MSIAuthentication()
@cached(ttl=60)
def mgmt_client_factory(client_class: Any) -> Any:
allow_more_workers()
try:
return get_client_from_cli_profile(client_class)
except CLIError:
if issubclass(client_class, SubscriptionClient):
return client_class(get_msi())
else:
return client_class(get_msi(), get_subscription())
@cached(ttl=60)
def get_storage_account_name_key(account_id: Optional[str] = None) -> Tuple[str, str]:
db_client = mgmt_client_factory(StorageManagementClient)
if account_id is None:
account_id = os.environ["ONEFUZZ_DATA_STORAGE"]
resource = parse_resource_id(account_id)
key = (
db_client.storage_accounts.list_keys(
resource["resource_group"], resource["name"]
)
.keys[0]
.value
)
return resource["name"], key
@cached(ttl=60)
def get_blob_service(account_id: Optional[str] = None) -> BlockBlobService:
logging.info("getting blob container (account_id: %s)", account_id)
name, key = get_storage_account_name_key(account_id)
service = BlockBlobService(account_name=name, account_key=key)
return service
@cached
def get_base_resource_group() -> Any: # should be str
return parse_resource_id(os.environ["ONEFUZZ_RESOURCE_GROUP"])["resource_group"]
@cached
def get_base_region() -> Any: # should be str
client = mgmt_client_factory(ResourceManagementClient)
group = client.resource_groups.get(get_base_resource_group())
return group.location
@cached
def get_subscription() -> Any: # should be str
return parse_resource_id(os.environ["ONEFUZZ_DATA_STORAGE"])["subscription"]
@cached
def get_fuzz_storage() -> str:
return os.environ["ONEFUZZ_DATA_STORAGE"]
@cached
def get_func_storage() -> str:
return os.environ["ONEFUZZ_FUNC_STORAGE"]
@cached
def get_instance_name() -> str:
return os.environ["ONEFUZZ_INSTANCE_NAME"]
@cached(ttl=60)
def get_regions() -> List[str]:
client = mgmt_client_factory(SubscriptionClient)
subscription = get_subscription()
locations = client.subscriptions.list_locations(subscription)
return sorted([x.name for x in locations])
def get_graph_client() -> Any:
return mgmt_client_factory(GraphRbacManagementClient)
def is_member_of(group_id: str, member_id: str) -> bool:
client = get_graph_client()
return bool(
client.groups.is_member_of(
CheckGroupMembershipParameters(group_id=group_id, member_id=member_id)
).value
)

View File

@ -0,0 +1,29 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
from typing import Any
from azure.mgmt.compute import ComputeManagementClient
from msrestazure.azure_exceptions import CloudError
from .creds import mgmt_client_factory
def list_disks(resource_group: str) -> Any:
logging.info("listing disks %s", resource_group)
compute_client = mgmt_client_factory(ComputeManagementClient)
return compute_client.disks.list_by_resource_group(resource_group)
def delete_disk(resource_group: str, name: str) -> bool:
logging.info("deleting disks %s : %s", resource_group, name)
compute_client = mgmt_client_factory(ComputeManagementClient)
try:
compute_client.disks.delete(resource_group, name)
return True
except CloudError as err:
logging.error("unable to delete disk: %s", err)
return False

View File

@ -0,0 +1,42 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import Union
from azure.mgmt.compute import ComputeManagementClient
from memoization import cached
from msrestazure.azure_exceptions import CloudError
from msrestazure.tools import parse_resource_id
from onefuzztypes.enums import OS, ErrorCode
from onefuzztypes.models import Error
from onefuzztypes.primitives import Region
from .creds import mgmt_client_factory
@cached(ttl=60)
def get_os(region: Region, image: str) -> Union[Error, OS]:
client = mgmt_client_factory(ComputeManagementClient)
parsed = parse_resource_id(image)
if "resource_group" in parsed:
try:
name = client.images.get(
parsed["resource_group"], parsed["name"]
).storage_profile.os_disk.os_type.name
except CloudError as err:
return Error(code=ErrorCode.INVALID_IMAGE, errors=[str(err)])
else:
publisher, offer, sku, version = image.split(":")
try:
if version == "latest":
version = client.virtual_machine_images.list(
region, publisher, offer, sku, top=1
)[0].name
name = client.virtual_machine_images.get(
region, publisher, offer, sku, version
).os_disk_image.operating_system.name
except CloudError as err:
return Error(code=ErrorCode.INVALID_IMAGE, errors=[str(err)])
return OS[name]

View File

@ -0,0 +1,150 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
import os
from typing import Any, Dict, Optional, Union
from uuid import UUID
from azure.mgmt.network import NetworkManagementClient
from msrestazure.azure_exceptions import CloudError
from msrestazure.tools import parse_resource_id
from onefuzztypes.enums import ErrorCode
from onefuzztypes.models import Error
from .creds import get_base_resource_group, mgmt_client_factory
from .subnet import create_virtual_network, get_subnet_id
from .vmss import get_instance_id
def get_scaleset_instance_ip(scaleset: UUID, machine_id: UUID) -> Optional[str]:
instance = get_instance_id(scaleset, machine_id)
if not isinstance(instance, str):
return None
resource_group = get_base_resource_group()
client = mgmt_client_factory(NetworkManagementClient)
intf = client.network_interfaces.list_virtual_machine_scale_set_network_interfaces(
resource_group, str(scaleset)
)
try:
for interface in intf:
resource = parse_resource_id(interface.virtual_machine.id)
if resource.get("resource_name") != instance:
continue
for config in interface.ip_configurations:
if config.private_ip_address is None:
continue
return str(config.private_ip_address)
except CloudError:
# this can fail if an interface is removed during the iteration
pass
return None
def get_ip(resource_group: str, name: str) -> Optional[Any]:
logging.info("getting ip %s:%s", resource_group, name)
network_client = mgmt_client_factory(NetworkManagementClient)
try:
return network_client.public_ip_addresses.get(resource_group, name)
except CloudError:
return None
def delete_ip(resource_group: str, name: str) -> Any:
logging.info("deleting ip %s:%s", resource_group, name)
network_client = mgmt_client_factory(NetworkManagementClient)
return network_client.public_ip_addresses.delete(resource_group, name)
def create_ip(resource_group: str, name: str, location: str) -> Any:
logging.info("creating ip for %s:%s in %s", resource_group, name, location)
network_client = mgmt_client_factory(NetworkManagementClient)
params: Dict[str, Union[str, Dict[str, str]]] = {
"location": location,
"public_ip_allocation_method": "Dynamic",
}
if "ONEFUZZ_OWNER" in os.environ:
params["tags"] = {"OWNER": os.environ["ONEFUZZ_OWNER"]}
return network_client.public_ip_addresses.create_or_update(
resource_group, name, params
)
def get_public_nic(resource_group: str, name: str) -> Optional[Any]:
logging.info("getting nic: %s %s", resource_group, name)
network_client = mgmt_client_factory(NetworkManagementClient)
try:
return network_client.network_interfaces.get(resource_group, name)
except CloudError:
return None
def delete_nic(resource_group: str, name: str) -> Optional[Any]:
logging.info("deleting nic %s:%s", resource_group, name)
network_client = mgmt_client_factory(NetworkManagementClient)
return network_client.network_interfaces.delete(resource_group, name)
def create_public_nic(resource_group: str, name: str, location: str) -> Optional[Error]:
logging.info("creating nic for %s:%s in %s", resource_group, name, location)
network_client = mgmt_client_factory(NetworkManagementClient)
subnet_id = get_subnet_id(resource_group, location)
if not subnet_id:
return create_virtual_network(resource_group, location, location)
ip = get_ip(resource_group, name)
if not ip:
create_ip(resource_group, name, location)
return None
params = {
"location": location,
"ip_configurations": [
{
"name": "myIPConfig",
"public_ip_address": ip,
"subnet": {"id": subnet_id},
}
],
}
if "ONEFUZZ_OWNER" in os.environ:
params["tags"] = {"OWNER": os.environ["ONEFUZZ_OWNER"]}
try:
network_client.network_interfaces.create_or_update(resource_group, name, params)
except CloudError as err:
if "RetryableError" not in repr(err):
return Error(
code=ErrorCode.VM_CREATE_FAILED,
errors=["unable to create nic: %s" % err],
)
return None
def get_public_ip(resource_id: str) -> Optional[str]:
logging.info("getting ip for %s", resource_id)
network_client = mgmt_client_factory(NetworkManagementClient)
resource = parse_resource_id(resource_id)
ip = (
network_client.network_interfaces.get(
resource["resource_group"], resource["name"]
)
.ip_configurations[0]
.public_ip_address
)
resource = parse_resource_id(ip.id)
ip = network_client.public_ip_addresses.get(
resource["resource_group"], resource["name"]
).ip_address
if ip is None:
return None
else:
return str(ip)

View File

@ -0,0 +1,29 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
from typing import Any, Dict
from azure.mgmt.loganalytics import LogAnalyticsManagementClient
from memoization import cached
from .creds import get_base_resource_group, mgmt_client_factory
@cached(ttl=60)
def get_montior_client() -> Any:
return mgmt_client_factory(LogAnalyticsManagementClient)
@cached(ttl=60)
def get_monitor_settings() -> Dict[str, str]:
resource_group = get_base_resource_group()
workspace_name = os.environ["ONEFUZZ_MONITOR"]
client = get_montior_client()
customer_id = client.workspaces.get(resource_group, workspace_name).customer_id
shared_key = client.shared_keys.get_shared_keys(
resource_group, workspace_name
).primary_shared_key
return {"id": customer_id, "key": shared_key}

View File

@ -0,0 +1,24 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import inspect
import logging
WORKERS_DONE = False
def allow_more_workers() -> None:
global WORKERS_DONE
if WORKERS_DONE:
return
stack = inspect.stack()
for entry in stack:
if entry.filename.endswith("azure_functions_worker/dispatcher.py"):
if entry.frame.f_locals["self"]._sync_call_tp._max_workers == 1:
logging.info("bumped thread worker count to 50")
entry.frame.f_locals["self"]._sync_call_tp._max_workers = 50
WORKERS_DONE = True

View File

@ -0,0 +1,46 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
from typing import Optional, Union
from msrestazure.azure_exceptions import CloudError
from onefuzztypes.enums import ErrorCode
from onefuzztypes.models import Error
from onefuzztypes.primitives import Region
from .creds import get_base_resource_group
from .subnet import create_virtual_network, delete_subnet, get_subnet_id
class Network:
def __init__(self, region: Region):
self.group = get_base_resource_group()
self.region = region
def exists(self) -> bool:
return self.get_id() is not None
def get_id(self) -> Optional[str]:
return get_subnet_id(self.group, self.region)
def create(self) -> Union[None, Error]:
if not self.exists():
result = create_virtual_network(self.group, self.region, self.region)
if isinstance(result, CloudError):
error = Error(
code=ErrorCode.UNABLE_TO_CREATE_NETWORK, errors=[result.message]
)
logging.error(
"network creation failed: %s- %s",
self.region,
error,
)
return error
return None
def delete(self) -> None:
delete_subnet(self.group, self.region)

View File

@ -0,0 +1,153 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import base64
import datetime
import json
import logging
from typing import List, Optional, Type, TypeVar, Union
from uuid import UUID
from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError
from azure.storage.queue import (
QueueSasPermissions,
QueueServiceClient,
generate_queue_sas,
)
from memoization import cached
from pydantic import BaseModel
from .creds import get_storage_account_name_key
QueueNameType = Union[str, UUID]
@cached(ttl=60)
def get_queue_client(account_id: str) -> QueueServiceClient:
logging.info("getting blob container (account_id: %s)", account_id)
name, key = get_storage_account_name_key(account_id)
account_url = "https://%s.queue.core.windows.net" % name
client = QueueServiceClient(
account_url=account_url,
credential={"account_name": name, "account_key": key},
)
return client
def get_queue_sas(
queue: QueueNameType,
*,
account_id: str,
read: bool = False,
add: bool = False,
update: bool = False,
process: bool = False,
) -> str:
logging.info("getting queue sas %s (account_id: %s)", queue, account_id)
name, key = get_storage_account_name_key(account_id)
expiry = datetime.datetime.utcnow() + datetime.timedelta(days=30)
token = generate_queue_sas(
name,
str(queue),
key,
permission=QueueSasPermissions(
read=read, add=add, update=update, process=process
),
expiry=expiry,
)
url = "https://%s.queue.core.windows.net/%s?%s" % (name, queue, token)
return url
@cached(ttl=60)
def create_queue(name: QueueNameType, *, account_id: str) -> None:
client = get_queue_client(account_id)
try:
client.create_queue(str(name))
except ResourceExistsError:
pass
def delete_queue(name: QueueNameType, *, account_id: str) -> None:
client = get_queue_client(account_id)
queues = client.list_queues()
if str(name) in [x["name"] for x in queues]:
client.delete_queue(name)
def get_queue(name: QueueNameType, *, account_id: str) -> Optional[QueueServiceClient]:
client = get_queue_client(account_id)
try:
return client.get_queue_client(str(name))
except ResourceNotFoundError:
return None
def send_message(
name: QueueNameType,
message: bytes,
*,
account_id: str,
) -> None:
queue = get_queue(name, account_id=account_id)
if queue:
try:
queue.send_message(base64.b64encode(message).decode())
except ResourceNotFoundError:
pass
A = TypeVar("A", bound=BaseModel)
MIN_PEEK_SIZE = 1
MAX_PEEK_SIZE = 32
# Peek at a max of 32 messages
# https://docs.microsoft.com/en-us/python/api/azure-storage-queue/azure.storage.queue.queueclient
def peek_queue(
name: QueueNameType,
*,
account_id: str,
object_type: Type[A],
max_messages: int = MAX_PEEK_SIZE,
) -> List[A]:
result: List[A] = []
# message count
if max_messages < MIN_PEEK_SIZE or max_messages > MAX_PEEK_SIZE:
raise ValueError("invalid max messages: %s" % max_messages)
queue = get_queue(name, account_id=account_id)
if not queue:
return result
for message in queue.peek_messages(max_messages=max_messages):
decoded = base64.b64decode(message.content)
raw = json.loads(decoded)
result.append(object_type.parse_obj(raw))
return result
def queue_object(
name: QueueNameType,
message: BaseModel,
*,
account_id: str,
visibility_timeout: Optional[int] = None,
) -> bool:
queue = get_queue(name, account_id=account_id)
if not queue:
raise Exception("unable to queue object, no such queue: %s" % queue)
encoded = base64.b64encode(message.json(exclude_none=True).encode()).decode()
try:
queue.send_message(encoded, visibility_timeout=visibility_timeout)
return True
except ResourceNotFoundError:
return False

View File

@ -0,0 +1,70 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
import os
from typing import Any, Optional, Union, cast
from azure.mgmt.network import NetworkManagementClient
from msrestazure.azure_exceptions import CloudError
from onefuzztypes.enums import ErrorCode
from onefuzztypes.models import Error
from .creds import mgmt_client_factory
def get_subnet_id(resource_group: str, name: str) -> Optional[str]:
network_client = mgmt_client_factory(NetworkManagementClient)
try:
subnet = network_client.subnets.get(resource_group, name, name)
return cast(str, subnet.id)
except CloudError:
logging.info(
"subnet missing: resource group: %s name: %s",
resource_group,
name,
)
return None
def delete_subnet(resource_group: str, name: str) -> Union[None, CloudError, Any]:
network_client = mgmt_client_factory(NetworkManagementClient)
try:
return network_client.virtual_networks.delete(resource_group, name)
except CloudError as err:
if err.error and "InUseSubnetCannotBeDeleted" in str(err.error):
logging.error(
"subnet delete failed: %s %s : %s", resource_group, name, repr(err)
)
return None
else:
raise err
def create_virtual_network(
resource_group: str, name: str, location: str
) -> Optional[Error]:
logging.info(
"creating subnet - resource group: %s name: %s location: %s",
resource_group,
name,
location,
)
network_client = mgmt_client_factory(NetworkManagementClient)
params = {
"location": location,
"address_space": {"address_prefixes": ["10.0.0.0/8"]},
"subnets": [{"name": name, "address_prefix": "10.0.0.0/16"}],
}
if "ONEFUZZ_OWNER" in os.environ:
params["tags"] = {"OWNER": os.environ["ONEFUZZ_OWNER"]}
try:
network_client.virtual_networks.create_or_update(resource_group, name, params)
except CloudError as err:
return Error(code=ErrorCode.UNABLE_TO_CREATE_NETWORK, errors=[str(err.message)])
return None

View File

@ -0,0 +1,30 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
import os
from typing import Optional
from azure.cosmosdb.table import TableService
from memoization import cached
from .creds import get_storage_account_name_key
@cached(ttl=60)
def get_client(
table: Optional[str] = None, account_id: Optional[str] = None
) -> TableService:
if account_id is None:
account_id = os.environ["ONEFUZZ_FUNC_STORAGE"]
logging.info("getting table account: (account_id: %s)", account_id)
name, key = get_storage_account_name_key(account_id)
client = TableService(account_name=name, account_key=key)
if table and not client.exists(table):
logging.info("creating missing table %s", table)
client.create_table(table, fail_on_exist=False)
return client

View File

@ -0,0 +1,273 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
import os
from typing import Any, Dict, List, Optional, Union, cast
from uuid import UUID
from azure.mgmt.compute import ComputeManagementClient
from azure.mgmt.compute.models import VirtualMachine
from msrestazure.azure_exceptions import CloudError
from onefuzztypes.enums import OS, ErrorCode
from onefuzztypes.models import Authentication, Error
from onefuzztypes.primitives import Extension, Region
from pydantic import BaseModel
from .creds import get_base_resource_group, mgmt_client_factory
from .disk import delete_disk, list_disks
from .image import get_os
from .ip import create_public_nic, delete_ip, delete_nic, get_ip, get_public_nic
def get_vm(name: str) -> Optional[VirtualMachine]:
resource_group = get_base_resource_group()
logging.debug("getting vm: %s %s - %s", resource_group, name)
compute_client = mgmt_client_factory(ComputeManagementClient)
try:
return cast(
VirtualMachine,
compute_client.virtual_machines.get(
resource_group, name, expand="instanceView"
),
)
except CloudError as err:
logging.debug("vm does not exist %s", err)
return None
def create_vm(
name: str,
location: str,
vm_sku: str,
image: str,
password: str,
ssh_public_key: str,
) -> Union[None, Error]:
resource_group = get_base_resource_group()
logging.info("creating vm %s:%s:%s", resource_group, location, name)
compute_client = mgmt_client_factory(ComputeManagementClient)
nic = get_public_nic(resource_group, name)
if nic is None:
result = create_public_nic(resource_group, name, location)
if isinstance(result, Error):
return result
logging.info("waiting on nic creation")
return None
if image.startswith("/"):
image_ref = {"id": image}
else:
image_val = image.split(":", 4)
image_ref = {
"publisher": image_val[0],
"offer": image_val[1],
"sku": image_val[2],
"version": image_val[3],
}
params: Dict = {
"location": location,
"os_profile": {
"computer_name": "node",
"admin_username": "onefuzz",
"admin_password": password,
},
"hardware_profile": {"vm_size": vm_sku},
"storage_profile": {"image_reference": image_ref},
"network_profile": {"network_interfaces": [{"id": nic.id}]},
}
image_os = get_os(location, image)
if isinstance(image_os, Error):
return image_os
if image_os == OS.linux:
params["os_profile"]["linux_configuration"] = {
"disable_password_authentication": True,
"ssh": {
"public_keys": [
{
"path": "/home/onefuzz/.ssh/authorized_keys",
"key_data": ssh_public_key,
}
]
},
}
if "ONEFUZZ_OWNER" in os.environ:
params["tags"] = {"OWNER": os.environ["ONEFUZZ_OWNER"]}
try:
compute_client.virtual_machines.create_or_update(resource_group, name, params)
except CloudError as err:
if "The request failed due to conflict with a concurrent request" in str(err):
logging.debug(
"create VM had conflicts with concurrent request, ignoring %s", err
)
return None
return Error(code=ErrorCode.VM_CREATE_FAILED, errors=[str(err)])
return None
def get_extension(vm_name: str, extension_name: str) -> Optional[Any]:
resource_group = get_base_resource_group()
logging.debug(
"getting extension: %s:%s:%s - %s",
resource_group,
vm_name,
extension_name,
)
compute_client = mgmt_client_factory(ComputeManagementClient)
try:
return compute_client.virtual_machine_extensions.get(
resource_group, vm_name, extension_name
)
except CloudError as err:
logging.error("extension does not exist %s", err)
return None
def create_extension(vm_name: str, extension: Dict) -> Any:
resource_group = get_base_resource_group()
logging.info(
"creating extension: %s:%s:%s", resource_group, vm_name, extension["name"]
)
compute_client = mgmt_client_factory(ComputeManagementClient)
return compute_client.virtual_machine_extensions.create_or_update(
resource_group, vm_name, extension["name"], extension
)
def delete_vm(name: str) -> Any:
resource_group = get_base_resource_group()
logging.info("deleting vm: %s %s", resource_group, name)
compute_client = mgmt_client_factory(ComputeManagementClient)
return compute_client.virtual_machines.delete(resource_group, name)
def has_components(name: str) -> bool:
# check if any of the components associated with a VM still exist.
#
# Azure VM Deletion requires we first delete the VM, then delete all of it's
# resources. This is required to ensure we've cleaned it all up before
# marking it "done"
resource_group = get_base_resource_group()
if get_vm(name):
return True
if get_public_nic(resource_group, name):
return True
if get_ip(resource_group, name):
return True
disks = [x.name for x in list_disks(resource_group) if x.name.startswith(name)]
if disks:
return True
return False
def delete_vm_components(name: str) -> bool:
resource_group = get_base_resource_group()
logging.info("deleting vm components %s:%s", resource_group, name)
if get_vm(name):
logging.info("deleting vm %s:%s", resource_group, name)
delete_vm(name)
return False
if get_public_nic(resource_group, name):
logging.info("deleting nic %s:%s", resource_group, name)
delete_nic(resource_group, name)
return False
if get_ip(resource_group, name):
logging.info("deleting ip %s:%s", resource_group, name)
delete_ip(resource_group, name)
return False
disks = [x.name for x in list_disks(resource_group) if x.name.startswith(name)]
if disks:
for disk in disks:
logging.info("deleting disk %s:%s", resource_group, disk)
delete_disk(resource_group, disk)
return False
return True
class VM(BaseModel):
name: Union[UUID, str]
region: Region
sku: str
image: str
auth: Authentication
def is_deleted(self) -> bool:
return has_components(str(self.name))
def exists(self) -> bool:
return self.get() is not None
def get(self) -> Optional[VirtualMachine]:
return get_vm(str(self.name))
def create(self) -> Union[None, Error]:
if self.get() is not None:
return None
logging.info("vm creating: %s", self.name)
return create_vm(
str(self.name),
self.region,
self.sku,
self.image,
self.auth.password,
self.auth.public_key,
)
def delete(self) -> bool:
return delete_vm_components(str(self.name))
def add_extensions(self, extensions: List[Extension]) -> Union[bool, Error]:
status = []
to_create = []
for config in extensions:
if not isinstance(config["name"], str):
logging.error("vm agent - incompatable name: %s", repr(config))
continue
extension = get_extension(str(self.name), config["name"])
if extension:
logging.info(
"vm extension state: %s - %s - %s",
self.name,
config["name"],
extension.provisioning_state,
)
status.append(extension.provisioning_state)
else:
to_create.append(config)
if to_create:
for config in to_create:
create_extension(str(self.name), config)
else:
if all([x == "Succeeded" for x in status]):
return True
elif "Failed" in status:
return Error(
code=ErrorCode.VM_CREATE_FAILED,
errors=["failed to launch extension"],
)
elif not ("Creating" in status or "Updating" in status):
logging.error("vm agent - unknown state %s: %s", self.name, status)
return False

View File

@ -0,0 +1,334 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
import os
from typing import Any, Dict, List, Optional, Union, cast
from uuid import UUID
from azure.mgmt.compute import ComputeManagementClient
from azure.mgmt.compute.models import ResourceSku, ResourceSkuRestrictionsType
from memoization import cached
from msrestazure.azure_exceptions import CloudError
from onefuzztypes.enums import OS, ErrorCode
from onefuzztypes.models import Error
from onefuzztypes.primitives import Region
from .creds import get_base_resource_group, mgmt_client_factory
from .image import get_os
def list_vmss(name: UUID) -> Optional[List[str]]:
resource_group = get_base_resource_group()
client = mgmt_client_factory(ComputeManagementClient)
try:
instances = [
x.instance_id
for x in client.virtual_machine_scale_set_vms.list(
resource_group, str(name)
)
]
return instances
except CloudError as err:
logging.error("cloud error listing vmss: %s (%s)", name, err)
return None
def delete_vmss(name: UUID) -> Any:
resource_group = get_base_resource_group()
compute_client = mgmt_client_factory(ComputeManagementClient)
try:
compute_client.virtual_machine_scale_sets.delete(resource_group, str(name))
except CloudError as err:
logging.error("cloud error deleting vmss: %s (%s)", name, err)
def get_vmss(name: UUID) -> Optional[Any]:
resource_group = get_base_resource_group()
logging.debug("getting vm: %s", name)
compute_client = mgmt_client_factory(ComputeManagementClient)
try:
return compute_client.virtual_machine_scale_sets.get(resource_group, str(name))
except CloudError as err:
logging.debug("vm does not exist %s", err)
return None
def resize_vmss(name: UUID, capacity: int) -> None:
check_can_update(name)
resource_group = get_base_resource_group()
logging.info("updating VM count - name: %s vm_count: %d", name, capacity)
compute_client = mgmt_client_factory(ComputeManagementClient)
compute_client.virtual_machine_scale_sets.update(
resource_group, str(name), {"sku": {"capacity": capacity}}
)
def get_vmss_size(name: UUID) -> Optional[int]:
vmss = get_vmss(name)
if vmss is None:
return None
return cast(int, vmss.sku.capacity)
def list_instance_ids(name: UUID) -> Dict[UUID, str]:
logging.info("get instance IDs for scaleset: %s", name)
resource_group = get_base_resource_group()
compute_client = mgmt_client_factory(ComputeManagementClient)
results = {}
try:
for instance in compute_client.virtual_machine_scale_set_vms.list(
resource_group, str(name)
):
results[UUID(instance.vm_id)] = cast(str, instance.instance_id)
except CloudError:
logging.debug("scaleset not available: %s", name)
return results
@cached(ttl=60)
def get_instance_id(name: UUID, vm_id: UUID) -> Union[str, Error]:
resource_group = get_base_resource_group()
logging.info("get instance ID for scaleset node: %s:%s", name, vm_id)
compute_client = mgmt_client_factory(ComputeManagementClient)
vm_id_str = str(vm_id)
for instance in compute_client.virtual_machine_scale_set_vms.list(
resource_group, str(name)
):
if instance.vm_id == vm_id_str:
return cast(str, instance.instance_id)
return Error(
code=ErrorCode.UNABLE_TO_FIND,
errors=["unable to find scaleset machine: %s:%s" % (name, vm_id)],
)
class UnableToUpdate(Exception):
pass
def check_can_update(name: UUID) -> Any:
vmss = get_vmss(name)
if vmss is None:
raise UnableToUpdate
if vmss.provisioning_state != "Succeeded":
raise UnableToUpdate
return vmss
def reimage_vmss_nodes(name: UUID, vm_ids: List[UUID]) -> Optional[Error]:
check_can_update(name)
resource_group = get_base_resource_group()
logging.info("reimaging scaleset VM - name: %s vm_ids:%s", name, vm_ids)
compute_client = mgmt_client_factory(ComputeManagementClient)
instance_ids = []
machine_to_id = list_instance_ids(name)
for vm_id in vm_ids:
if vm_id in machine_to_id:
instance_ids.append(machine_to_id[vm_id])
else:
logging.info("unable to find vm_id for %s:%s", name, vm_id)
if instance_ids:
compute_client.virtual_machine_scale_sets.reimage_all(
resource_group, str(name), instance_ids=instance_ids
)
return None
def delete_vmss_nodes(name: UUID, vm_ids: List[UUID]) -> Optional[Error]:
check_can_update(name)
resource_group = get_base_resource_group()
logging.info("deleting scaleset VM - name: %s vm_ids:%s", name, vm_ids)
compute_client = mgmt_client_factory(ComputeManagementClient)
instance_ids = []
machine_to_id = list_instance_ids(name)
for vm_id in vm_ids:
if vm_id in machine_to_id:
instance_ids.append(machine_to_id[vm_id])
else:
logging.info("unable to find vm_id for %s:%s", name, vm_id)
if instance_ids:
compute_client.virtual_machine_scale_sets.delete_instances(
resource_group, str(name), instance_ids=instance_ids
)
return None
def update_extensions(name: UUID, extensions: List[Any]) -> None:
check_can_update(name)
resource_group = get_base_resource_group()
logging.info("updating VM extensions: %s", name)
compute_client = mgmt_client_factory(ComputeManagementClient)
compute_client.virtual_machine_scale_sets.update(
resource_group,
str(name),
{"virtual_machine_profile": {"extension_profile": {"extensions": extensions}}},
)
def create_vmss(
location: Region,
name: UUID,
vm_sku: str,
vm_count: int,
image: str,
network_id: str,
spot_instances: bool,
extensions: List[Any],
password: str,
ssh_public_key: str,
tags: Dict[str, str],
) -> Optional[Error]:
vmss = get_vmss(name)
if vmss is not None:
return None
logging.info(
"creating VM count"
"name: %s vm_sku: %s vm_count: %d "
"image: %s subnet: %s spot_instances: %s",
name,
vm_sku,
vm_count,
image,
network_id,
spot_instances,
)
resource_group = get_base_resource_group()
compute_client = mgmt_client_factory(ComputeManagementClient)
if image.startswith("/"):
image_ref = {"id": image}
else:
image_val = image.split(":", 4)
image_ref = {
"publisher": image_val[0],
"offer": image_val[1],
"sku": image_val[2],
"version": image_val[3],
}
sku = {"name": vm_sku, "tier": "Standard", "capacity": vm_count}
params: Dict[str, Any] = {
"location": location,
"do_not_run_extensions_on_overprovisioned_vms": True,
"upgrade_policy": {"mode": "Manual"},
"sku": sku,
"identity": {"type": "SystemAssigned"},
"virtual_machine_profile": {
"priority": "Regular",
"storage_profile": {"image_reference": image_ref},
"os_profile": {
"computer_name_prefix": "node",
"admin_username": "onefuzz",
"admin_password": password,
},
"network_profile": {
"network_interface_configurations": [
{
"name": "onefuzz-nic",
"primary": True,
"ip_configurations": [
{"name": "onefuzz-ip-config", "subnet": {"id": network_id}}
],
}
]
},
"extension_profile": {"extensions": extensions},
},
}
image_os = get_os(location, image)
if isinstance(image_os, Error):
return image_os
if image_os == OS.linux:
params["virtual_machine_profile"]["os_profile"]["linux_configuration"] = {
"disable_password_authentication": True,
"ssh": {
"public_keys": [
{
"path": "/home/onefuzz/.ssh/authorized_keys",
"key_data": ssh_public_key,
}
]
},
}
if spot_instances:
# Setting max price to -1 means it won't be evicted because of
# price.
#
# https://docs.microsoft.com/en-us/azure/
# virtual-machine-scale-sets/use-spot#resource-manager-templates
params["virtual_machine_profile"].update(
{
"eviction_policy": "Delete",
"priority": "Spot",
"billing_profile": {"max_price": -1},
}
)
params["tags"] = tags.copy()
owner = os.environ.get("ONEFUZZ_OWNER")
if owner:
params["tags"]["OWNER"] = owner
try:
compute_client.virtual_machine_scale_sets.create_or_update(
resource_group, name, params
)
except CloudError as err:
if "The request failed due to conflict with a concurrent request" in repr(err):
logging.debug(
"create VM had conflicts with concurrent request, ignoring %s", err
)
return None
return Error(
code=ErrorCode.VM_CREATE_FAILED,
errors=["creating vmss: %s" % err],
)
return None
@cached(ttl=60)
def list_available_skus(location: str) -> List[str]:
compute_client = mgmt_client_factory(ComputeManagementClient)
skus: List[ResourceSku] = list(
compute_client.resource_skus.list(filter="location eq '%s'" % location)
)
sku_names: List[str] = []
for sku in skus:
available = True
if sku.restrictions is not None:
for restriction in sku.restrictions:
if restriction.type == ResourceSkuRestrictionsType.location and (
location.upper() in [v.upper() for v in restriction.values]
):
available = False
break
if available:
sku_names.append(sku.name)
return sku_names

View File

@ -0,0 +1,54 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import json
from enum import Enum
from queue import Empty, Queue
from typing import Dict, Optional, Union
from uuid import UUID
from onefuzztypes.primitives import Event
EVENTS: Queue = Queue()
def resolve(data: Event) -> Union[str, int, Dict[str, str]]:
if isinstance(data, str):
return data
if isinstance(data, UUID):
return str(data)
elif isinstance(data, Enum):
return data.name
elif isinstance(data, int):
return data
elif isinstance(data, dict):
for x in data:
data[x] = str(data[x])
return data
raise NotImplementedError("no conversion from %s" % type(data))
def get_event() -> Optional[str]:
events = []
for _ in range(10):
try:
(event, data) = EVENTS.get(block=False)
events.append({"type": event, "data": data})
EVENTS.task_done()
except Empty:
break
if events:
return json.dumps({"target": "dashboard", "arguments": events})
else:
return None
def add_event(message_type: str, data: Dict[str, Event]) -> None:
for key in data:
data[key] = resolve(data[key])
EVENTS.put((message_type, data))

View File

@ -0,0 +1,328 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
from typing import List, Optional
from uuid import UUID
from onefuzztypes.enums import OS, AgentMode
from onefuzztypes.models import AgentConfig, ReproConfig
from onefuzztypes.primitives import Extension, Region
from .azure.containers import get_container_sas_url, get_file_sas_url, save_blob
from .azure.creds import get_func_storage, get_instance_name
from .azure.monitor import get_monitor_settings
from .reports import get_report
# TODO: figure out how to create VM specific SSH keys for Windows.
#
# Previously done via task specific scripts:
# if is_windows and auth is not None:
# ssh_key = auth.public_key.strip()
# ssh_path = "$env:ProgramData/ssh/administrators_authorized_keys"
# commands += ['Set-Content -Path %s -Value "%s"' % (ssh_path, ssh_key)]
# return commands
def generic_extensions(region: Region, os: OS) -> List[Extension]:
extensions = [monitor_extension(region, os)]
depedency = dependency_extension(region, os)
if depedency:
extensions.append(depedency)
return extensions
def monitor_extension(region: Region, os: OS) -> Extension:
settings = get_monitor_settings()
if os == OS.windows:
return {
"name": "OMSExtension",
"publisher": "Microsoft.EnterpriseCloud.Monitoring",
"type": "MicrosoftMonitoringAgent",
"typeHandlerVersion": "1.0",
"location": region,
"autoUpgradeMinorVersion": True,
"settings": {"workspaceId": settings["id"]},
"protectedSettings": {"workspaceKey": settings["key"]},
}
elif os == OS.linux:
return {
"name": "OMSExtension",
"publisher": "Microsoft.EnterpriseCloud.Monitoring",
"type": "OmsAgentForLinux",
"typeHandlerVersion": "1.12",
"location": region,
"autoUpgradeMinorVersion": True,
"settings": {"workspaceId": settings["id"]},
"protectedSettings": {"workspaceKey": settings["key"]},
}
raise NotImplementedError("unsupported os: %s" % os)
def dependency_extension(region: Region, os: OS) -> Optional[Extension]:
if os == OS.windows:
extension = {
"name": "DependencyAgentWindows",
"publisher": "Microsoft.Azure.Monitoring.DependencyAgent",
"type": "DependencyAgentWindows",
"typeHandlerVersion": "9.5",
"location": region,
"autoUpgradeMinorVersion": True,
}
return extension
else:
# TODO: dependency agent for linux is not reliable
# extension = {
# "name": "DependencyAgentLinux",
# "publisher": "Microsoft.Azure.Monitoring.DependencyAgent",
# "type": "DependencyAgentLinux",
# "typeHandlerVersion": "9.5",
# "location": vm.region,
# "autoUpgradeMinorVersion": True,
# }
return None
def build_pool_config(pool_name: str) -> str:
agent_config = AgentConfig(
pool_name=pool_name,
onefuzz_url="https://%s.azurewebsites.net" % get_instance_name(),
instrumentation_key=os.environ.get("APPINSIGHTS_INSTRUMENTATIONKEY"),
telemetry_key=os.environ.get("ONEFUZZ_TELEMETRY"),
)
save_blob(
"vm-scripts",
"%s/config.json" % pool_name,
agent_config.json(),
account_id=get_func_storage(),
)
return get_file_sas_url(
"vm-scripts",
"%s/config.json" % pool_name,
account_id=get_func_storage(),
read=True,
)
def update_managed_scripts(mode: AgentMode) -> None:
commands = [
"azcopy sync '%s' instance-specific-setup"
% (
get_container_sas_url(
"instance-specific-setup",
read=True,
list=True,
account_id=get_func_storage(),
)
),
"azcopy sync '%s' tools"
% (
get_container_sas_url(
"tools", read=True, list=True, account_id=get_func_storage()
)
),
]
save_blob(
"vm-scripts",
"managed.ps1",
"\r\n".join(commands) + "\r\n",
account_id=get_func_storage(),
)
save_blob(
"vm-scripts",
"managed.sh",
"\n".join(commands) + "\n",
account_id=get_func_storage(),
)
def agent_config(
region: Region, os: OS, mode: AgentMode, *, urls: Optional[List[str]] = None
) -> Extension:
update_managed_scripts(mode)
if urls is None:
urls = []
if os == OS.windows:
urls += [
get_file_sas_url(
"vm-scripts",
"managed.ps1",
account_id=get_func_storage(),
read=True,
),
get_file_sas_url(
"tools",
"win64/azcopy.exe",
account_id=get_func_storage(),
read=True,
),
get_file_sas_url(
"tools",
"win64/setup.ps1",
account_id=get_func_storage(),
read=True,
),
get_file_sas_url(
"tools",
"win64/onefuzz.ps1",
account_id=get_func_storage(),
read=True,
),
]
to_execute_cmd = (
"powershell -ExecutionPolicy Unrestricted -File win64/setup.ps1 "
"-mode %s" % (mode.name)
)
extension = {
"name": "CustomScriptExtension",
"type": "CustomScriptExtension",
"publisher": "Microsoft.Compute",
"location": region,
"type_handler_version": "1.9",
"auto_upgrade_minor_version": True,
"settings": {"commandToExecute": to_execute_cmd, "fileUris": urls},
"protectedSettings": {},
}
return extension
elif os == OS.linux:
urls += [
get_file_sas_url(
"vm-scripts",
"managed.sh",
account_id=get_func_storage(),
read=True,
),
get_file_sas_url(
"tools",
"linux/azcopy",
account_id=get_func_storage(),
read=True,
),
get_file_sas_url(
"tools",
"linux/setup.sh",
account_id=get_func_storage(),
read=True,
),
]
to_execute_cmd = "sh setup.sh %s" % (mode.name)
extension = {
"name": "CustomScript",
"publisher": "Microsoft.Azure.Extensions",
"type": "CustomScript",
"typeHandlerVersion": "2.1",
"location": region,
"autoUpgradeMinorVersion": True,
"settings": {"commandToExecute": to_execute_cmd, "fileUris": urls},
"protectedSettings": {},
}
return extension
raise NotImplementedError("unsupported OS: %s" % os)
def fuzz_extensions(region: Region, os: OS, pool_name: str) -> List[Extension]:
urls = [build_pool_config(pool_name)]
fuzz_extension = agent_config(region, os, AgentMode.fuzz, urls=urls)
extensions = generic_extensions(region, os)
extensions += [fuzz_extension]
return extensions
def repro_extensions(
region: Region,
repro_os: OS,
repro_id: UUID,
repro_config: ReproConfig,
setup_container: Optional[str],
) -> List[Extension]:
# TODO - what about contents of repro.ps1 / repro.sh?
report = get_report(repro_config.container, repro_config.path)
if report is None:
raise Exception("invalid report: %s" % repro_config)
commands = []
if setup_container:
commands += [
"azcopy sync '%s' ./setup"
% (get_container_sas_url(setup_container, read=True, list=True)),
]
urls = [
get_file_sas_url(repro_config.container, repro_config.path, read=True),
get_file_sas_url(
report.input_blob.container, report.input_blob.name, read=True
),
]
repro_files = []
if repro_os == OS.windows:
repro_files = ["%s/repro.ps1" % repro_id]
task_script = "\r\n".join(commands)
script_name = "task-setup.ps1"
else:
repro_files = ["%s/repro.sh" % repro_id, "%s/repro-stdout.sh" % repro_id]
commands += ["chmod -R +x setup"]
task_script = "\n".join(commands)
script_name = "task-setup.sh"
save_blob(
"task-configs",
"%s/%s" % (repro_id, script_name),
task_script,
account_id=get_func_storage(),
)
for repro_file in repro_files:
urls += [
get_file_sas_url(
"repro-scripts",
repro_file,
account_id=get_func_storage(),
read=True,
),
get_file_sas_url(
"task-configs",
"%s/%s" % (repro_id, script_name),
account_id=get_func_storage(),
read=True,
),
]
base_extension = agent_config(region, repro_os, AgentMode.repro, urls=urls)
extensions = generic_extensions(region, repro_os)
extensions += [base_extension]
return extensions
def proxy_manager_extensions(region: Region) -> List[Extension]:
urls = [
get_file_sas_url(
"proxy-configs",
"%s/config.json" % region,
account_id=get_func_storage(),
read=True,
),
get_file_sas_url(
"tools",
"linux/onefuzz-proxy-manager",
account_id=get_func_storage(),
read=True,
),
]
base_extension = agent_config(region, OS.linux, AgentMode.proxy, urls=urls)
extensions = generic_extensions(region, OS.linux)
extensions += [base_extension]
return extensions

View File

@ -0,0 +1,45 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import List, Tuple
from uuid import UUID
from onefuzztypes.models import Heartbeat as BASE
from onefuzztypes.models import HeartbeatEntry, HeartbeatSummary
from .orm import ORMMixin
class Heartbeat(BASE, ORMMixin):
@classmethod
def add(cls, entry: HeartbeatEntry) -> None:
for value in entry.data:
heartbeat_id = "-".join([str(entry.machine_id), value["type"].name])
heartbeat = cls(
task_id=entry.task_id,
heartbeat_id=heartbeat_id,
machine_id=entry.machine_id,
heartbeat_type=value["type"],
)
heartbeat.save()
@classmethod
def get_heartbeats(cls, task_id: UUID) -> List[HeartbeatSummary]:
entries = cls.search(query={"task_id": [task_id]})
result = []
for entry in entries:
result.append(
HeartbeatSummary(
timestamp=entry.Timestamp,
machine_id=entry.machine_id,
type=entry.heartbeat_type,
)
)
return result
@classmethod
def key_fields(cls) -> Tuple[str, str]:
return ("task_id", "heartbeat_id")

View File

@ -0,0 +1,83 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
from datetime import datetime, timedelta
from typing import List, Optional, Tuple
from onefuzztypes.enums import JobState, TaskState
from onefuzztypes.models import Job as BASE_JOB
from .orm import MappingIntStrAny, ORMMixin, QueryFilter
from .tasks.main import Task
class Job(BASE_JOB, ORMMixin):
@classmethod
def key_fields(cls) -> Tuple[str, Optional[str]]:
return ("job_id", None)
@classmethod
def search_states(cls, *, states: Optional[List[JobState]] = None) -> List["Job"]:
query: QueryFilter = {}
if states:
query["state"] = states
return cls.search(query=query)
@classmethod
def search_expired(cls) -> List["Job"]:
time_filter = "end_time lt datetime'%s'" % datetime.utcnow().isoformat()
return cls.search(
query={"state": JobState.available()}, raw_unchecked_filter=time_filter
)
def save_exclude(self) -> Optional[MappingIntStrAny]:
return {"task_info": ...}
def event_include(self) -> Optional[MappingIntStrAny]:
return {
"job_id": ...,
"state": ...,
"error": ...,
}
def telemetry_include(self) -> Optional[MappingIntStrAny]:
return {
"machine_id": ...,
"state": ...,
"scaleset_id": ...,
}
def init(self) -> None:
logging.info("init job: %s", self.job_id)
self.state = JobState.enabled
self.save()
def stopping(self) -> None:
self.state = JobState.stopping
logging.info("stopping job: %s", self.job_id)
not_stopped = [
task
for task in Task.search(query={"job_id": [self.job_id]})
if task.state != TaskState.stopped
]
if not_stopped:
for task in not_stopped:
task.state = TaskState.stopping
task.save()
else:
self.state = JobState.stopped
self.save()
def queue_stop(self) -> None:
self.queue(method=self.stopping)
def on_start(self) -> None:
# try to keep this effectively idempotent
if self.end_time is None:
self.end_time = datetime.utcnow() + timedelta(hours=self.config.duration)
self.save()

View File

@ -0,0 +1,206 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
from typing import Iterator, List, Optional
from azure.devops.connection import Connection
from azure.devops.credentials import BasicAuthentication
from azure.devops.exceptions import AzureDevOpsServiceError
from azure.devops.v6_0.work_item_tracking.models import (
CommentCreate,
JsonPatchOperation,
Wiql,
WorkItem,
)
from azure.devops.v6_0.work_item_tracking.work_item_tracking_client import (
WorkItemTrackingClient,
)
from memoization import cached
from onefuzztypes.models import ADOTemplate, Report
from .common import Render
@cached(ttl=60)
def get_ado_client(base_url: str, token: str) -> WorkItemTrackingClient:
connection = Connection(base_url=base_url, creds=BasicAuthentication("PAT", token))
client = connection.clients_v6_0.get_work_item_tracking_client()
return client
@cached(ttl=60)
def get_valid_fields(
client: WorkItemTrackingClient, project: Optional[str] = None
) -> List[str]:
valid_fields = [
x.reference_name.lower()
for x in client.get_fields(project=project, expand="ExtensionFields")
]
return valid_fields
class ADO:
def __init__(
self, container: str, filename: str, config: ADOTemplate, report: Report
):
self.config = config
self.renderer = Render(container, filename, report)
self.client = get_ado_client(self.config.base_url, self.config.auth_token)
self.project = self.render(self.config.project)
def render(self, template: str) -> str:
return self.renderer.render(template)
def existing_work_items(self) -> Iterator[WorkItem]:
filters = {}
for key in self.config.unique_fields:
if key == "System.TeamProject":
value = self.render(self.config.project)
else:
value = self.render(self.config.ado_fields[key])
filters[key.lower()] = value
valid_fields = get_valid_fields(
self.client, project=filters.get("system.teamproject")
)
post_query_filter = {}
# WIQL (Work Item Query Language) is an SQL like query language that
# doesn't support query params, safe quoting, or any other SQL-injection
# protection mechanisms.
#
# As such, build the WIQL with a those fields we can pre-determine are
# "safe" and otherwise use post-query filtering.
parts = []
for k, v in filters.items():
# Only add pre-system approved fields to the query
if k not in valid_fields:
post_query_filter[k] = v
continue
# WIQL supports wrapping values in ' or " and escaping ' by doubling it
#
# For this System.Title: hi'there
# use this query fragment: [System.Title] = 'hi''there'
#
# For this System.Title: hi"there
# use this query fragment: [System.Title] = 'hi"there'
#
# For this System.Title: hi'"there
# use this query fragment: [System.Title] = 'hi''"there'
SINGLE = "'"
parts.append("[%s] = '%s'" % (k, v.replace(SINGLE, SINGLE + SINGLE)))
query = "select [System.Id] from WorkItems"
if parts:
query += " where " + " AND ".join(parts)
wiql = Wiql(query=query)
for entry in self.client.query_by_wiql(wiql).work_items:
item = self.client.get_work_item(entry.id, expand="Fields")
lowered_fields = {x.lower(): str(y) for (x, y) in item.fields.items()}
if post_query_filter and not all(
[
k.lower() in lowered_fields and lowered_fields[k.lower()] == v
for (k, v) in post_query_filter.items()
]
):
continue
yield item
def update_existing(self, item: WorkItem) -> None:
if self.config.on_duplicate.comment:
comment = self.render(self.config.on_duplicate.comment)
self.client.add_comment(
CommentCreate(comment),
self.project,
item.id,
)
document = []
for field in self.config.on_duplicate.increment:
value = int(item.fields[field]) if field in item.fields else 0
value += 1
document.append(
JsonPatchOperation(
op="Replace", path="/fields/%s" % field, value=str(value)
)
)
for field in self.config.on_duplicate.ado_fields:
field_value = self.render(self.config.on_duplicate.ado_fields[field])
document.append(
JsonPatchOperation(
op="Replace", path="/fields/%s" % field, value=field_value
)
)
if item.fields["System.State"] in self.config.on_duplicate.set_state:
document.append(
JsonPatchOperation(
op="Replace",
path="/fields/System.State",
value=self.config.on_duplicate.set_state[
item.fields["System.State"]
],
)
)
if document:
self.client.update_work_item(document, item.id, project=self.project)
def create_new(self) -> None:
task_type = self.render(self.config.type)
document = []
if "System.Tags" not in self.config.ado_fields:
document.append(
JsonPatchOperation(
op="Add", path="/fields/System.Tags", value="Onefuzz"
)
)
for field in self.config.ado_fields:
value = self.render(self.config.ado_fields[field])
if field == "System.Tags":
value += ";Onefuzz"
document.append(
JsonPatchOperation(op="Add", path="/fields/%s" % field, value=value)
)
entry = self.client.create_work_item(
document=document, project=self.project, type=task_type
)
if self.config.comment:
comment = self.render(self.config.comment)
self.client.add_comment(
CommentCreate(comment),
self.project,
entry.id,
)
def process(self) -> None:
seen = False
for work_item in self.existing_work_items():
self.update_existing(work_item)
seen = True
if not seen:
self.create_new()
def notify_ado(
config: ADOTemplate, container: str, filename: str, report: Report
) -> None:
try:
ado = ADO(container, filename, config, report)
ado.process()
except AzureDevOpsServiceError as err:
logging.error("ADO report failed: %s", err)
except ValueError as err:
logging.error("ADO report value error: %s", err)

View File

@ -0,0 +1,61 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import Optional
from jinja2.sandbox import SandboxedEnvironment
from onefuzztypes.models import Report
from ..azure.containers import auth_download_url
from ..jobs import Job
from ..tasks.config import get_setup_container
from ..tasks.main import Task
class Render:
def __init__(self, container: str, filename: str, report: Report):
self.report = report
self.container = container
self.filename = filename
task = Task.get(report.job_id, report.task_id)
if not task:
raise ValueError
job = Job.get(report.job_id)
if not job:
raise ValueError
self.task_config = task.config
self.job_config = job.config
self.env = SandboxedEnvironment()
self.target_url: Optional[str] = None
setup_container = get_setup_container(task.config)
if setup_container:
self.target_url = auth_download_url(
setup_container, self.report.executable.replace("setup/", "", 1)
)
self.report_url = auth_download_url(container, filename)
self.input_url: Optional[str] = None
if self.report.input_blob:
self.input_url = auth_download_url(
self.report.input_blob.container, self.report.input_blob.name
)
def render(self, template: str) -> str:
return self.env.from_string(template).render(
{
"report": self.report,
"task": self.task_config,
"job": self.job_config,
"report_url": self.report_url,
"input_url": self.input_url,
"target_url": self.target_url,
"report_container": self.container,
"report_filename": self.filename,
"repro_cmd": "onefuzz repro create_and_connect %s %s"
% (self.container, self.filename),
}
)

View File

@ -0,0 +1,122 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
from typing import Dict, List, Optional, Sequence, Tuple, Union
from uuid import UUID
from memoization import cached
from onefuzztypes import models
from onefuzztypes.enums import ErrorCode, TaskState
from onefuzztypes.models import ADOTemplate, Error, NotificationTemplate, TeamsTemplate
from onefuzztypes.primitives import Container, Event
from ..azure.containers import get_container_metadata, get_file_sas_url
from ..azure.creds import get_fuzz_storage
from ..azure.queue import send_message
from ..dashboard import add_event
from ..orm import ORMMixin
from ..reports import get_report
from ..tasks.config import get_input_container_queues
from ..tasks.main import Task
from .ado import notify_ado
from .teams import notify_teams
class Notification(models.Notification, ORMMixin):
@classmethod
def get_by_id(cls, notification_id: UUID) -> Union[Error, "Notification"]:
notifications = cls.search(query={"notification_id": [notification_id]})
if not notifications:
return Error(
code=ErrorCode.INVALID_REQUEST, errors=["unable to find Notification"]
)
if len(notifications) != 1:
return Error(
code=ErrorCode.INVALID_REQUEST,
errors=["error identifying Notification"],
)
notification = notifications[0]
return notification
@classmethod
def get_existing(
cls, container: Container, config: NotificationTemplate
) -> Optional["Notification"]:
notifications = Notification.search(query={"container": [container]})
for notification in notifications:
if notification.config == config:
return notification
return None
@classmethod
def key_fields(cls) -> Tuple[str, str]:
return ("notification_id", "container")
@cached(ttl=10)
def get_notifications(container: Container) -> List[Notification]:
return Notification.search(query={"container": [container]})
@cached(ttl=10)
def get_queue_tasks() -> Sequence[Tuple[Task, Sequence[str]]]:
results = []
for task in Task.search_states(states=TaskState.available()):
containers = get_input_container_queues(task.config)
if containers:
results.append((task, containers))
return results
@cached(ttl=60)
def container_metadata(container: Container) -> Optional[Dict[str, str]]:
return get_container_metadata(container)
def new_files(container: Container, filename: str) -> None:
results: Dict[str, Event] = {"container": container, "file": filename}
metadata = container_metadata(container)
if metadata:
results["metadata"] = metadata
notifications = get_notifications(container)
if notifications:
report = get_report(container, filename)
if report:
results["executable"] = report.executable
results["crash_type"] = report.crash_type
results["crash_site"] = report.crash_site
results["job_id"] = report.job_id
results["task_id"] = report.task_id
logging.info("notifications for %s %s %s", container, filename, notifications)
done = []
for notification in notifications:
# ignore duplicate configurations
if notification.config in done:
continue
done.append(notification.config)
if isinstance(notification.config, TeamsTemplate):
notify_teams(notification.config, container, filename, report)
if not report:
continue
if isinstance(notification.config, ADOTemplate):
notify_ado(notification.config, container, filename, report)
for (task, containers) in get_queue_tasks():
if container in containers:
logging.info("queuing input %s %s %s", container, filename, task.task_id)
url = get_file_sas_url(container, filename, read=True, delete=True)
send_message(
task.task_id, bytes(url, "utf-8"), account_id=get_fuzz_storage()
)
add_event("new_file", results)

View File

@ -0,0 +1,127 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
from typing import Any, Dict, List, Optional
import requests
from onefuzztypes.models import Report, TeamsTemplate
from ..azure.containers import auth_download_url
from ..tasks.config import get_setup_container
from ..tasks.main import Task
def markdown_escape(data: str) -> str:
values = "\\*_{}[]()#+-.!"
for value in values:
data = data.replace(value, "\\" + value)
data = data.replace("`", "``")
return data
def code_block(data: str) -> str:
data = data.replace("`", "``")
return "\n```%s\n```\n" % data
def send_teams_webhook(
config: TeamsTemplate,
title: str,
facts: List[Dict[str, str]],
text: Optional[str],
) -> None:
title = markdown_escape(title)
message: Dict[str, Any] = {
"@type": "MessageCard",
"@context": "https://schema.org/extensions",
"summary": title,
"sections": [{"activityTitle": title, "facts": facts}],
}
if text:
message["sections"].append({"text": text})
response = requests.post(config.url, json=message)
if not response.ok:
logging.error("webhook failed %s %s", response.status_code, response.content)
def notify_teams(
config: TeamsTemplate, container: str, filename: str, report: Optional[Report]
) -> None:
text = None
facts: List[Dict[str, str]] = []
if report:
task = Task.get(report.job_id, report.task_id)
if not task:
logging.error(
"report with invalid task %s:%s", report.job_id, report.task_id
)
return
title = "new crash in %s: %s @ %s" % (
report.executable,
report.crash_type,
report.crash_site,
)
links = [
"[report](%s)" % auth_download_url(container, filename),
]
setup_container = get_setup_container(task.config)
if setup_container:
links.append(
"[executable](%s)"
% auth_download_url(
setup_container,
report.executable.replace("setup/", "", 1),
),
)
if report.input_blob:
links.append(
"[input](%s)"
% auth_download_url(
report.input_blob.container, report.input_blob.name
),
)
facts += [
{"name": "Files", "value": " | ".join(links)},
{
"name": "Task",
"value": markdown_escape(
"job_id: %s task_id: %s" % (report.job_id, report.task_id)
),
},
{
"name": "Repro",
"value": code_block(
"onefuzz repro create_and_connect %s %s" % (container, filename)
),
},
]
text = "## Call Stack\n" + "\n".join(code_block(x) for x in report.call_stack)
else:
title = "new file found"
facts += [
{
"name": "file",
"value": "[%s/%s](%s)"
% (
markdown_escape(container),
markdown_escape(filename),
auth_download_url(container, filename),
),
}
]
send_teams_webhook(config, title, facts, text)

View File

@ -0,0 +1,435 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import inspect
import json
from datetime import datetime
from enum import Enum
from typing import (
Any,
Callable,
Dict,
List,
Mapping,
Optional,
Tuple,
Type,
TypeVar,
Union,
)
from uuid import UUID
from azure.common import AzureConflictHttpError, AzureMissingResourceHttpError
from onefuzztypes.enums import (
ErrorCode,
JobState,
NodeState,
PoolState,
ScalesetState,
TaskState,
TelemetryEvent,
UpdateType,
VmState,
)
from onefuzztypes.models import Error
from onefuzztypes.primitives import Container, PoolName, Region
from pydantic import BaseModel, Field
from .azure.table import get_client
from .dashboard import add_event
from .telemetry import track_event_filtered
from .updates import queue_update
A = TypeVar("A", bound="ORMMixin")
QUERY_VALUE_TYPES = Union[
List[int],
List[str],
List[UUID],
List[Region],
List[Container],
List[PoolName],
List[VmState],
List[ScalesetState],
List[JobState],
List[TaskState],
List[PoolState],
List[NodeState],
]
QueryFilter = Dict[str, QUERY_VALUE_TYPES]
SAFE_STRINGS = (UUID, Container, Region, PoolName)
KEY = Union[int, str, UUID, Enum]
QUEUE_DELAY_STOPPING_SECONDS = 30
QUEUE_DELAY_CREATE_SECONDS = 5
HOURS = 60 * 60
def resolve(key: KEY) -> str:
if isinstance(key, str):
return key
elif isinstance(key, UUID):
return str(key)
elif isinstance(key, Enum):
return key.name
elif isinstance(key, int):
return str(key)
raise NotImplementedError("unsupported type %s - %s" % (type(key), repr(key)))
def build_filters(
cls: Type[A], query_args: Optional[QueryFilter]
) -> Tuple[Optional[str], QueryFilter]:
if not query_args:
return (None, {})
partition_key_field, row_key_field = cls.key_fields()
search_filter_parts = []
post_filters: QueryFilter = {}
for field, values in query_args.items():
if field not in cls.__fields__:
raise ValueError("unexpected field %s: %s" % (repr(field), cls))
if not values:
continue
if field == partition_key_field:
field_name = "PartitionKey"
elif field == row_key_field:
field_name = "RowKey"
else:
field_name = field
parts: Optional[List[str]] = None
if isinstance(values[0], int):
parts = []
for x in values:
if not isinstance(x, int):
raise TypeError("unexpected type")
parts.append("%s eq %d" % (field_name, x))
elif isinstance(values[0], Enum):
parts = []
for x in values:
if not isinstance(x, Enum):
raise TypeError("unexpected type")
parts.append("%s eq '%s'" % (field_name, x.name))
elif all(isinstance(x, SAFE_STRINGS) for x in values):
parts = ["%s eq '%s'" % (field_name, x) for x in values]
else:
post_filters[field_name] = values
if parts:
if len(parts) == 1:
search_filter_parts.append(parts[0])
else:
search_filter_parts.append("(" + " or ".join(parts) + ")")
if search_filter_parts:
return (" and ".join(search_filter_parts), post_filters)
return (None, post_filters)
def post_filter(value: Any, filters: Optional[QueryFilter]) -> bool:
if not filters:
return True
for field in filters:
if field not in value:
return False
if value[field] not in filters[field]:
return False
return True
MappingIntStrAny = Mapping[Union[int, str], Any]
# A = TypeVar("A", bound="Model")
class ModelMixin(BaseModel):
def export_exclude(self) -> Optional[MappingIntStrAny]:
return None
def raw(
self,
*,
by_alias: bool = False,
exclude_none: bool = False,
exclude: MappingIntStrAny = None,
include: MappingIntStrAny = None,
) -> Dict[str, Any]:
# cycling through json means all wrapped types get resolved, such as UUID
result: Dict[str, Any] = json.loads(
self.json(
by_alias=by_alias,
exclude_none=exclude_none,
exclude=exclude,
include=include,
)
)
return result
class ORMMixin(ModelMixin):
Timestamp: Optional[datetime] = Field(alias="Timestamp")
etag: Optional[str]
@classmethod
def table_name(cls: Type[A]) -> str:
return cls.__name__
@classmethod
def get(
cls: Type[A], PartitionKey: KEY, RowKey: Optional[KEY] = None
) -> Optional[A]:
client = get_client(table=cls.table_name())
partition_key = resolve(PartitionKey)
row_key = resolve(RowKey) if RowKey else partition_key
try:
raw = client.get_entity(cls.table_name(), partition_key, row_key)
except AzureMissingResourceHttpError:
return None
return cls.load(raw)
@classmethod
def key_fields(cls) -> Tuple[str, Optional[str]]:
raise NotImplementedError("keys not defined")
# FILTERS:
# The following
# * save_exclude: Specify fields to *exclude* from saving to Storage Tables
# * export_exclude: Specify the fields to *exclude* from sending to an external API
# * telemetry_include: Specify the fields to *include* for telemetry
#
# For implementation details see:
# https://pydantic-docs.helpmanual.io/usage/exporting_models/#advanced-include-and-exclude
def save_exclude(self) -> Optional[MappingIntStrAny]:
return None
def export_exclude(self) -> Optional[MappingIntStrAny]:
return {"etag": ..., "Timestamp": ...}
def telemetry_include(self) -> Optional[MappingIntStrAny]:
return {}
def event_include(self) -> Optional[MappingIntStrAny]:
return {}
def event(self) -> Any:
return self.raw(exclude_none=True, include=self.event_include())
def telemetry(self) -> Any:
return self.raw(exclude_none=True, include=self.telemetry_include())
def _queue_as_needed(self) -> None:
# Upon ORM save with state, if the object has a state that needs work,
# automatically queue it
state = getattr(self, "state", None)
if state is None:
return
needs_work = getattr(state, "needs_work", None)
if needs_work is None:
return
if state not in needs_work():
return
if state.name in ["stopping", "stop", "shutdown"]:
self.queue(visibility_timeout=QUEUE_DELAY_STOPPING_SECONDS)
else:
self.queue(visibility_timeout=QUEUE_DELAY_CREATE_SECONDS)
def _event_as_needed(self) -> None:
# Upon ORM save, if the object returns event data, we'll send it to the
# dashboard event subsystem
data = self.event()
if not data:
return
add_event(self.table_name(), data)
def get_keys(self) -> Tuple[KEY, KEY]:
partition_key_field, row_key_field = self.key_fields()
partition_key = getattr(self, partition_key_field)
if row_key_field:
row_key = getattr(self, row_key_field)
else:
row_key = partition_key
return (partition_key, row_key)
def save(self, new: bool = False, require_etag: bool = False) -> Optional[Error]:
# TODO: migrate to an inspect.signature() model
raw = self.raw(by_alias=True, exclude_none=True, exclude=self.save_exclude())
for key in raw:
if not isinstance(raw[key], (str, int)):
raw[key] = json.dumps(raw[key])
# for datetime fields that passed through filtering, use the real value,
# rather than a serialized form
for field in self.__fields__:
if field not in raw:
continue
if self.__fields__[field].type_ == datetime:
raw[field] = getattr(self, field)
partition_key_field, row_key_field = self.key_fields()
# PartitionKey and RowKey must be 'str'
raw["PartitionKey"] = resolve(raw[partition_key_field])
raw["RowKey"] = resolve(raw[row_key_field or partition_key_field])
del raw[partition_key_field]
if row_key_field in raw:
del raw[row_key_field]
client = get_client(table=self.table_name())
# never save the timestamp
if "Timestamp" in raw:
del raw["Timestamp"]
if new:
try:
self.etag = client.insert_entity(self.table_name(), raw)
except AzureConflictHttpError:
return Error(code=ErrorCode.UNABLE_TO_CREATE, errors=["row exists"])
elif self.etag and require_etag:
self.etag = client.replace_entity(
self.table_name(), raw, if_match=self.etag
)
else:
self.etag = client.insert_or_replace_entity(self.table_name(), raw)
self._queue_as_needed()
if self.table_name() in TelemetryEvent.__members__:
telem = self.telemetry()
if telem:
track_event_filtered(TelemetryEvent[self.table_name()], telem)
self._event_as_needed()
return None
def delete(self) -> None:
# fire off an event so Signalr knows it's being deleted
self._event_as_needed()
partition_key, row_key = self.get_keys()
client = get_client()
try:
client.delete_entity(
self.table_name(), resolve(partition_key), resolve(row_key)
)
except AzureMissingResourceHttpError:
# It's OK if the component is already deleted
pass
@classmethod
def load(cls: Type[A], data: Dict[str, Union[str, bytes, bytearray]]) -> A:
partition_key_field, row_key_field = cls.key_fields()
if partition_key_field in data:
raise Exception(
"duplicate PartitionKey field %s for %s"
% (partition_key_field, cls.table_name())
)
if row_key_field in data:
raise Exception(
"duplicate RowKey field %s for %s" % (row_key_field, cls.table_name())
)
data[partition_key_field] = data["PartitionKey"]
if row_key_field is not None:
data[row_key_field] = data["RowKey"]
del data["PartitionKey"]
del data["RowKey"]
for key in inspect.signature(cls).parameters:
if key not in data:
continue
annotation = inspect.signature(cls).parameters[key].annotation
if inspect.isclass(annotation):
if issubclass(annotation, BaseModel) or issubclass(annotation, dict):
data[key] = json.loads(data[key])
continue
if getattr(annotation, "__origin__", None) == Union and any(
inspect.isclass(x) and issubclass(x, BaseModel)
for x in annotation.__args__
):
data[key] = json.loads(data[key])
continue
# Required for Python >=3.7. In 3.6, a `Dict[_,_]` annotation is a class
# according to `inspect.isclass`.
if getattr(annotation, "__origin__", None) == dict:
data[key] = json.loads(data[key])
continue
return cls.parse_obj(data)
@classmethod
def search(
cls: Type[A],
*,
query: Optional[QueryFilter] = None,
raw_unchecked_filter: Optional[str] = None,
num_results: int = None,
) -> List[A]:
search_filter, post_filters = build_filters(cls, query)
if raw_unchecked_filter is not None:
if search_filter is None:
search_filter = raw_unchecked_filter
else:
search_filter = "(%s) and (%s)" % (search_filter, raw_unchecked_filter)
client = get_client(table=cls.table_name())
entries = []
for row in client.query_entities(
cls.table_name(), filter=search_filter, num_results=num_results
):
if not post_filter(row, post_filters):
continue
entry = cls.load(row)
entries.append(entry)
return entries
def queue(
self,
*,
method: Optional[Callable] = None,
visibility_timeout: Optional[int] = None,
) -> None:
if not hasattr(self, "state"):
raise NotImplementedError("Queued an ORM mapping without State")
update_type = UpdateType.__members__.get(type(self).__name__)
if update_type is None:
raise NotImplementedError("unsupported update type: %s" % self)
method_name: Optional[str] = None
if method is not None:
if not hasattr(method, "__name__"):
raise Exception("unable to queue method: %s" % method)
method_name = method.__name__
partition_key, row_key = self.get_keys()
queue_update(
update_type,
resolve(partition_key),
resolve(row_key),
method=method_name,
visibility_timeout=visibility_timeout,
)

View File

@ -0,0 +1,807 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import datetime
import logging
from typing import Dict, List, Optional, Tuple, Union
from uuid import UUID
from onefuzztypes.enums import (
OS,
Architecture,
ErrorCode,
NodeState,
PoolState,
ScalesetState,
TaskState,
)
from onefuzztypes.models import Error
from onefuzztypes.models import Node as BASE_NODE
from onefuzztypes.models import NodeCommand
from onefuzztypes.models import NodeTasks as BASE_NODE_TASK
from onefuzztypes.models import Pool as BASE_POOL
from onefuzztypes.models import Scaleset as BASE_SCALESET
from onefuzztypes.models import (
ScalesetNodeState,
ScalesetSummary,
WorkSet,
WorkSetSummary,
WorkUnitSummary,
)
from onefuzztypes.primitives import PoolName, Region
from pydantic import Field
from .azure.auth import build_auth
from .azure.creds import get_fuzz_storage
from .azure.image import get_os
from .azure.network import Network
from .azure.queue import create_queue, delete_queue, peek_queue, queue_object
from .azure.table import get_client
from .azure.vmss import (
UnableToUpdate,
create_vmss,
delete_vmss,
delete_vmss_nodes,
get_instance_id,
get_vmss,
get_vmss_size,
list_instance_ids,
reimage_vmss_nodes,
resize_vmss,
update_extensions,
)
from .extension import fuzz_extensions
from .orm import MappingIntStrAny, ORMMixin, QueryFilter
# Future work:
#
# Enabling autoscaling for the scalesets based on the pool work queues.
# https://docs.microsoft.com/en-us/azure/azure-monitor/platform/autoscale-common-metrics#commonly-used-storage-metrics
class Node(BASE_NODE, ORMMixin):
@classmethod
def search_states(
cls,
*,
scaleset_id: Optional[UUID] = None,
states: Optional[List[NodeState]] = None,
pool_name: Optional[str] = None,
) -> List["Node"]:
query: QueryFilter = {}
if scaleset_id:
query["scaleset_id"] = [scaleset_id]
if states:
query["state"] = states
if pool_name:
query["pool_name"] = [pool_name]
return cls.search(query=query)
@classmethod
def get_by_machine_id(cls, machine_id: UUID) -> Optional["Node"]:
nodes = cls.search(query={"machine_id": [machine_id]})
if not nodes:
return None
if len(nodes) != 1:
return None
return nodes[0]
@classmethod
def key_fields(cls) -> Tuple[str, str]:
return ("pool_name", "machine_id")
def save_exclude(self) -> Optional[MappingIntStrAny]:
return {"tasks": ...}
def telemetry_include(self) -> Optional[MappingIntStrAny]:
return {
"machine_id": ...,
"state": ...,
"scaleset_id": ...,
}
def event_include(self) -> Optional[MappingIntStrAny]:
return {
"pool_name": ...,
"machine_id": ...,
"state": ...,
"scaleset_id": ...,
}
def scaleset_node_exists(self) -> bool:
if self.scaleset_id is None:
return False
scaleset = Scaleset.get_by_id(self.scaleset_id)
if not isinstance(scaleset, Scaleset):
return False
instance_id = get_instance_id(scaleset.scaleset_id, self.machine_id)
return isinstance(instance_id, str)
@classmethod
def stop_task(cls, task_id: UUID) -> None:
# For now, this just re-images the node. Eventually, this
# should send a message to the node to let the agent shut down
# gracefully
nodes = NodeTasks.get_nodes_by_task_id(task_id)
for node in nodes:
if node.state not in NodeState.ready_for_reset():
logging.info(
"stopping task %s on machine_id:%s",
task_id,
node.machine_id,
)
node.state = NodeState.done
node.save()
class NodeTasks(BASE_NODE_TASK, ORMMixin):
@classmethod
def key_fields(cls) -> Tuple[str, str]:
return ("machine_id", "task_id")
def telemetry_include(self) -> Optional[MappingIntStrAny]:
return {
"machine_id": ...,
"task_id": ...,
"state": ...,
}
@classmethod
def get_nodes_by_task_id(cls, task_id: UUID) -> List["Node"]:
result = []
for entry in cls.search(query={"task_id": [task_id]}):
node = Node.get_by_machine_id(entry.machine_id)
if node:
result.append(node)
return result
@classmethod
def get_by_machine_id(cls, machine_id: UUID) -> List["NodeTasks"]:
return cls.search(query={"machine_id": [machine_id]})
@classmethod
def get_by_task_id(cls, task_id: UUID) -> List["NodeTasks"]:
return cls.search(query={"task_id": [task_id]})
# this isn't anticipated to be needed by the client, hence it not
# being in onefuzztypes
class NodeMessage(ORMMixin):
agent_id: UUID
message_id: str = Field(default_factory=datetime.datetime.utcnow().timestamp)
message: NodeCommand
@classmethod
def key_fields(cls) -> Tuple[str, str]:
return ("agent_id", "create_date")
@classmethod
def get_messages(
cls, agent_id: UUID, num_results: int = None
) -> List["NodeMessage"]:
entries: List["NodeMessage"] = cls.search(
query={"agent_id": [agent_id]}, num_results=num_results
)
return entries
@classmethod
def delete_messages(cls, agent_id: UUID, message_ids: List[str]) -> None:
client = get_client(table=cls.table_name())
batch = client.batch(table_name=cls.table_name())
for message_id in message_ids:
batch.delete_entity(agent_id, message_id)
client.commit_batch(cls.table_name(), batch)
class Pool(BASE_POOL, ORMMixin):
@classmethod
def create(
cls,
*,
name: PoolName,
os: OS,
arch: Architecture,
managed: bool,
client_id: Optional[UUID],
) -> "Pool":
return cls(
name=name,
os=os,
arch=arch,
managed=managed,
client_id=client_id,
config=None,
)
def save_exclude(self) -> Optional[MappingIntStrAny]:
return {
"nodes": ...,
"queue": ...,
"work_queue": ...,
"config": ...,
"node_summary": ...,
}
def export_exclude(self) -> Optional[MappingIntStrAny]:
return {
"etag": ...,
"timestamp": ...,
}
def event_include(self) -> Optional[MappingIntStrAny]:
return {
"name": ...,
"pool_id": ...,
"os": ...,
"state": ...,
"managed": ...,
}
def telemetry_include(self) -> Optional[MappingIntStrAny]:
return {
"pool_id": ...,
"os": ...,
"state": ...,
"managed": ...,
}
def populate_scaleset_summary(self) -> None:
self.scaleset_summary = [
ScalesetSummary(scaleset_id=x.scaleset_id, state=x.state)
for x in Scaleset.search_by_pool(self.name)
]
def populate_work_queue(self) -> None:
self.work_queue = []
# Only populate the work queue summaries if the pool is initialized. We
# can then be sure that the queue is available in the operations below.
if self.state == PoolState.init:
return
worksets = peek_queue(
self.get_pool_queue(), account_id=get_fuzz_storage(), object_type=WorkSet
)
for workset in worksets:
work_units = [
WorkUnitSummary(
job_id=work_unit.job_id,
task_id=work_unit.task_id,
task_type=work_unit.task_type,
)
for work_unit in workset.work_units
]
self.work_queue.append(WorkSetSummary(work_units=work_units))
def get_pool_queue(self) -> str:
return "pool-%s" % self.pool_id.hex
def init(self) -> None:
create_queue(self.get_pool_queue(), account_id=get_fuzz_storage())
self.state = PoolState.running
self.save()
def schedule_workset(self, work_set: WorkSet) -> bool:
# Don't schedule work for pools that can't and won't do work.
if self.state in [PoolState.shutdown, PoolState.halt]:
return False
return queue_object(
self.get_pool_queue(), work_set, account_id=get_fuzz_storage()
)
@classmethod
def get_by_id(cls, pool_id: UUID) -> Union[Error, "Pool"]:
pools = cls.search(query={"pool_id": [pool_id]})
if not pools:
return Error(code=ErrorCode.INVALID_REQUEST, errors=["unable to find pool"])
if len(pools) != 1:
return Error(
code=ErrorCode.INVALID_REQUEST, errors=["error identifying pool"]
)
pool = pools[0]
return pool
@classmethod
def get_by_name(cls, name: PoolName) -> Union[Error, "Pool"]:
pools = cls.search(query={"name": [name]})
if not pools:
return Error(code=ErrorCode.INVALID_REQUEST, errors=["unable to find pool"])
if len(pools) != 1:
return Error(
code=ErrorCode.INVALID_REQUEST, errors=["error identifying pool"]
)
pool = pools[0]
return pool
@classmethod
def search_states(cls, *, states: Optional[List[PoolState]] = None) -> List["Pool"]:
query: QueryFilter = {}
if states:
query["state"] = states
return cls.search(query=query)
def shutdown(self) -> None:
""" shutdown allows nodes to finish current work then delete """
scalesets = Scaleset.search_by_pool(self.name)
nodes = Node.search(query={"pool_name": [self.name]})
if not scalesets and not nodes:
logging.info("pool stopped, deleting: %s", self.name)
self.state = PoolState.halt
self.delete()
return
for scaleset in scalesets:
scaleset.state = ScalesetState.shutdown
scaleset.save()
for node in nodes:
node.state = NodeState.shutdown
node.save()
self.save()
def halt(self) -> None:
""" halt the pool immediately """
scalesets = Scaleset.search_by_pool(self.name)
nodes = Node.search(query={"pool_name": [self.name]})
if not scalesets and not nodes:
delete_queue(self.get_pool_queue(), account_id=get_fuzz_storage())
logging.info("pool stopped, deleting: %s", self.name)
self.state = PoolState.halt
self.delete()
return
for scaleset in scalesets:
scaleset.state = ScalesetState.halt
scaleset.save()
for node in nodes:
logging.info(
"deleting node from pool: %s (%s) - machine_id:%s",
self.pool_id,
self.name,
node.machine_id,
)
node.delete()
self.save()
@classmethod
def key_fields(cls) -> Tuple[str, str]:
return ("name", "pool_id")
class Scaleset(BASE_SCALESET, ORMMixin):
def save_exclude(self) -> Optional[MappingIntStrAny]:
return {"nodes": ...}
def event_include(self) -> Optional[MappingIntStrAny]:
return {
"pool_name": ...,
"scaleset_id": ...,
"state": ...,
"os": ...,
"size": ...,
"error": ...,
}
def telemetry_include(self) -> Optional[MappingIntStrAny]:
return {
"scaleset_id": ...,
"os": ...,
"vm_sku": ...,
"size": ...,
"spot_instances": ...,
}
@classmethod
def create(
cls,
*,
pool_name: PoolName,
vm_sku: str,
image: str,
region: Region,
size: int,
spot_instances: bool,
tags: Dict[str, str],
client_id: Optional[UUID] = None,
client_object_id: Optional[UUID] = None,
) -> "Scaleset":
return cls(
pool_name=pool_name,
vm_sku=vm_sku,
image=image,
region=region,
size=size,
spot_instances=spot_instances,
auth=build_auth(),
client_id=client_id,
client_object_id=client_object_id,
tags=tags,
)
@classmethod
def search_by_pool(cls, pool_name: PoolName) -> List["Scaleset"]:
return cls.search(query={"pool_name": [pool_name]})
@classmethod
def get_by_id(cls, scaleset_id: UUID) -> Union[Error, "Scaleset"]:
scalesets = cls.search(query={"scaleset_id": [scaleset_id]})
if not scalesets:
return Error(
code=ErrorCode.INVALID_REQUEST, errors=["unable to find scaleset"]
)
if len(scalesets) != 1:
return Error(
code=ErrorCode.INVALID_REQUEST, errors=["error identifying scaleset"]
)
scaleset = scalesets[0]
return scaleset
@classmethod
def get_by_object_id(cls, object_id: UUID) -> List["Scaleset"]:
return cls.search(query={"client_object_id": [object_id]})
def init(self) -> None:
logging.info("scaleset init: %s", self.scaleset_id)
# Handle the race condition between a pool being deleted and a
# scaleset being added to the pool.
pool = Pool.get_by_name(self.pool_name)
if isinstance(pool, Error):
self.error = pool
self.state = ScalesetState.halt
self.save()
return
if pool.state == PoolState.init:
logging.info(
"scaleset waiting for pool: %s - %s", self.pool_name, self.scaleset_id
)
elif pool.state == PoolState.running:
image_os = get_os(self.region, self.image)
if isinstance(image_os, Error):
self.error = image_os
self.state = ScalesetState.creation_failed
elif image_os != pool.os:
self.error = Error(
code=ErrorCode.INVALID_REQUEST,
errors=["invalid os (got: %s needed: %s)" % (image_os, pool.os)],
)
self.state = ScalesetState.creation_failed
else:
self.state = ScalesetState.setup
else:
self.state = ScalesetState.setup
self.save()
def setup(self) -> None:
# TODO: How do we pass in SSH configs for Windows? Previously
# This was done as part of the generated per-task setup script.
logging.info("scaleset setup: %s", self.scaleset_id)
network = Network(self.region)
network_id = network.get_id()
if not network_id:
logging.info("creating network: %s", self.region)
result = network.create()
if isinstance(result, Error):
self.error = result
self.state = ScalesetState.creation_failed
self.save()
return
if self.auth is None:
self.error = Error(
code=ErrorCode.UNABLE_TO_CREATE, errors=["missing required auth"]
)
self.state = ScalesetState.creation_failed
self.save()
return
vmss = get_vmss(self.scaleset_id)
if vmss is None:
pool = Pool.get_by_name(self.pool_name)
if isinstance(pool, Error):
self.error = pool
self.state = ScalesetState.halt
self.save()
return
logging.info("creating scaleset: %s", self.scaleset_id)
extensions = fuzz_extensions(self.region, pool.os, self.pool_name)
result = create_vmss(
self.region,
self.scaleset_id,
self.vm_sku,
self.size,
self.image,
network_id,
self.spot_instances,
extensions,
self.auth.password,
self.auth.public_key,
self.tags,
)
if isinstance(result, Error):
self.error = result
logging.error(
"stopping task because of failed vmss: %s %s",
self.scaleset_id,
result,
)
self.state = ScalesetState.creation_failed
else:
logging.info("creating scaleset: %s", self.scaleset_id)
elif vmss.provisioning_state == "Creating":
logging.info("Waiting on scaleset creation: %s", self.scaleset_id)
else:
logging.info("scaleset running: %s", self.scaleset_id)
self.state = ScalesetState.running
self.client_object_id = vmss.identity.principal_id
self.save()
# result = 'did I modify the scaleset in azure'
def cleanup_nodes(self) -> bool:
if self.state == ScalesetState.halt:
self.halt()
return True
nodes = Node.search_states(
scaleset_id=self.scaleset_id, states=NodeState.ready_for_reset()
)
if not nodes:
logging.debug("scaleset node gc done (no nodes) %s", self.scaleset_id)
return False
to_delete = []
to_reimage = []
for node in nodes:
# delete nodes that are not waiting on the scaleset GC
if not node.scaleset_node_exists():
node.delete()
elif node.state in [NodeState.shutdown, NodeState.halt]:
to_delete.append(node)
else:
to_reimage.append(node)
# Perform operations until they fail due to scaleset getting locked
try:
if to_delete:
self.delete_nodes(to_delete)
for node in to_delete:
node.state = NodeState.halt
node.save()
if to_reimage:
self.reimage_nodes(to_reimage)
except UnableToUpdate:
logging.info("scaleset update already in progress: %s", self.scaleset_id)
return True
def resize(self) -> None:
logging.info(
"scaleset resize: %s - current: %s new: %s",
self.scaleset_id,
self.size,
self.new_size,
)
# no work needed to resize
if self.new_size is None:
self.state = ScalesetState.running
self.save()
return
# just in case, always ensure size is within max capacity
self.new_size = min(self.new_size, self.max_size())
# Treat Azure knowledge of the size of the scaleset as "ground truth"
size = get_vmss_size(self.scaleset_id)
if size is None:
logging.info("scaleset is unavailable. Re-queuing")
self.save()
return
if size == self.new_size:
# NOTE: this is the only place we reset to the 'running' state.
# This ensures that our idea of scaleset size agrees with Azure
node_count = len(Node.search_states(scaleset_id=self.scaleset_id))
if node_count == self.size:
logging.info("resize finished: %s", self.scaleset_id)
self.new_size = None
self.state = ScalesetState.running
else:
logging.info(
"resize is finished, waiting for nodes to check in: "
"%s (%d of %d nodes checked in)",
self.scaleset_id,
node_count,
self.size,
)
# When adding capacity, call the resize API directly
elif self.new_size > self.size:
try:
resize_vmss(self.scaleset_id, self.new_size)
except UnableToUpdate:
logging.info("scaleset is mid-operation already")
# Shut down any nodes without work. Otherwise, rely on Scaleset.reimage_node
# to pick up that the scaleset is too big upon task completion
else:
nodes = Node.search_states(
scaleset_id=self.scaleset_id, states=[NodeState.init, NodeState.free]
)
for node in nodes:
if size > self.new_size:
node.state = NodeState.halt
node.save()
size -= 1
else:
break
self.save()
def delete_nodes(self, nodes: List[Node]) -> None:
if not nodes:
logging.debug("no nodes to delete")
return
if self.state == ScalesetState.halt:
logging.debug("scaleset delete will delete node: %s", self.scaleset_id)
return
machine_ids = [x.machine_id for x in nodes]
logging.info("deleting %s:%s", self.scaleset_id, machine_ids)
delete_vmss_nodes(self.scaleset_id, machine_ids)
self.size -= len(machine_ids)
self.save()
def reimage_nodes(self, nodes: List[Node]) -> None:
from .tasks.main import Task
if not nodes:
logging.debug("no nodes to reimage")
return
for node in nodes:
for entry in NodeTasks.get_by_machine_id(node.machine_id):
task = Task.get_by_task_id(entry.task_id)
if isinstance(task, Task):
if task.state in [TaskState.stopping, TaskState.stopped]:
continue
task.error = Error(
code=ErrorCode.TASK_FAILED,
errors=["node reimaged during task execution"],
)
task.state = TaskState.stopping
task.save()
entry.delete()
if self.state == ScalesetState.shutdown:
self.delete_nodes(nodes)
return
if self.state == ScalesetState.halt:
logging.debug("scaleset delete will delete node: %s", self.scaleset_id)
return
machine_ids = [x.machine_id for x in nodes]
result = reimage_vmss_nodes(self.scaleset_id, machine_ids)
if isinstance(result, Error):
raise Exception(
"unable to reimage nodes: %s:%s - %s"
% (self.scaleset_id, machine_ids, result)
)
def shutdown(self) -> None:
logging.info("scaleset shutdown: %s", self.scaleset_id)
size = get_vmss_size(self.scaleset_id)
if size is None or size == 0:
self.state = ScalesetState.halt
self.halt()
return
self.save()
def halt(self) -> None:
for node in Node.search_states(scaleset_id=self.scaleset_id):
logging.info("deleting node %s:%s", self.scaleset_id, node.machine_id)
node.delete()
vmss = get_vmss(self.scaleset_id)
if vmss is None:
logging.info("scaleset deleted: %s", self.scaleset_id)
self.state = ScalesetState.halt
self.delete()
else:
logging.info("scaleset deleting: %s", self.scaleset_id)
delete_vmss(self.scaleset_id)
self.save()
def max_size(self) -> int:
# https://docs.microsoft.com/en-us/azure/virtual-machine-scale-sets/
# virtual-machine-scale-sets-placement-groups#checklist-for-using-large-scale-sets
if self.image.startswith("/"):
return 600
else:
return 1000
@classmethod
def search_states(
cls, *, states: Optional[List[ScalesetState]] = None
) -> List["Scaleset"]:
query: QueryFilter = {}
if states:
query["state"] = states
return cls.search(query=query)
def update_nodes(self) -> None:
# Be in at-least 'setup' before checking for the list of VMs
if self.state == self.init:
return
nodes = Node.search_states(scaleset_id=self.scaleset_id)
azure_nodes = list_instance_ids(self.scaleset_id)
self.nodes = []
for (machine_id, instance_id) in azure_nodes.items():
node_state: Optional[ScalesetNodeState] = None
for node in nodes:
if node.machine_id == machine_id:
node_state = ScalesetNodeState(
machine_id=machine_id,
instance_id=instance_id,
state=node.state,
)
break
if not node_state:
node_state = ScalesetNodeState(
machine_id=machine_id,
instance_id=instance_id,
)
self.nodes.append(node_state)
def update_configs(self) -> None:
if self.state != ScalesetState.running:
logging.debug(
"scaleset not running, not updating configs: %s", self.scaleset_id
)
return
pool = Pool.get_by_name(self.pool_name)
if isinstance(pool, Error):
self.error = pool
return self.halt()
logging.debug("updating scaleset configs: %s", self.scaleset_id)
extensions = fuzz_extensions(self.region, pool.os, self.pool_name)
try:
update_extensions(self.scaleset_id, extensions)
except UnableToUpdate:
logging.debug(
"unable to update configs, update already in progress: %s",
self.scaleset_id,
)
@classmethod
def key_fields(cls) -> Tuple[str, str]:
return ("pool_name", "scaleset_id")

View File

@ -0,0 +1,236 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import datetime
import logging
import os
from typing import List, Optional, Tuple
from azure.mgmt.compute.models import VirtualMachine
from onefuzztypes.enums import VmState
from onefuzztypes.models import (
Authentication,
Error,
Forward,
ProxyConfig,
ProxyHeartbeat,
)
from onefuzztypes.primitives import Region
from pydantic import Field
from .__version__ import __version__
from .azure.auth import build_auth
from .azure.containers import get_file_sas_url, save_blob
from .azure.ip import get_public_ip
from .azure.queue import get_queue_sas
from .azure.vm import VM
from .extension import proxy_manager_extensions
from .orm import HOURS, MappingIntStrAny, ORMMixin, QueryFilter
from .proxy_forward import ProxyForward
PROXY_SKU = "Standard_B2s"
PROXY_IMAGE = "Canonical:UbuntuServer:18.04-LTS:latest"
# This isn't intended to ever be shared to the client, hence not being in
# onefuzztypes
class Proxy(ORMMixin):
region: Region
state: VmState = Field(default=VmState.init)
auth: Authentication = Field(default_factory=build_auth)
ip: Optional[str]
error: Optional[str]
version: str = Field(default=__version__)
heartbeat: Optional[ProxyHeartbeat]
@classmethod
def key_fields(cls) -> Tuple[str, Optional[str]]:
return ("region", None)
def event_include(self) -> Optional[MappingIntStrAny]:
return {
"region": ...,
"state": ...,
"ip": ...,
"error": ...,
}
def get_vm(self) -> VM:
vm = VM(
name="proxy-%s" % self.region,
region=self.region,
sku=PROXY_SKU,
image=PROXY_IMAGE,
auth=self.auth,
)
return vm
def init(self) -> None:
vm = self.get_vm()
vm_data = vm.get()
if vm_data:
if vm_data.provisioning_state == "Failed":
self.set_failed(vm)
else:
self.save_proxy_config()
self.state = VmState.extensions_launch
else:
result = vm.create()
if isinstance(result, Error):
self.error = repr(result)
self.state = VmState.stopping
self.save()
def set_failed(self, vm_data: VirtualMachine) -> None:
logging.error("vm failed to provision: %s", vm_data.name)
for status in vm_data.instance_view.statuses:
if status.level.name.lower() == "error":
logging.error(
"vm status: %s %s %s %s",
vm_data.name,
status.code,
status.display_status,
status.message,
)
self.state = VmState.vm_allocation_failed
def extensions_launch(self) -> None:
vm = self.get_vm()
vm_data = vm.get()
if not vm_data:
logging.error("Azure VM does not exist: %s", vm.name)
self.state = VmState.stopping
self.save()
return
if vm_data.provisioning_state == "Failed":
self.set_failed(vm_data)
self.save()
return
ip = get_public_ip(vm_data.network_profile.network_interfaces[0].id)
if ip is None:
self.save()
return
self.ip = ip
extensions = proxy_manager_extensions(self.region)
result = vm.add_extensions(extensions)
if isinstance(result, Error):
logging.error("vm extension failed: %s", repr(result))
self.error = repr(result)
self.state = VmState.stopping
elif result:
self.state = VmState.running
self.save()
def stopping(self) -> None:
vm = self.get_vm()
if not vm.is_deleted():
logging.info("stopping proxy: %s", self.region)
vm.delete()
self.save()
else:
self.stopped()
def stopped(self) -> None:
logging.info("removing proxy: %s", self.region)
self.delete()
def is_used(self) -> bool:
if len(self.get_forwards()) == 0:
logging.info("proxy has no forwards: %s", self.region)
return False
return True
def is_alive(self) -> bool:
# Unfortunately, with and without TZ information is required for compare
# or exceptions are generated
ten_minutes_ago_no_tz = datetime.datetime.utcnow() - datetime.timedelta(
minutes=10
)
ten_minutes_ago = ten_minutes_ago_no_tz.astimezone(datetime.timezone.utc)
if (
self.heartbeat is not None
and self.heartbeat.timestamp < ten_minutes_ago_no_tz
):
logging.error(
"proxy last heartbeat is more than an 10 minutes old: %s", self.region
)
return False
elif not self.heartbeat and self.Timestamp and self.Timestamp < ten_minutes_ago:
logging.error(
"proxy has no heartbeat in the last 10 minutes: %s", self.region
)
return False
return True
def get_forwards(self) -> List[Forward]:
forwards: List[Forward] = []
for entry in ProxyForward.search_forward(region=self.region):
if entry.endtime < datetime.datetime.now(tz=datetime.timezone.utc):
entry.delete()
else:
forwards.append(
Forward(
src_port=entry.port,
dst_ip=entry.dst_ip,
dst_port=entry.dst_port,
)
)
return forwards
def save_proxy_config(self) -> None:
forwards = self.get_forwards()
proxy_config = ProxyConfig(
url=get_file_sas_url(
"proxy-configs",
"%s/config.json" % self.region,
account_id=os.environ["ONEFUZZ_FUNC_STORAGE"],
read=True,
),
notification=get_queue_sas(
"proxy",
account_id=os.environ["ONEFUZZ_FUNC_STORAGE"],
add=True,
),
forwards=forwards,
region=self.region,
)
save_blob(
"proxy-configs",
"%s/config.json" % self.region,
proxy_config.json(),
account_id=os.environ["ONEFUZZ_FUNC_STORAGE"],
)
def queue_stop(self, count: int) -> None:
self.queue(method=self.stopping, visibility_timeout=count * HOURS)
@classmethod
def search_states(cls, *, states: Optional[List[VmState]] = None) -> List["Proxy"]:
query: QueryFilter = {}
if states:
query["state"] = states
return cls.search(query=query)
@classmethod
def get_or_create(cls, region: Region) -> Optional["Proxy"]:
proxy = Proxy.get(region)
if proxy is not None:
if proxy.version != __version__:
# If the proxy is out-of-date, delete and re-create it
proxy.state = VmState.stopping
proxy.save()
return None
return proxy
proxy = Proxy(region=region)
proxy.save()
return proxy

View File

@ -0,0 +1,134 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import datetime
import logging
from typing import List, Optional, Tuple, Union
from uuid import UUID
from onefuzztypes.enums import ErrorCode
from onefuzztypes.models import Error, Forward
from onefuzztypes.primitives import Region
from pydantic import Field
from .azure.ip import get_scaleset_instance_ip
from .orm import ORMMixin, QueryFilter
PORT_RANGES = range(6000, 7000)
# This isn't intended to ever be shared to the client, hence not being in
# onefuzztypes
class ProxyForward(ORMMixin):
region: Region
port: int
scaleset_id: UUID
machine_id: UUID
dst_ip: str
dst_port: int
endtime: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
@classmethod
def key_fields(cls) -> Tuple[str, str]:
return ("region", "port")
@classmethod
def update_or_create(
cls,
region: Region,
scaleset_id: UUID,
machine_id: UUID,
dst_port: int,
duration: int,
) -> Union["ProxyForward", Error]:
private_ip = get_scaleset_instance_ip(scaleset_id, machine_id)
if not private_ip:
return Error(
code=ErrorCode.UNABLE_TO_PORT_FORWARD, errors=["no private ip for node"]
)
entries = cls.search_forward(
scaleset_id=scaleset_id,
machine_id=machine_id,
dst_port=dst_port,
region=region,
)
if entries:
entry = entries[0]
entry.endtime = datetime.datetime.utcnow() + datetime.timedelta(
hours=duration
)
entry.save()
return entry
existing = [int(x.port) for x in entries]
for port in PORT_RANGES:
if port in existing:
continue
entry = cls(
region=region,
port=port,
scaleset_id=scaleset_id,
machine_id=machine_id,
dst_ip=private_ip,
dst_port=dst_port,
endtime=datetime.datetime.utcnow() + datetime.timedelta(hours=duration),
)
result = entry.save(new=True)
if isinstance(result, Error):
logging.info("port is already used: %s", entry)
continue
return entry
return Error(
code=ErrorCode.UNABLE_TO_PORT_FORWARD, errors=["all forward ports used"]
)
@classmethod
def remove_forward(
cls,
scaleset_id: UUID,
*,
machine_id: Optional[UUID] = None,
dst_port: Optional[int] = None,
) -> List[Region]:
entries = cls.search_forward(
scaleset_id=scaleset_id, machine_id=machine_id, dst_port=dst_port
)
regions = set()
for entry in entries:
regions.add(entry.region)
entry.delete()
return list(regions)
@classmethod
def search_forward(
cls,
*,
scaleset_id: Optional[UUID] = None,
region: Optional[Region] = None,
machine_id: Optional[UUID] = None,
dst_port: Optional[int] = None,
) -> List["ProxyForward"]:
query: QueryFilter = {}
if region is not None:
query["region"] = [region]
if scaleset_id is not None:
query["scaleset_id"] = [scaleset_id]
if machine_id is not None:
query["machine_id"] = [machine_id]
if dst_port is not None:
query["dst_port"] = [dst_port]
return cls.search(query=query)
def to_forward(self) -> Forward:
return Forward(src_port=self.port, dst_ip=self.dst_ip, dst_port=self.dst_port)

View File

@ -0,0 +1,56 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import json
import logging
from typing import Optional
from onefuzztypes.models import Report
from pydantic import ValidationError
from .azure.containers import get_blob
def parse_report(content: str, metadata: Optional[str] = None) -> Optional[Report]:
if isinstance(content, bytes):
try:
content = content.decode()
except UnicodeDecodeError as err:
logging.error(
"unable to parse report (%s): unicode decode of report failed - %s",
metadata,
err,
)
return None
try:
data = json.loads(content)
except json.decoder.JSONDecodeError as err:
logging.error(
"unable to parse report (%s): json decoding failed - %s", metadata, err
)
return None
try:
entry = Report.parse_obj(data)
except ValidationError as err:
logging.error("unable to parse report (%s): %s", metadata, err)
return None
return entry
def get_report(container: str, filename: str) -> Optional[Report]:
metadata = "/".join([container, filename])
if not filename.endswith(".json"):
logging.error("get_report invalid extension: %s", metadata)
return None
blob = get_blob(container, filename)
if blob is None:
logging.error("get_report invalid blob: %s", metadata)
return None
return parse_report(blob, metadata=metadata)

View File

@ -0,0 +1,237 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
from typing import List, Optional, Tuple, Union
from azure.mgmt.compute.models import VirtualMachine
from onefuzztypes.enums import OS, ContainerType, ErrorCode, VmState
from onefuzztypes.models import Error
from onefuzztypes.models import Repro as BASE_REPRO
from onefuzztypes.models import ReproConfig, TaskVm
from .azure.auth import build_auth
from .azure.containers import save_blob
from .azure.creds import get_base_region, get_func_storage
from .azure.ip import get_public_ip
from .azure.vm import VM
from .extension import repro_extensions
from .orm import HOURS, ORMMixin, QueryFilter
from .reports import get_report
from .tasks.main import Task
DEFAULT_OS = {
OS.linux: "Canonical:UbuntuServer:18.04-LTS:latest",
OS.windows: "MicrosoftWindowsDesktop:Windows-10:rs5-pro:latest",
}
DEFAULT_SKU = "Standard_DS1_v2"
class Repro(BASE_REPRO, ORMMixin):
def set_error(self, error: Error) -> None:
logging.error(
"repro failed: vm_id: %s task_id: %s: error: %s",
self.vm_id,
self.task_id,
error,
)
self.error = error
self.state = VmState.stopping
self.save()
def get_vm(self) -> VM:
task = Task.get_by_task_id(self.task_id)
if isinstance(task, Error):
raise Exception("previously existing task missing: %s", self.task_id)
vm_config = task.get_repro_vm_config()
if vm_config is None:
# if using a pool without any scalesets defined yet, use reasonable defaults
if task.os not in DEFAULT_OS:
raise NotImplementedError("unsupported OS for repro %s" % task.os)
vm_config = TaskVm(
region=get_base_region(), sku=DEFAULT_SKU, image=DEFAULT_OS[task.os]
)
if self.auth is None:
raise Exception("missing auth")
return VM(
name=self.vm_id,
region=vm_config.region,
sku=vm_config.sku,
image=vm_config.image,
auth=self.auth,
)
def init(self) -> None:
vm = self.get_vm()
vm_data = vm.get()
if vm_data:
if vm_data.provisioning_state == "Failed":
self.set_failed(vm)
else:
script_result = self.build_repro_script()
if isinstance(script_result, Error):
return self.set_error(script_result)
self.state = VmState.extensions_launch
else:
result = vm.create()
if isinstance(result, Error):
return self.set_error(result)
self.save()
def set_failed(self, vm_data: VirtualMachine) -> None:
errors = []
for status in vm_data.instance_view.statuses:
if status.level.name.lower() == "error":
errors.append(
"%s %s %s" % (status.code, status.display_status, status.message)
)
return self.set_error(Error(code=ErrorCode.VM_CREATE_FAILED, errors=errors))
def get_setup_container(self) -> Optional[str]:
task = Task.get_by_task_id(self.task_id)
if isinstance(task, Task):
for container in task.config.containers:
if container.type == ContainerType.setup:
return container.name
return None
def extensions_launch(self) -> None:
vm = self.get_vm()
vm_data = vm.get()
if not vm_data:
return self.set_error(
Error(
code=ErrorCode.VM_CREATE_FAILED,
errors=["failed before launching extensions"],
)
)
if vm_data.provisioning_state == "Failed":
return self.set_failed(vm_data)
if not self.ip:
self.ip = get_public_ip(vm_data.network_profile.network_interfaces[0].id)
extensions = repro_extensions(
vm.region, self.os, self.vm_id, self.config, self.get_setup_container()
)
result = vm.add_extensions(extensions)
if isinstance(result, Error):
return self.set_error(result)
elif result:
self.state = VmState.running
self.save()
def stopping(self) -> None:
vm = self.get_vm()
if not vm.is_deleted():
logging.info("vm stopping: %s", self.vm_id)
vm.delete()
self.save()
else:
self.stopped()
def stopped(self) -> None:
logging.info("vm stopped: %s", self.vm_id)
self.delete()
def build_repro_script(self) -> Optional[Error]:
if self.auth is None:
return Error(code=ErrorCode.VM_CREATE_FAILED, errors=["missing auth"])
task = Task.get_by_task_id(self.task_id)
if isinstance(task, Error):
return task
report = get_report(self.config.container, self.config.path)
if report is None:
return Error(code=ErrorCode.VM_CREATE_FAILED, errors=["missing report"])
files = {}
if task.os == OS.windows:
ssh_path = "$env:ProgramData/ssh/administrators_authorized_keys"
cmds = [
'Set-Content -Path %s -Value "%s"' % (ssh_path, self.auth.public_key),
". C:\\onefuzz\\tools\\win64\\onefuzz.ps1",
"Set-SetSSHACL",
'while (1) { cdb -server tcp:port=1337 -c "g" setup\\%s %s }'
% (
task.config.task.target_exe,
report.input_blob.name,
),
]
cmd = "\r\n".join(cmds)
files["repro.ps1"] = cmd
elif task.os == OS.linux:
gdb_fmt = (
"ASAN_OPTIONS='abort_on_error=1' gdbserver "
"%s /onefuzz/setup/%s /onefuzz/downloaded/%s"
)
cmd = "while :; do %s; done" % (
gdb_fmt
% (
"localhost:1337",
task.config.task.target_exe,
report.input_blob.name,
)
)
files["repro.sh"] = cmd
cmd = "#!/bin/bash\n%s" % (
gdb_fmt % ("-", task.config.task.target_exe, report.input_blob.name)
)
files["repro-stdout.sh"] = cmd
else:
raise NotImplementedError("invalid task os: %s" % task.os)
for filename in files:
save_blob(
"repro-scripts",
"%s/%s" % (self.vm_id, filename),
files[filename],
account_id=get_func_storage(),
)
logging.info("saved repro script")
return None
def queue_stop(self, count: int) -> None:
self.queue(method=self.stopping, visibility_timeout=count * HOURS)
@classmethod
def search_states(cls, *, states: Optional[List[VmState]] = None) -> List["Repro"]:
query: QueryFilter = {}
if states:
query["state"] = states
return cls.search(query=query)
@classmethod
def create(cls, config: ReproConfig) -> Union[Error, "Repro"]:
report = get_report(config.container, config.path)
if not report:
return Error(
code=ErrorCode.UNABLE_TO_FIND, errors=["unable to find report"]
)
task = Task.get_by_task_id(report.task_id)
if isinstance(task, Error):
return task
vm = cls(config=config, task_id=task.task_id, os=task.os, auth=build_auth())
vm.save()
vm.queue_stop(config.duration)
return vm
@classmethod
def key_fields(cls) -> Tuple[str, Optional[str]]:
return ("vm_id", None)

View File

@ -0,0 +1,151 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import json
import logging
import os
from typing import TYPE_CHECKING, Optional, Sequence, Type, TypeVar, Union
from uuid import UUID
from azure.functions import HttpRequest, HttpResponse
from azure.graphrbac.models import GraphErrorException
from onefuzztypes.enums import ErrorCode
from onefuzztypes.models import Error
from onefuzztypes.responses import BaseResponse
from pydantic import ValidationError
from .azure.creds import is_member_of
from .orm import ModelMixin
# We don't actually use these types at runtime at this time. Rather,
# these are used in a bound TypeVar. MyPy suggests to only import these
# types during type checking.
if TYPE_CHECKING:
from onefuzztypes.requests import BaseRequest # noqa: F401
from pydantic import BaseModel # noqa: F401
def check_access(req: HttpRequest) -> Optional[Error]:
if "ONEFUZZ_AAD_GROUP_ID" not in os.environ:
return None
group_id = os.environ["ONEFUZZ_AAD_GROUP_ID"]
member_id = req.headers["x-ms-client-principal-id"]
try:
result = is_member_of(group_id, member_id)
except GraphErrorException:
return Error(
code=ErrorCode.UNAUTHORIZED, errors=["unable to interact with graph"]
)
if not result:
logging.error("unauthorized access: %s is not in %s", member_id, group_id)
return Error(
code=ErrorCode.UNAUTHORIZED,
errors=["not approved to use this instance of onefuzz"],
)
return None
def ok(
data: Union[BaseResponse, Sequence[BaseResponse], ModelMixin, Sequence[ModelMixin]]
) -> HttpResponse:
if isinstance(data, BaseResponse):
return HttpResponse(data.json(exclude_none=True), mimetype="application/json")
if isinstance(data, list) and len(data) > 0 and isinstance(data[0], BaseResponse):
decoded = [json.loads(x.json(exclude_none=True)) for x in data]
return HttpResponse(json.dumps(decoded), mimetype="application/json")
if isinstance(data, ModelMixin):
return HttpResponse(
data.json(exclude_none=True, exclude=data.export_exclude()),
mimetype="application/json",
)
decoded = [
x.raw(exclude_none=True, exclude=x.export_exclude())
if isinstance(x, ModelMixin)
else x
for x in data
]
return HttpResponse(
json.dumps(decoded),
mimetype="application/json",
)
def not_ok(
error: Error, *, status_code: int = 400, context: Union[str, UUID]
) -> HttpResponse:
if 400 <= status_code and status_code <= 599:
logging.error("request error - %s: %s" % (str(context), error.json()))
return HttpResponse(
error.json(), status_code=status_code, mimetype="application/json"
)
else:
raise Exception(
"status code %s is not int the expected range [400; 599]" % status_code
)
def redirect(location: str) -> HttpResponse:
return HttpResponse(status_code=302, headers={"Location": location})
def convert_error(err: ValidationError) -> Error:
errors = []
for error in err.errors():
if isinstance(error["loc"], tuple):
name = ".".join([str(x) for x in error["loc"]])
else:
name = str(error["loc"])
errors.append("%s: %s" % (name, error["msg"]))
return Error(code=ErrorCode.INVALID_REQUEST, errors=errors)
# TODO: loosen restrictions here during dev. We should be specific
# about only parsing things that are of a "Request" type, but there are
# a handful of types that need work in order to enforce that.
#
# These can be easily found by swapping the following comment and running
# mypy.
#
# A = TypeVar("A", bound="BaseRequest")
A = TypeVar("A", bound="BaseModel")
def parse_request(cls: Type[A], req: HttpRequest) -> Union[A, Error]:
access = check_access(req)
if isinstance(access, Error):
return access
try:
return cls.parse_obj(req.get_json())
except ValidationError as err:
return convert_error(err)
def parse_uri(cls: Type[A], req: HttpRequest) -> Union[A, Error]:
access = check_access(req)
if isinstance(access, Error):
return access
data = {}
for key in req.params:
data[key] = req.params[key]
try:
return cls.parse_obj(data)
except ValidationError as err:
return convert_error(err)
class RequestException(Exception):
def __init__(self, error: Error):
self.error = error
message = "error %s" % error
super().__init__(message)

View File

@ -0,0 +1,47 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import List, Optional, Tuple
from uuid import UUID
from onefuzztypes.models import TaskEvent as BASE_TASK_EVENT
from onefuzztypes.models import (
TaskEventSummary,
WorkerDoneEvent,
WorkerEvent,
WorkerRunningEvent,
)
from .orm import ORMMixin
class TaskEvent(BASE_TASK_EVENT, ORMMixin):
@classmethod
def get_summary(cls, task_id: UUID) -> List[TaskEventSummary]:
events = cls.search(query={"task_id": [task_id]})
events.sort(key=lambda e: e.Timestamp)
return [
TaskEventSummary(
timestamp=e.Timestamp,
event_data=cls.get_event_data(e.event_data),
event_type=type(e.event_data.event).__name__,
)
for e in events
]
@classmethod
def key_fields(cls) -> Tuple[str, Optional[str]]:
return ("task_id", None)
@classmethod
def get_event_data(cls, worker_event: WorkerEvent) -> str:
event = worker_event.event
if isinstance(event, WorkerDoneEvent):
return "exit status: %s" % event.exit_status
elif isinstance(event, WorkerRunningEvent):
return ""
else:
return "Unrecognized event: %s" % event

View File

@ -0,0 +1,320 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
import os
from typing import Dict, List, Optional
from uuid import UUID
from onefuzztypes.enums import Compare, ContainerPermission, ContainerType, TaskFeature
from onefuzztypes.models import TaskConfig, TaskDefinition, TaskUnitConfig
from ..azure.containers import blob_exists, get_container_sas_url, get_containers
from ..azure.creds import get_fuzz_storage, get_instance_name
from ..azure.queue import get_queue_sas
from .defs import TASK_DEFINITIONS
LOGGER = logging.getLogger("onefuzz")
def get_input_container_queues(config: TaskConfig) -> Optional[List[str]]: # tasks.Task
if config.task.type not in TASK_DEFINITIONS:
raise TaskConfigError("unsupported task type: %s" % config.task.type.name)
container_type = TASK_DEFINITIONS[config.task.type].monitor_queue
if container_type:
return [x.name for x in config.containers if x.type == container_type]
return None
def check_val(compare: Compare, expected: int, actual: int) -> bool:
if compare == Compare.Equal:
return expected == actual
if compare == Compare.AtLeast:
return expected <= actual
if compare == Compare.AtMost:
return expected >= actual
raise NotImplementedError
def check_container(
compare: Compare,
expected: int,
container_type: ContainerType,
containers: Dict[ContainerType, List[str]],
) -> None:
actual = len(containers.get(container_type, []))
if not check_val(compare, expected, actual):
raise TaskConfigError(
"container type %s: expected %s %d, got %d"
% (container_type.name, compare.name, expected, actual)
)
def check_containers(definition: TaskDefinition, config: TaskConfig) -> None:
all_containers = set(get_containers().keys())
containers: Dict[ContainerType, List[str]] = {}
for container in config.containers:
if container.name not in all_containers:
raise TaskConfigError("missing container: %s" % container.name)
if container.type not in containers:
containers[container.type] = []
containers[container.type].append(container.name)
for container_def in definition.containers:
check_container(
container_def.compare, container_def.value, container_def.type, containers
)
for container_type in containers:
if container_type not in [x.type for x in definition.containers]:
raise TaskConfigError(
"unsupported container type for this task: %s", container_type.name
)
if definition.monitor_queue:
if definition.monitor_queue not in [x.type for x in definition.containers]:
raise TaskConfigError(
"unable to monitor container type as it is not used by this task: %s"
% definition.monitor_queue.name
)
def check_config(config: TaskConfig) -> None:
if config.task.type not in TASK_DEFINITIONS:
raise TaskConfigError("unsupported task type: %s" % config.task.type.name)
if config.vm is not None and config.pool is not None:
raise TaskConfigError("either the vm or pool must be specified, but not both")
definition = TASK_DEFINITIONS[config.task.type]
check_containers(definition, config)
if (
TaskFeature.supervisor_exe in definition.features
and not config.task.supervisor_exe
):
err = "missing supervisor_exe"
LOGGER.error(err)
raise TaskConfigError("missing supervisor_exe")
if config.vm:
if not check_val(definition.vm.compare, definition.vm.value, config.vm.count):
err = "invalid vm count: expected %s %d, got %s" % (
definition.vm.compare,
definition.vm.value,
config.vm.count,
)
LOGGER.error(err)
raise TaskConfigError(err)
elif config.pool:
if not check_val(definition.vm.compare, definition.vm.value, config.pool.count):
err = "invalid vm count: expected %s %d, got %s" % (
definition.vm.compare,
definition.vm.value,
config.pool.count,
)
LOGGER.error(err)
raise TaskConfigError(err)
else:
raise TaskConfigError("either the vm or pool must be specified")
if TaskFeature.target_exe in definition.features:
container = [x for x in config.containers if x.type == ContainerType.setup][0]
if not blob_exists(container.name, config.task.target_exe):
err = "target_exe `%s` does not exist in the setup container `%s`" % (
config.task.target_exe,
container.name,
)
LOGGER.warning(err)
if TaskFeature.generator_exe in definition.features:
container = [x for x in config.containers if x.type == ContainerType.tools][0]
if not config.task.generator_exe:
raise TaskConfigError("generator_exe is not defined")
tools_paths = ["{tools_dir}/", "{tools_dir}\\"]
for tool_path in tools_paths:
if config.task.generator_exe.startswith(tool_path):
generator = config.task.generator_exe.replace(tool_path, "")
if not blob_exists(container.name, generator):
err = (
"generator_exe `%s` does not exist in the tools container `%s`"
% (
config.task.generator_exe,
container.name,
)
)
LOGGER.error(err)
raise TaskConfigError(err)
if TaskFeature.stats_file in definition.features:
if config.task.stats_file is not None and config.task.stats_format is None:
err = "using a stats_file requires a stats_format"
LOGGER.error(err)
raise TaskConfigError(err)
def build_task_config(
job_id: UUID, task_id: UUID, task_config: TaskConfig
) -> TaskUnitConfig:
if task_config.task.type not in TASK_DEFINITIONS:
raise TaskConfigError("unsupported task type: %s" % task_config.task.type.name)
definition = TASK_DEFINITIONS[task_config.task.type]
config = TaskUnitConfig(
job_id=job_id,
task_id=task_id,
task_type=task_config.task.type,
instrumentation_key=os.environ.get("APPINSIGHTS_INSTRUMENTATIONKEY"),
telemetry_key=os.environ.get("ONEFUZZ_TELEMETRY"),
heartbeat_queue=get_queue_sas(
"heartbeat",
account_id=os.environ["ONEFUZZ_FUNC_STORAGE"],
add=True,
),
back_channel_address="https://%s.azurewebsites.net/api/back_channel"
% (get_instance_name()),
)
if definition.monitor_queue:
config.input_queue = get_queue_sas(
task_id,
add=True,
read=True,
update=True,
process=True,
account_id=get_fuzz_storage(),
)
for container_def in definition.containers:
if container_def.type == ContainerType.setup:
continue
containers = []
for (i, container) in enumerate(task_config.containers):
if container.type != container_def.type:
continue
containers.append(
{
"path": "_".join(["task", container_def.type.name, str(i)]),
"url": get_container_sas_url(
container.name,
read=ContainerPermission.Read in container_def.permissions,
write=ContainerPermission.Write in container_def.permissions,
add=ContainerPermission.Add in container_def.permissions,
delete=ContainerPermission.Delete in container_def.permissions,
create=ContainerPermission.Create in container_def.permissions,
list=ContainerPermission.List in container_def.permissions,
),
}
)
if not containers:
continue
if (
container_def.compare in [Compare.Equal, Compare.AtMost]
and container_def.value == 1
):
setattr(config, container_def.type.name, containers[0])
else:
setattr(config, container_def.type.name, containers)
EMPTY_DICT: Dict[str, str] = {}
EMPTY_LIST: List[str] = []
if TaskFeature.supervisor_exe in definition.features:
config.supervisor_exe = task_config.task.supervisor_exe
if TaskFeature.supervisor_env in definition.features:
config.supervisor_env = task_config.task.supervisor_env or EMPTY_DICT
if TaskFeature.supervisor_options in definition.features:
config.supervisor_options = task_config.task.supervisor_options or EMPTY_LIST
if TaskFeature.supervisor_input_marker in definition.features:
config.supervisor_input_marker = task_config.task.supervisor_input_marker
if TaskFeature.target_exe in definition.features:
config.target_exe = "setup/%s" % task_config.task.target_exe
if TaskFeature.target_env in definition.features:
config.target_env = task_config.task.target_env or EMPTY_DICT
if TaskFeature.target_options in definition.features:
config.target_options = task_config.task.target_options or EMPTY_LIST
if TaskFeature.target_options_merge in definition.features:
config.target_options_merge = task_config.task.target_options_merge or False
if TaskFeature.rename_output in definition.features:
config.rename_output = task_config.task.rename_output or False
if TaskFeature.generator_exe in definition.features:
config.generator_exe = task_config.task.generator_exe
if TaskFeature.generator_env in definition.features:
config.generator_env = task_config.task.generator_env or EMPTY_DICT
if TaskFeature.generator_options in definition.features:
config.generator_options = task_config.task.generator_options or EMPTY_LIST
if (
TaskFeature.wait_for_files in definition.features
and task_config.task.wait_for_files
):
config.wait_for_files = task_config.task.wait_for_files.name
if TaskFeature.analyzer_exe in definition.features:
config.analyzer_exe = task_config.task.analyzer_exe
if TaskFeature.analyzer_options in definition.features:
config.analyzer_options = task_config.task.analyzer_options or EMPTY_LIST
if TaskFeature.analyzer_env in definition.features:
config.analyzer_env = task_config.task.analyzer_env or EMPTY_DICT
if TaskFeature.stats_file in definition.features:
config.stats_file = task_config.task.stats_file
config.stats_format = task_config.task.stats_format
if TaskFeature.target_timeout in definition.features:
config.target_timeout = task_config.task.target_timeout
if TaskFeature.check_asan_log in definition.features:
config.check_asan_log = task_config.task.check_asan_log
if TaskFeature.check_debugger in definition.features:
config.check_debugger = task_config.task.check_debugger
if TaskFeature.check_retry_count in definition.features:
config.check_retry_count = task_config.task.check_retry_count or 0
return config
def get_setup_container(config: TaskConfig) -> str:
for container in config.containers:
if container.type == ContainerType.setup:
return container.name
raise TaskConfigError(
"task missing setup container: task_type = %s" % config.task.type
)
class TaskConfigError(Exception):
pass

View File

@ -0,0 +1,376 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from onefuzztypes.enums import (
Compare,
ContainerPermission,
ContainerType,
TaskFeature,
TaskType,
)
from onefuzztypes.models import ContainerDefinition, TaskDefinition, VmDefinition
# all tasks are required to have a 'setup' container
TASK_DEFINITIONS = {
TaskType.generic_analysis: TaskDefinition(
features=[
TaskFeature.target_exe,
TaskFeature.target_options,
TaskFeature.analyzer_exe,
TaskFeature.analyzer_env,
TaskFeature.analyzer_options,
],
vm=VmDefinition(compare=Compare.AtLeast, value=1),
containers=[
ContainerDefinition(
type=ContainerType.setup,
compare=Compare.Equal,
value=1,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
ContainerDefinition(
type=ContainerType.analysis,
compare=Compare.Equal,
value=1,
permissions=[
ContainerPermission.Write,
ContainerPermission.Read,
ContainerPermission.List,
ContainerPermission.Create,
],
),
ContainerDefinition(
type=ContainerType.crashes,
compare=Compare.Equal,
value=1,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
ContainerDefinition(
type=ContainerType.tools,
compare=Compare.Equal,
value=1,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
],
monitor_queue=ContainerType.crashes,
),
TaskType.libfuzzer_fuzz: TaskDefinition(
features=[
TaskFeature.target_exe,
TaskFeature.target_env,
TaskFeature.target_options,
TaskFeature.target_workers,
],
vm=VmDefinition(compare=Compare.AtLeast, value=1),
containers=[
ContainerDefinition(
type=ContainerType.setup,
compare=Compare.Equal,
value=1,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
ContainerDefinition(
type=ContainerType.crashes,
compare=Compare.Equal,
value=1,
permissions=[ContainerPermission.Write, ContainerPermission.Create],
),
ContainerDefinition(
type=ContainerType.inputs,
compare=Compare.Equal,
value=1,
permissions=[
ContainerPermission.Write,
ContainerPermission.Read,
ContainerPermission.List,
ContainerPermission.Create,
],
),
ContainerDefinition(
type=ContainerType.readonly_inputs,
compare=Compare.AtLeast,
value=0,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
],
monitor_queue=None,
),
TaskType.libfuzzer_crash_report: TaskDefinition(
features=[
TaskFeature.target_exe,
TaskFeature.target_env,
TaskFeature.target_options,
TaskFeature.target_timeout,
TaskFeature.check_retry_count,
],
vm=VmDefinition(compare=Compare.AtLeast, value=1),
containers=[
ContainerDefinition(
type=ContainerType.setup,
compare=Compare.Equal,
value=1,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
ContainerDefinition(
type=ContainerType.crashes,
compare=Compare.Equal,
value=1,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
ContainerDefinition(
type=ContainerType.reports,
compare=Compare.AtMost,
value=1,
permissions=[ContainerPermission.Create],
),
ContainerDefinition(
type=ContainerType.unique_reports,
compare=Compare.AtMost,
value=1,
permissions=[ContainerPermission.Create],
),
ContainerDefinition(
type=ContainerType.no_repro,
compare=Compare.AtMost,
value=1,
permissions=[ContainerPermission.Create],
),
],
monitor_queue=ContainerType.crashes,
),
TaskType.libfuzzer_coverage: TaskDefinition(
features=[
TaskFeature.target_exe,
TaskFeature.target_env,
TaskFeature.target_options,
],
vm=VmDefinition(compare=Compare.Equal, value=1),
containers=[
ContainerDefinition(
type=ContainerType.setup,
compare=Compare.Equal,
value=1,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
ContainerDefinition(
type=ContainerType.readonly_inputs,
compare=Compare.AtLeast,
value=1,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
ContainerDefinition(
type=ContainerType.coverage,
compare=Compare.Equal,
value=1,
permissions=[
ContainerPermission.Create,
ContainerPermission.List,
ContainerPermission.Read,
ContainerPermission.Write,
],
),
],
monitor_queue=ContainerType.readonly_inputs,
),
TaskType.libfuzzer_merge: TaskDefinition(
features=[
TaskFeature.target_exe,
TaskFeature.target_env,
TaskFeature.target_options,
],
vm=VmDefinition(compare=Compare.Equal, value=1),
containers=[
ContainerDefinition(
type=ContainerType.setup,
compare=Compare.Equal,
value=1,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
ContainerDefinition(
type=ContainerType.readonly_inputs,
compare=Compare.AtLeast,
value=1,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
ContainerDefinition(
type=ContainerType.inputs,
compare=Compare.Equal,
value=1,
permissions=[ContainerPermission.Create, ContainerPermission.List],
),
],
monitor_queue=ContainerType.inputs,
),
TaskType.generic_supervisor: TaskDefinition(
features=[
TaskFeature.target_exe,
TaskFeature.target_options,
TaskFeature.supervisor_exe,
TaskFeature.supervisor_env,
TaskFeature.supervisor_options,
TaskFeature.supervisor_input_marker,
TaskFeature.wait_for_files,
],
vm=VmDefinition(compare=Compare.AtLeast, value=1),
containers=[
ContainerDefinition(
type=ContainerType.setup,
compare=Compare.Equal,
value=1,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
ContainerDefinition(
type=ContainerType.tools,
compare=Compare.Equal,
value=1,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
ContainerDefinition(
type=ContainerType.crashes,
compare=Compare.Equal,
value=1,
permissions=[ContainerPermission.Create],
),
ContainerDefinition(
type=ContainerType.inputs,
compare=Compare.Equal,
value=1,
permissions=[
ContainerPermission.Create,
ContainerPermission.Read,
ContainerPermission.List,
],
),
],
monitor_queue=None,
),
TaskType.generic_merge: TaskDefinition(
features=[
TaskFeature.target_exe,
TaskFeature.target_options,
TaskFeature.supervisor_exe,
TaskFeature.supervisor_env,
TaskFeature.supervisor_options,
TaskFeature.supervisor_input_marker,
TaskFeature.stats_file,
],
vm=VmDefinition(compare=Compare.AtLeast, value=1),
containers=[
ContainerDefinition(
type=ContainerType.setup,
compare=Compare.Equal,
value=1,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
ContainerDefinition(
type=ContainerType.tools,
compare=Compare.Equal,
value=1,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
ContainerDefinition(
type=ContainerType.readonly_inputs,
compare=Compare.AtLeast,
value=1,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
ContainerDefinition(
type=ContainerType.inputs,
compare=Compare.Equal,
value=1,
permissions=[ContainerPermission.Create, ContainerPermission.List],
),
],
monitor_queue=None,
),
TaskType.generic_generator: TaskDefinition(
features=[
TaskFeature.generator_exe,
TaskFeature.generator_env,
TaskFeature.generator_options,
TaskFeature.target_exe,
TaskFeature.target_env,
TaskFeature.target_options,
TaskFeature.rename_output,
TaskFeature.target_timeout,
TaskFeature.check_asan_log,
TaskFeature.check_debugger,
TaskFeature.check_retry_count,
],
vm=VmDefinition(compare=Compare.AtLeast, value=1),
containers=[
ContainerDefinition(
type=ContainerType.setup,
compare=Compare.Equal,
value=1,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
ContainerDefinition(
type=ContainerType.tools,
compare=Compare.Equal,
value=1,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
ContainerDefinition(
type=ContainerType.crashes,
compare=Compare.Equal,
value=1,
permissions=[ContainerPermission.Create],
),
ContainerDefinition(
type=ContainerType.readonly_inputs,
compare=Compare.AtLeast,
value=1,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
],
monitor_queue=None,
),
TaskType.generic_crash_report: TaskDefinition(
features=[
TaskFeature.target_exe,
TaskFeature.target_env,
TaskFeature.target_options,
TaskFeature.target_timeout,
TaskFeature.check_asan_log,
TaskFeature.check_debugger,
TaskFeature.check_retry_count,
],
vm=VmDefinition(compare=Compare.AtLeast, value=1),
containers=[
ContainerDefinition(
type=ContainerType.setup,
compare=Compare.Equal,
value=1,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
ContainerDefinition(
type=ContainerType.crashes,
compare=Compare.Equal,
value=1,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
ContainerDefinition(
type=ContainerType.reports,
compare=Compare.AtMost,
value=1,
permissions=[ContainerPermission.Create],
),
ContainerDefinition(
type=ContainerType.unique_reports,
compare=Compare.AtMost,
value=1,
permissions=[ContainerPermission.Create],
),
ContainerDefinition(
type=ContainerType.no_repro,
compare=Compare.AtMost,
value=1,
permissions=[ContainerPermission.Create],
),
],
monitor_queue=ContainerType.crashes,
),
}

View File

@ -0,0 +1,231 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
from datetime import datetime, timedelta
from typing import List, Optional, Tuple, Union
from uuid import UUID
from onefuzztypes.enums import ErrorCode, TaskState
from onefuzztypes.models import Error
from onefuzztypes.models import Task as BASE_TASK
from onefuzztypes.models import TaskConfig, TaskVm
from ..azure.creds import get_fuzz_storage
from ..azure.image import get_os
from ..azure.queue import create_queue, delete_queue
from ..orm import MappingIntStrAny, ORMMixin, QueryFilter
from ..pools import Node, Pool, Scaleset
from ..proxy_forward import ProxyForward
class Task(BASE_TASK, ORMMixin):
def ready_to_schedule(self) -> bool:
if self.config.prereq_tasks:
for task_id in self.config.prereq_tasks:
task = Task.get_by_task_id(task_id)
# if a prereq task fails, then mark this task as failed
if isinstance(task, Error):
self.error = task
self.state = TaskState.stopping
self.save()
return False
if task.state not in task.state.has_started():
return False
return True
@classmethod
def create(cls, config: TaskConfig, job_id: UUID) -> Union["Task", Error]:
if config.vm:
os = get_os(config.vm.region, config.vm.image)
elif config.pool:
pool = Pool.get_by_name(config.pool.pool_name)
if isinstance(pool, Error):
return pool
os = pool.os
else:
raise Exception("task must have vm or pool")
task = cls(config=config, job_id=job_id, os=os)
task.save()
return task
def save_exclude(self) -> Optional[MappingIntStrAny]:
return {"heartbeats": ...}
def is_ready(self) -> bool:
if self.config.prereq_tasks:
for prereq_id in self.config.prereq_tasks:
prereq = Task.get_by_task_id(prereq_id)
if isinstance(prereq, Error):
logging.info("task prereq has error: %s - %s", self.task_id, prereq)
self.error = prereq
self.state = TaskState.stopping
self.save()
return False
if prereq.state != TaskState.running:
logging.info(
"task is waiting on prereq: %s - %s:",
self.task_id,
prereq.task_id,
)
return False
return True
# At current, the telemetry filter will generate something like this:
#
# {
# 'job_id': 'f4a20fd8-0dcc-4a4e-8804-6ef7df50c978',
# 'task_id': '835f7b3f-43ad-4718-b7e4-d506d9667b09',
# 'state': 'stopped',
# 'config': {
# 'task': {'type': 'libfuzzer_coverage'},
# 'vm': {'count': 1}
# }
# }
def telemetry_include(self) -> Optional[MappingIntStrAny]:
return {
"job_id": ...,
"task_id": ...,
"state": ...,
"config": {"vm": {"count": ...}, "task": {"type": ...}},
}
def event_include(self) -> Optional[MappingIntStrAny]:
return {
"job_id": ...,
"task_id": ...,
"state": ...,
"error": ...,
}
def init(self) -> None:
create_queue(self.task_id, account_id=get_fuzz_storage())
self.state = TaskState.waiting
self.save()
def stopping(self) -> None:
# TODO: we need to tell every node currently working on this task to stop
# TODO: we need to 'unschedule' this task from the existing pools
self.state = TaskState.stopping
logging.info("stopping task: %s:%s", self.job_id, self.task_id)
ProxyForward.remove_forward(self.task_id)
delete_queue(str(self.task_id), account_id=get_fuzz_storage())
Node.stop_task(self.task_id)
self.state = TaskState.stopped
self.save()
def queue_stop(self) -> None:
self.queue(method=self.stopping)
@classmethod
def search_states(
cls, *, job_id: Optional[UUID] = None, states: Optional[List[TaskState]] = None
) -> List["Task"]:
query: QueryFilter = {}
if job_id:
query["job_id"] = [job_id]
if states:
query["state"] = states
return cls.search(query=query)
@classmethod
def search_expired(cls) -> List["Task"]:
time_filter = "end_time lt datetime'%s'" % datetime.utcnow().isoformat()
return cls.search(
query={"state": TaskState.available()}, raw_unchecked_filter=time_filter
)
@classmethod
def get_by_task_id(cls, task_id: UUID) -> Union[Error, "Task"]:
tasks = cls.search(query={"task_id": [task_id]})
if not tasks:
return Error(code=ErrorCode.INVALID_REQUEST, errors=["unable to find task"])
if len(tasks) != 1:
return Error(
code=ErrorCode.INVALID_REQUEST, errors=["error identifying task"]
)
task = tasks[0]
return task
def get_pool(self) -> Optional[Pool]:
if self.config.pool:
pool = Pool.get_by_name(self.config.pool.pool_name)
if isinstance(pool, Error):
logging.info(
"unable to schedule task to pool: %s - %s", self.task_id, pool
)
return None
return pool
elif self.config.vm:
scalesets = Scaleset.search()
scalesets = [
x
for x in scalesets
if x.vm_sku == self.config.vm.sku and x.image == self.config.vm.image
]
for scaleset in scalesets:
pool = Pool.get_by_name(scaleset.pool_name)
if isinstance(pool, Error):
logging.info(
"unable to schedule task to pool: %s - %s",
self.task_id,
pool,
)
else:
return pool
logging.warn(
"unable to find a scaleset that matches the task prereqs: %s",
self.task_id,
)
return None
def get_repro_vm_config(self) -> Union[TaskVm, None]:
if self.config.vm:
return self.config.vm
if self.config.pool is None:
raise Exception("either pool or vm must be specified: %s" % self.task_id)
pool = Pool.get_by_name(self.config.pool.pool_name)
if isinstance(pool, Error):
logging.info("unable to find pool from task: %s", self.task_id)
return None
for scaleset in Scaleset.search_by_pool(self.config.pool.pool_name):
return TaskVm(
region=scaleset.region,
sku=scaleset.vm_sku,
image=scaleset.image,
)
logging.warning(
"no scalesets are defined for task: %s:%s", self.job_id, self.task_id
)
return None
def on_start(self) -> None:
# try to keep this effectively idempotent
if self.end_time is None:
self.end_time = datetime.utcnow() + timedelta(
hours=self.config.task.duration
)
self.save()
from ..jobs import Job
job = Job.get(self.job_id)
if job:
job.on_start()
@classmethod
def key_fields(cls) -> Tuple[str, str]:
return ("job_id", "task_id")

View File

@ -0,0 +1,94 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
from typing import Dict, List
from uuid import UUID
from onefuzztypes.enums import OS, TaskState
from onefuzztypes.models import WorkSet, WorkUnit
from ..azure.containers import blob_exists, get_container_sas_url, save_blob
from ..azure.creds import get_func_storage
from .config import build_task_config, get_setup_container
from .main import Task
HOURS = 60 * 60
def schedule_tasks() -> None:
to_schedule: Dict[UUID, List[Task]] = {}
for task in Task.search_states(states=[TaskState.waiting]):
if not task.ready_to_schedule():
continue
if task.job_id not in to_schedule:
to_schedule[task.job_id] = []
to_schedule[task.job_id].append(task)
for tasks in to_schedule.values():
# TODO: for now, we're only scheduling one task per VM.
for task in tasks:
logging.info("scheduling task: %s", task.task_id)
agent_config = build_task_config(task.job_id, task.task_id, task.config)
setup_container = get_setup_container(task.config)
setup_url = get_container_sas_url(setup_container, read=True, list=True)
setup_script = None
if task.os == OS.windows and blob_exists(setup_container, "setup.ps1"):
setup_script = "setup.ps1"
if task.os == OS.linux and blob_exists(setup_container, "setup.sh"):
setup_script = "setup.sh"
save_blob(
"task-configs",
"%s/config.json" % task.task_id,
agent_config.json(),
account_id=get_func_storage(),
)
reboot = False
count = 1
if task.config.pool:
count = task.config.pool.count
reboot = task.config.task.reboot_after_setup is True
elif task.config.vm:
# this branch should go away when we stop letting people specify
# VM configs directly.
count = task.config.vm.count
reboot = (
task.config.vm.reboot_after_setup is True
or task.config.task.reboot_after_setup is True
)
task_config = agent_config
task_config_json = task_config.json()
work_unit = WorkUnit(
job_id=task_config.job_id,
task_id=task_config.task_id,
task_type=task_config.task_type,
config=task_config_json,
)
# For now, only offer singleton work sets.
work_set = WorkSet(
reboot=reboot,
script=(setup_script is not None),
setup_url=setup_url,
work_units=[work_unit],
)
pool = task.get_pool()
if not pool:
logging.info("unable to find pool for task: %s", task.task_id)
continue
for _ in range(count):
pool.schedule_workset(work_set)
task.state = TaskState.scheduled
task.save()

View File

@ -0,0 +1,72 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
import os
from typing import Any, Dict, Optional, Union
from onefuzztypes.enums import TelemetryData, TelemetryEvent
from opencensus.ext.azure.log_exporter import AzureLogHandler
LOCAL_CLIENT: Optional[logging.Logger] = None
CENTRAL_CLIENT: Optional[logging.Logger] = None
def _get_client(environ_key: str) -> Optional[logging.Logger]:
key = os.environ.get(environ_key)
if key is None:
return None
client = logging.getLogger("onefuzz")
client.addHandler(AzureLogHandler(connection_string="InstrumentationKey=%s" % key))
return client
def _central_client() -> Optional[logging.Logger]:
global CENTRAL_CLIENT
if not CENTRAL_CLIENT:
CENTRAL_CLIENT = _get_client("ONEFUZZ_TELEMETRY")
return CENTRAL_CLIENT
def _local_client() -> Union[None, Any, logging.Logger]:
global LOCAL_CLIENT
if not LOCAL_CLIENT:
LOCAL_CLIENT = _get_client("APPINSIGHTS_INSTRUMENTATIONKEY")
return LOCAL_CLIENT
# NOTE: All telemetry that is *NOT* using the ORM telemetry_include should
# go through this method
#
# This provides a point of inspection to know if it's data that is safe to
# log to the central OneFuzz telemetry point
def track_event(
event: TelemetryEvent, data: Dict[TelemetryData, Union[str, int]]
) -> None:
central = _central_client()
local = _local_client()
if local:
serialized = {k.name: v for (k, v) in data.items()}
local.info(event.name, extra={"custom_dimensions": serialized})
if event in TelemetryEvent.can_share() and central:
serialized = {
k.name: v for (k, v) in data.items() if k in TelemetryData.can_share()
}
central.info(event.name, extra={"custom_dimensions": serialized})
# NOTE: This should *only* be used for logging Telemetry data that uses
# the ORM telemetry_include method to limit data for telemetry.
def track_event_filtered(event: TelemetryEvent, data: Any) -> None:
central = _central_client()
local = _local_client()
if local:
local.info(event.name, extra={"custom_dimensions": data})
if central and event in TelemetryEvent.can_share():
central.info(event.name, extra={"custom_dimensions": data})

View File

@ -0,0 +1,110 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
from typing import Dict, Optional, Type
from msrestazure.azure_exceptions import CloudError
from onefuzztypes.enums import UpdateType
from pydantic import BaseModel
from .azure.creds import get_func_storage
from .azure.queue import queue_object
# This class isn't intended to be shared outside of the service
class Update(BaseModel):
update_type: UpdateType
PartitionKey: Optional[str]
RowKey: Optional[str]
method: Optional[str]
def queue_update(
update_type: UpdateType,
PartitionKey: Optional[str] = None,
RowKey: Optional[str] = None,
method: Optional[str] = None,
visibility_timeout: int = None,
) -> None:
logging.info(
"queuing type:%s id:[%s,%s] method:%s timeout: %s",
update_type.name,
PartitionKey,
RowKey,
method,
visibility_timeout,
)
update = Update(
update_type=update_type, PartitionKey=PartitionKey, RowKey=RowKey, method=method
)
try:
if not queue_object(
"update-queue",
update,
account_id=get_func_storage(),
visibility_timeout=visibility_timeout,
):
logging.error("unable to queue update")
except CloudError as err:
logging.error("GOT ERROR %s", repr(err))
logging.error("GOT ERROR %s", dir(err))
raise err
def execute_update(update: Update) -> None:
from .jobs import Job
from .orm import ORMMixin
from .pools import Node, Pool, Scaleset
from .proxy import Proxy
from .repro import Repro
from .tasks.main import Task
update_objects: Dict[UpdateType, Type[ORMMixin]] = {
UpdateType.Task: Task,
UpdateType.Job: Job,
UpdateType.Repro: Repro,
UpdateType.Proxy: Proxy,
UpdateType.Pool: Pool,
UpdateType.Node: Node,
UpdateType.Scaleset: Scaleset,
}
# TODO: remove these from being queued, these updates are handled elsewhere
if update.update_type == UpdateType.Scaleset:
return
if update.update_type in update_objects:
if update.PartitionKey is None or update.RowKey is None:
raise Exception("unsupported update: %s" % update)
obj = update_objects[update.update_type].get(update.PartitionKey, update.RowKey)
if not obj:
logging.error("unable find to obj to update %s", update)
return
if update.method and hasattr(obj, update.method):
getattr(obj, update.method)()
return
else:
state = getattr(obj, "state", None)
if state is None:
logging.error("queued update for object without state: %s", update)
return
func = getattr(obj, state.name, None)
if func is None:
logging.info("no function to implement state: %s", update)
return
func()
return
raise NotImplementedError("unimplemented update type: %s" % update.update_type.name)
def perform_update(update: Update) -> None:
logging.info("performing queued update: %s", update)
execute_update(update)

View File

@ -0,0 +1,31 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
from typing import Dict
from memoization import cached
from onefuzztypes.responses import Version
from .__version__ import __version__
@cached
def read_local_file(filename: str) -> str:
path = os.path.join(os.path.dirname(os.path.realpath(__file__)), filename)
if os.path.exists(path):
with open(path, "r") as handle:
return handle.read().strip()
else:
return "UNKNOWN"
def versions() -> Dict[str, Version]:
entry = Version(
git=read_local_file("git.version"),
build=read_local_file("build.id"),
version=__version__,
)
return {"onefuzz": entry}