mirror of
https://github.com/microsoft/onefuzz.git
synced 2025-06-16 11:58:09 +00:00
initial public release
This commit is contained in:
0
src/api-service/__app__/onefuzzlib/__init__.py
Normal file
0
src/api-service/__app__/onefuzzlib/__init__.py
Normal file
5
src/api-service/__app__/onefuzzlib/__version__.py
Normal file
5
src/api-service/__app__/onefuzzlib/__version__.py
Normal file
@ -0,0 +1,5 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# pylint: disable=W0612,C0111
|
||||
__version__ = "0.0.0"
|
75
src/api-service/__app__/onefuzzlib/agent_authorization.py
Normal file
75
src/api-service/__app__/onefuzzlib/agent_authorization.py
Normal file
@ -0,0 +1,75 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from typing import Callable, Union
|
||||
from uuid import UUID
|
||||
|
||||
import azure.functions as func
|
||||
import jwt
|
||||
from memoization import cached
|
||||
from onefuzztypes.enums import ErrorCode
|
||||
from onefuzztypes.models import Error
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .pools import Scaleset
|
||||
from .request import not_ok
|
||||
|
||||
|
||||
class TokenData(BaseModel):
|
||||
application_id: UUID
|
||||
object_id: UUID
|
||||
|
||||
|
||||
def try_get_token_auth_header(request: func.HttpRequest) -> Union[Error, TokenData]:
|
||||
""" Obtains the Access Token from the Authorization Header """
|
||||
auth: str = request.headers.get("Authorization", None)
|
||||
if not auth:
|
||||
return Error(
|
||||
code=ErrorCode.INVALID_REQUEST, errors=["Authorization header is expected"]
|
||||
)
|
||||
parts = auth.split()
|
||||
|
||||
if parts[0].lower() != "bearer":
|
||||
return Error(
|
||||
code=ErrorCode.INVALID_REQUEST,
|
||||
errors=["Authorization header must start with Bearer"],
|
||||
)
|
||||
|
||||
elif len(parts) == 1:
|
||||
return Error(code=ErrorCode.INVALID_REQUEST, errors=["Token not found"])
|
||||
|
||||
elif len(parts) > 2:
|
||||
return Error(
|
||||
code=ErrorCode.INVALID_REQUEST,
|
||||
errors=["Authorization header must be Bearer token"],
|
||||
)
|
||||
|
||||
# This token has already been verified by the azure authentication layer
|
||||
token = jwt.decode(parts[1], verify=False)
|
||||
return TokenData(application_id=UUID(token["appid"]), object_id=UUID(token["oid"]))
|
||||
|
||||
|
||||
@cached(ttl=60)
|
||||
def is_authorized(token_data: TokenData) -> bool:
|
||||
scalesets = Scaleset.get_by_object_id(token_data.object_id)
|
||||
return len(scalesets) > 0
|
||||
|
||||
|
||||
def verify_token(
|
||||
req: func.HttpRequest, func: Callable[[func.HttpRequest], func.HttpResponse]
|
||||
) -> func.HttpResponse:
|
||||
token = try_get_token_auth_header(req)
|
||||
|
||||
if isinstance(token, Error):
|
||||
return not_ok(token, status_code=401, context="token verification")
|
||||
|
||||
if not is_authorized(token):
|
||||
return not_ok(
|
||||
Error(code=ErrorCode.UNAUTHORIZED, errors=["Unrecognized agent"]),
|
||||
status_code=401,
|
||||
context="token verification",
|
||||
)
|
||||
|
||||
return func(req)
|
37
src/api-service/__app__/onefuzzlib/azure/auth.py
Normal file
37
src/api-service/__app__/onefuzzlib/azure/auth.py
Normal file
@ -0,0 +1,37 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import subprocess # nosec - used for ssh key generation
|
||||
import tempfile
|
||||
from typing import Tuple
|
||||
from uuid import uuid4
|
||||
|
||||
from onefuzztypes.models import Authentication
|
||||
|
||||
|
||||
def generate_keypair() -> Tuple[str, str]:
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
filename = os.path.join(tmpdir, "key")
|
||||
|
||||
cmd = ["ssh-keygen", "-t", "rsa", "-f", filename, "-P", "", "-b", "2048"]
|
||||
subprocess.check_output(cmd) # nosec - all arguments are under our control
|
||||
|
||||
with open(filename, "r") as handle:
|
||||
private = handle.read()
|
||||
|
||||
with open(filename + ".pub", "r") as handle:
|
||||
public = handle.read().strip()
|
||||
|
||||
return (public, private)
|
||||
|
||||
|
||||
def build_auth() -> Authentication:
|
||||
public_key, private_key = generate_keypair()
|
||||
auth = Authentication(
|
||||
password=str(uuid4()), public_key=public_key, private_key=private_key
|
||||
)
|
||||
return auth
|
155
src/api-service/__app__/onefuzzlib/azure/containers.py
Normal file
155
src/api-service/__app__/onefuzzlib/azure/containers.py
Normal file
@ -0,0 +1,155 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import datetime
|
||||
import os
|
||||
import urllib.parse
|
||||
from typing import Any, Dict, Optional, Union, cast
|
||||
|
||||
from azure.common import AzureHttpError, AzureMissingResourceHttpError
|
||||
from azure.storage.blob import BlobPermissions, ContainerPermissions
|
||||
|
||||
from .creds import get_blob_service
|
||||
|
||||
|
||||
def get_containers(account_id: Optional[str] = None) -> Dict[str, Dict[str, str]]:
|
||||
return {
|
||||
x.name: x.metadata
|
||||
for x in get_blob_service(account_id).list_containers(include_metadata=True)
|
||||
if not x.name.startswith("$")
|
||||
}
|
||||
|
||||
|
||||
def get_container_metadata(
|
||||
name: str, account_id: Optional[str] = None
|
||||
) -> Optional[Dict[str, str]]:
|
||||
try:
|
||||
result = get_blob_service(account_id).get_container_metadata(name)
|
||||
if result:
|
||||
return cast(Dict[str, str], result)
|
||||
except AzureHttpError:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def create_container(
|
||||
name: str, metadata: Optional[Dict[str, str]], account_id: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
try:
|
||||
get_blob_service(account_id).create_container(name, metadata=metadata)
|
||||
except AzureHttpError:
|
||||
# azure storage already logs errors
|
||||
return None
|
||||
|
||||
return get_container_sas_url(
|
||||
name, read=True, add=True, create=True, write=True, delete=True, list=True
|
||||
)
|
||||
|
||||
|
||||
def delete_container(name: str, account_id: Optional[str] = None) -> bool:
|
||||
try:
|
||||
return bool(get_blob_service(account_id).delete_container(name))
|
||||
except AzureHttpError:
|
||||
# azure storage already logs errors
|
||||
return False
|
||||
|
||||
|
||||
def get_container_sas_url(
|
||||
container: str,
|
||||
account_id: Optional[str] = None,
|
||||
read: bool = False,
|
||||
add: bool = False,
|
||||
create: bool = False,
|
||||
write: bool = False,
|
||||
delete: bool = False,
|
||||
list: bool = False,
|
||||
) -> str:
|
||||
service = get_blob_service(account_id)
|
||||
expiry = datetime.datetime.utcnow() + datetime.timedelta(days=30)
|
||||
permission = ContainerPermissions(read, add, create, write, delete, list)
|
||||
|
||||
sas_token = service.generate_container_shared_access_signature(
|
||||
container, permission=permission, expiry=expiry
|
||||
)
|
||||
|
||||
url = service.make_container_url(container, sas_token=sas_token)
|
||||
url = url.replace("?restype=container&", "?")
|
||||
return str(url)
|
||||
|
||||
|
||||
def get_file_sas_url(
|
||||
container: str,
|
||||
name: str,
|
||||
account_id: Optional[str] = None,
|
||||
read: bool = False,
|
||||
add: bool = False,
|
||||
create: bool = False,
|
||||
write: bool = False,
|
||||
delete: bool = False,
|
||||
list: bool = False,
|
||||
days: int = 30,
|
||||
hours: int = 0,
|
||||
minutes: int = 0,
|
||||
) -> str:
|
||||
service = get_blob_service(account_id)
|
||||
expiry = datetime.datetime.utcnow() + datetime.timedelta(
|
||||
days=days, hours=hours, minutes=minutes
|
||||
)
|
||||
permission = BlobPermissions(read, add, create, write, delete, list)
|
||||
|
||||
sas_token = service.generate_blob_shared_access_signature(
|
||||
container, name, permission=permission, expiry=expiry
|
||||
)
|
||||
|
||||
url = service.make_blob_url(container, name, sas_token=sas_token)
|
||||
return str(url)
|
||||
|
||||
|
||||
def save_blob(
|
||||
container: str, name: str, data: Union[str, bytes], account_id: Optional[str] = None
|
||||
) -> None:
|
||||
service = get_blob_service(account_id)
|
||||
service.create_container(container)
|
||||
if isinstance(data, str):
|
||||
service.create_blob_from_text(container, name, data)
|
||||
elif isinstance(data, bytes):
|
||||
service.create_blob_from_bytes(container, name, data)
|
||||
|
||||
|
||||
def get_blob(
|
||||
container: str, name: str, account_id: Optional[str] = None
|
||||
) -> Optional[Any]: # should be bytes
|
||||
service = get_blob_service(account_id)
|
||||
try:
|
||||
blob = service.get_blob_to_bytes(container, name).content
|
||||
return blob
|
||||
except AzureMissingResourceHttpError:
|
||||
return None
|
||||
|
||||
|
||||
def blob_exists(container: str, name: str, account_id: Optional[str] = None) -> bool:
|
||||
service = get_blob_service(account_id)
|
||||
try:
|
||||
service.get_blob_properties(container, name)
|
||||
return True
|
||||
except AzureMissingResourceHttpError:
|
||||
return False
|
||||
|
||||
|
||||
def delete_blob(container: str, name: str, account_id: Optional[str] = None) -> bool:
|
||||
service = get_blob_service(account_id)
|
||||
try:
|
||||
service.delete_blob(container, name)
|
||||
return True
|
||||
except AzureMissingResourceHttpError:
|
||||
return False
|
||||
|
||||
|
||||
def auth_download_url(container: str, filename: str) -> str:
|
||||
instance = os.environ["ONEFUZZ_INSTANCE"]
|
||||
return "%s/api/download?%s" % (
|
||||
instance,
|
||||
urllib.parse.urlencode({"container": container, "filename": filename}),
|
||||
)
|
116
src/api-service/__app__/onefuzzlib/azure/creds.py
Normal file
116
src/api-service/__app__/onefuzzlib/azure/creds.py
Normal file
@ -0,0 +1,116 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
from azure.cli.core import CLIError
|
||||
from azure.common.client_factory import get_client_from_cli_profile
|
||||
from azure.graphrbac import GraphRbacManagementClient
|
||||
from azure.graphrbac.models import CheckGroupMembershipParameters
|
||||
from azure.mgmt.resource import ResourceManagementClient
|
||||
from azure.mgmt.storage import StorageManagementClient
|
||||
from azure.mgmt.subscription import SubscriptionClient
|
||||
from azure.storage.blob import BlockBlobService
|
||||
from memoization import cached
|
||||
from msrestazure.azure_active_directory import MSIAuthentication
|
||||
from msrestazure.tools import parse_resource_id
|
||||
|
||||
from .monkeypatch import allow_more_workers
|
||||
|
||||
|
||||
@cached(ttl=60)
|
||||
def get_msi() -> MSIAuthentication:
|
||||
return MSIAuthentication()
|
||||
|
||||
|
||||
@cached(ttl=60)
|
||||
def mgmt_client_factory(client_class: Any) -> Any:
|
||||
allow_more_workers()
|
||||
try:
|
||||
return get_client_from_cli_profile(client_class)
|
||||
except CLIError:
|
||||
if issubclass(client_class, SubscriptionClient):
|
||||
return client_class(get_msi())
|
||||
else:
|
||||
return client_class(get_msi(), get_subscription())
|
||||
|
||||
|
||||
@cached(ttl=60)
|
||||
def get_storage_account_name_key(account_id: Optional[str] = None) -> Tuple[str, str]:
|
||||
db_client = mgmt_client_factory(StorageManagementClient)
|
||||
if account_id is None:
|
||||
account_id = os.environ["ONEFUZZ_DATA_STORAGE"]
|
||||
resource = parse_resource_id(account_id)
|
||||
key = (
|
||||
db_client.storage_accounts.list_keys(
|
||||
resource["resource_group"], resource["name"]
|
||||
)
|
||||
.keys[0]
|
||||
.value
|
||||
)
|
||||
return resource["name"], key
|
||||
|
||||
|
||||
@cached(ttl=60)
|
||||
def get_blob_service(account_id: Optional[str] = None) -> BlockBlobService:
|
||||
logging.info("getting blob container (account_id: %s)", account_id)
|
||||
name, key = get_storage_account_name_key(account_id)
|
||||
service = BlockBlobService(account_name=name, account_key=key)
|
||||
return service
|
||||
|
||||
|
||||
@cached
|
||||
def get_base_resource_group() -> Any: # should be str
|
||||
return parse_resource_id(os.environ["ONEFUZZ_RESOURCE_GROUP"])["resource_group"]
|
||||
|
||||
|
||||
@cached
|
||||
def get_base_region() -> Any: # should be str
|
||||
client = mgmt_client_factory(ResourceManagementClient)
|
||||
group = client.resource_groups.get(get_base_resource_group())
|
||||
return group.location
|
||||
|
||||
|
||||
@cached
|
||||
def get_subscription() -> Any: # should be str
|
||||
return parse_resource_id(os.environ["ONEFUZZ_DATA_STORAGE"])["subscription"]
|
||||
|
||||
|
||||
@cached
|
||||
def get_fuzz_storage() -> str:
|
||||
return os.environ["ONEFUZZ_DATA_STORAGE"]
|
||||
|
||||
|
||||
@cached
|
||||
def get_func_storage() -> str:
|
||||
return os.environ["ONEFUZZ_FUNC_STORAGE"]
|
||||
|
||||
|
||||
@cached
|
||||
def get_instance_name() -> str:
|
||||
return os.environ["ONEFUZZ_INSTANCE_NAME"]
|
||||
|
||||
|
||||
@cached(ttl=60)
|
||||
def get_regions() -> List[str]:
|
||||
client = mgmt_client_factory(SubscriptionClient)
|
||||
subscription = get_subscription()
|
||||
locations = client.subscriptions.list_locations(subscription)
|
||||
return sorted([x.name for x in locations])
|
||||
|
||||
|
||||
def get_graph_client() -> Any:
|
||||
return mgmt_client_factory(GraphRbacManagementClient)
|
||||
|
||||
|
||||
def is_member_of(group_id: str, member_id: str) -> bool:
|
||||
client = get_graph_client()
|
||||
return bool(
|
||||
client.groups.is_member_of(
|
||||
CheckGroupMembershipParameters(group_id=group_id, member_id=member_id)
|
||||
).value
|
||||
)
|
29
src/api-service/__app__/onefuzzlib/azure/disk.py
Normal file
29
src/api-service/__app__/onefuzzlib/azure/disk.py
Normal file
@ -0,0 +1,29 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from azure.mgmt.compute import ComputeManagementClient
|
||||
from msrestazure.azure_exceptions import CloudError
|
||||
|
||||
from .creds import mgmt_client_factory
|
||||
|
||||
|
||||
def list_disks(resource_group: str) -> Any:
|
||||
logging.info("listing disks %s", resource_group)
|
||||
compute_client = mgmt_client_factory(ComputeManagementClient)
|
||||
return compute_client.disks.list_by_resource_group(resource_group)
|
||||
|
||||
|
||||
def delete_disk(resource_group: str, name: str) -> bool:
|
||||
logging.info("deleting disks %s : %s", resource_group, name)
|
||||
compute_client = mgmt_client_factory(ComputeManagementClient)
|
||||
try:
|
||||
compute_client.disks.delete(resource_group, name)
|
||||
return True
|
||||
except CloudError as err:
|
||||
logging.error("unable to delete disk: %s", err)
|
||||
return False
|
42
src/api-service/__app__/onefuzzlib/azure/image.py
Normal file
42
src/api-service/__app__/onefuzzlib/azure/image.py
Normal file
@ -0,0 +1,42 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from typing import Union
|
||||
|
||||
from azure.mgmt.compute import ComputeManagementClient
|
||||
from memoization import cached
|
||||
from msrestazure.azure_exceptions import CloudError
|
||||
from msrestazure.tools import parse_resource_id
|
||||
from onefuzztypes.enums import OS, ErrorCode
|
||||
from onefuzztypes.models import Error
|
||||
from onefuzztypes.primitives import Region
|
||||
|
||||
from .creds import mgmt_client_factory
|
||||
|
||||
|
||||
@cached(ttl=60)
|
||||
def get_os(region: Region, image: str) -> Union[Error, OS]:
|
||||
client = mgmt_client_factory(ComputeManagementClient)
|
||||
parsed = parse_resource_id(image)
|
||||
if "resource_group" in parsed:
|
||||
try:
|
||||
name = client.images.get(
|
||||
parsed["resource_group"], parsed["name"]
|
||||
).storage_profile.os_disk.os_type.name
|
||||
except CloudError as err:
|
||||
return Error(code=ErrorCode.INVALID_IMAGE, errors=[str(err)])
|
||||
else:
|
||||
publisher, offer, sku, version = image.split(":")
|
||||
try:
|
||||
if version == "latest":
|
||||
version = client.virtual_machine_images.list(
|
||||
region, publisher, offer, sku, top=1
|
||||
)[0].name
|
||||
name = client.virtual_machine_images.get(
|
||||
region, publisher, offer, sku, version
|
||||
).os_disk_image.operating_system.name
|
||||
except CloudError as err:
|
||||
return Error(code=ErrorCode.INVALID_IMAGE, errors=[str(err)])
|
||||
return OS[name]
|
150
src/api-service/__app__/onefuzzlib/azure/ip.py
Normal file
150
src/api-service/__app__/onefuzzlib/azure/ip.py
Normal file
@ -0,0 +1,150 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from azure.mgmt.network import NetworkManagementClient
|
||||
from msrestazure.azure_exceptions import CloudError
|
||||
from msrestazure.tools import parse_resource_id
|
||||
from onefuzztypes.enums import ErrorCode
|
||||
from onefuzztypes.models import Error
|
||||
|
||||
from .creds import get_base_resource_group, mgmt_client_factory
|
||||
from .subnet import create_virtual_network, get_subnet_id
|
||||
from .vmss import get_instance_id
|
||||
|
||||
|
||||
def get_scaleset_instance_ip(scaleset: UUID, machine_id: UUID) -> Optional[str]:
|
||||
instance = get_instance_id(scaleset, machine_id)
|
||||
if not isinstance(instance, str):
|
||||
return None
|
||||
|
||||
resource_group = get_base_resource_group()
|
||||
|
||||
client = mgmt_client_factory(NetworkManagementClient)
|
||||
intf = client.network_interfaces.list_virtual_machine_scale_set_network_interfaces(
|
||||
resource_group, str(scaleset)
|
||||
)
|
||||
try:
|
||||
for interface in intf:
|
||||
resource = parse_resource_id(interface.virtual_machine.id)
|
||||
if resource.get("resource_name") != instance:
|
||||
continue
|
||||
|
||||
for config in interface.ip_configurations:
|
||||
if config.private_ip_address is None:
|
||||
continue
|
||||
return str(config.private_ip_address)
|
||||
except CloudError:
|
||||
# this can fail if an interface is removed during the iteration
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_ip(resource_group: str, name: str) -> Optional[Any]:
|
||||
logging.info("getting ip %s:%s", resource_group, name)
|
||||
network_client = mgmt_client_factory(NetworkManagementClient)
|
||||
try:
|
||||
return network_client.public_ip_addresses.get(resource_group, name)
|
||||
except CloudError:
|
||||
return None
|
||||
|
||||
|
||||
def delete_ip(resource_group: str, name: str) -> Any:
|
||||
logging.info("deleting ip %s:%s", resource_group, name)
|
||||
network_client = mgmt_client_factory(NetworkManagementClient)
|
||||
return network_client.public_ip_addresses.delete(resource_group, name)
|
||||
|
||||
|
||||
def create_ip(resource_group: str, name: str, location: str) -> Any:
|
||||
logging.info("creating ip for %s:%s in %s", resource_group, name, location)
|
||||
|
||||
network_client = mgmt_client_factory(NetworkManagementClient)
|
||||
params: Dict[str, Union[str, Dict[str, str]]] = {
|
||||
"location": location,
|
||||
"public_ip_allocation_method": "Dynamic",
|
||||
}
|
||||
if "ONEFUZZ_OWNER" in os.environ:
|
||||
params["tags"] = {"OWNER": os.environ["ONEFUZZ_OWNER"]}
|
||||
return network_client.public_ip_addresses.create_or_update(
|
||||
resource_group, name, params
|
||||
)
|
||||
|
||||
|
||||
def get_public_nic(resource_group: str, name: str) -> Optional[Any]:
|
||||
logging.info("getting nic: %s %s", resource_group, name)
|
||||
network_client = mgmt_client_factory(NetworkManagementClient)
|
||||
try:
|
||||
return network_client.network_interfaces.get(resource_group, name)
|
||||
except CloudError:
|
||||
return None
|
||||
|
||||
|
||||
def delete_nic(resource_group: str, name: str) -> Optional[Any]:
|
||||
logging.info("deleting nic %s:%s", resource_group, name)
|
||||
network_client = mgmt_client_factory(NetworkManagementClient)
|
||||
return network_client.network_interfaces.delete(resource_group, name)
|
||||
|
||||
|
||||
def create_public_nic(resource_group: str, name: str, location: str) -> Optional[Error]:
|
||||
logging.info("creating nic for %s:%s in %s", resource_group, name, location)
|
||||
|
||||
network_client = mgmt_client_factory(NetworkManagementClient)
|
||||
subnet_id = get_subnet_id(resource_group, location)
|
||||
if not subnet_id:
|
||||
return create_virtual_network(resource_group, location, location)
|
||||
|
||||
ip = get_ip(resource_group, name)
|
||||
if not ip:
|
||||
create_ip(resource_group, name, location)
|
||||
return None
|
||||
|
||||
params = {
|
||||
"location": location,
|
||||
"ip_configurations": [
|
||||
{
|
||||
"name": "myIPConfig",
|
||||
"public_ip_address": ip,
|
||||
"subnet": {"id": subnet_id},
|
||||
}
|
||||
],
|
||||
}
|
||||
if "ONEFUZZ_OWNER" in os.environ:
|
||||
params["tags"] = {"OWNER": os.environ["ONEFUZZ_OWNER"]}
|
||||
|
||||
try:
|
||||
network_client.network_interfaces.create_or_update(resource_group, name, params)
|
||||
except CloudError as err:
|
||||
if "RetryableError" not in repr(err):
|
||||
return Error(
|
||||
code=ErrorCode.VM_CREATE_FAILED,
|
||||
errors=["unable to create nic: %s" % err],
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def get_public_ip(resource_id: str) -> Optional[str]:
|
||||
logging.info("getting ip for %s", resource_id)
|
||||
network_client = mgmt_client_factory(NetworkManagementClient)
|
||||
resource = parse_resource_id(resource_id)
|
||||
ip = (
|
||||
network_client.network_interfaces.get(
|
||||
resource["resource_group"], resource["name"]
|
||||
)
|
||||
.ip_configurations[0]
|
||||
.public_ip_address
|
||||
)
|
||||
resource = parse_resource_id(ip.id)
|
||||
ip = network_client.public_ip_addresses.get(
|
||||
resource["resource_group"], resource["name"]
|
||||
).ip_address
|
||||
if ip is None:
|
||||
return None
|
||||
else:
|
||||
return str(ip)
|
29
src/api-service/__app__/onefuzzlib/azure/monitor.py
Normal file
29
src/api-service/__app__/onefuzzlib/azure/monitor.py
Normal file
@ -0,0 +1,29 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
|
||||
from azure.mgmt.loganalytics import LogAnalyticsManagementClient
|
||||
from memoization import cached
|
||||
|
||||
from .creds import get_base_resource_group, mgmt_client_factory
|
||||
|
||||
|
||||
@cached(ttl=60)
|
||||
def get_montior_client() -> Any:
|
||||
return mgmt_client_factory(LogAnalyticsManagementClient)
|
||||
|
||||
|
||||
@cached(ttl=60)
|
||||
def get_monitor_settings() -> Dict[str, str]:
|
||||
resource_group = get_base_resource_group()
|
||||
workspace_name = os.environ["ONEFUZZ_MONITOR"]
|
||||
client = get_montior_client()
|
||||
customer_id = client.workspaces.get(resource_group, workspace_name).customer_id
|
||||
shared_key = client.shared_keys.get_shared_keys(
|
||||
resource_group, workspace_name
|
||||
).primary_shared_key
|
||||
return {"id": customer_id, "key": shared_key}
|
24
src/api-service/__app__/onefuzzlib/azure/monkeypatch.py
Normal file
24
src/api-service/__app__/onefuzzlib/azure/monkeypatch.py
Normal file
@ -0,0 +1,24 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
|
||||
WORKERS_DONE = False
|
||||
|
||||
|
||||
def allow_more_workers() -> None:
|
||||
global WORKERS_DONE
|
||||
if WORKERS_DONE:
|
||||
return
|
||||
|
||||
stack = inspect.stack()
|
||||
for entry in stack:
|
||||
if entry.filename.endswith("azure_functions_worker/dispatcher.py"):
|
||||
if entry.frame.f_locals["self"]._sync_call_tp._max_workers == 1:
|
||||
logging.info("bumped thread worker count to 50")
|
||||
entry.frame.f_locals["self"]._sync_call_tp._max_workers = 50
|
||||
|
||||
WORKERS_DONE = True
|
46
src/api-service/__app__/onefuzzlib/azure/network.py
Normal file
46
src/api-service/__app__/onefuzzlib/azure/network.py
Normal file
@ -0,0 +1,46 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
from typing import Optional, Union
|
||||
|
||||
from msrestazure.azure_exceptions import CloudError
|
||||
from onefuzztypes.enums import ErrorCode
|
||||
from onefuzztypes.models import Error
|
||||
from onefuzztypes.primitives import Region
|
||||
|
||||
from .creds import get_base_resource_group
|
||||
from .subnet import create_virtual_network, delete_subnet, get_subnet_id
|
||||
|
||||
|
||||
class Network:
|
||||
def __init__(self, region: Region):
|
||||
self.group = get_base_resource_group()
|
||||
self.region = region
|
||||
|
||||
def exists(self) -> bool:
|
||||
return self.get_id() is not None
|
||||
|
||||
def get_id(self) -> Optional[str]:
|
||||
return get_subnet_id(self.group, self.region)
|
||||
|
||||
def create(self) -> Union[None, Error]:
|
||||
if not self.exists():
|
||||
result = create_virtual_network(self.group, self.region, self.region)
|
||||
if isinstance(result, CloudError):
|
||||
error = Error(
|
||||
code=ErrorCode.UNABLE_TO_CREATE_NETWORK, errors=[result.message]
|
||||
)
|
||||
logging.error(
|
||||
"network creation failed: %s- %s",
|
||||
self.region,
|
||||
error,
|
||||
)
|
||||
return error
|
||||
|
||||
return None
|
||||
|
||||
def delete(self) -> None:
|
||||
delete_subnet(self.group, self.region)
|
153
src/api-service/__app__/onefuzzlib/azure/queue.py
Normal file
153
src/api-service/__app__/onefuzzlib/azure/queue.py
Normal file
@ -0,0 +1,153 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import base64
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
from typing import List, Optional, Type, TypeVar, Union
|
||||
from uuid import UUID
|
||||
|
||||
from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError
|
||||
from azure.storage.queue import (
|
||||
QueueSasPermissions,
|
||||
QueueServiceClient,
|
||||
generate_queue_sas,
|
||||
)
|
||||
from memoization import cached
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .creds import get_storage_account_name_key
|
||||
|
||||
QueueNameType = Union[str, UUID]
|
||||
|
||||
|
||||
@cached(ttl=60)
|
||||
def get_queue_client(account_id: str) -> QueueServiceClient:
|
||||
logging.info("getting blob container (account_id: %s)", account_id)
|
||||
name, key = get_storage_account_name_key(account_id)
|
||||
account_url = "https://%s.queue.core.windows.net" % name
|
||||
client = QueueServiceClient(
|
||||
account_url=account_url,
|
||||
credential={"account_name": name, "account_key": key},
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
def get_queue_sas(
|
||||
queue: QueueNameType,
|
||||
*,
|
||||
account_id: str,
|
||||
read: bool = False,
|
||||
add: bool = False,
|
||||
update: bool = False,
|
||||
process: bool = False,
|
||||
) -> str:
|
||||
logging.info("getting queue sas %s (account_id: %s)", queue, account_id)
|
||||
name, key = get_storage_account_name_key(account_id)
|
||||
expiry = datetime.datetime.utcnow() + datetime.timedelta(days=30)
|
||||
|
||||
token = generate_queue_sas(
|
||||
name,
|
||||
str(queue),
|
||||
key,
|
||||
permission=QueueSasPermissions(
|
||||
read=read, add=add, update=update, process=process
|
||||
),
|
||||
expiry=expiry,
|
||||
)
|
||||
|
||||
url = "https://%s.queue.core.windows.net/%s?%s" % (name, queue, token)
|
||||
return url
|
||||
|
||||
|
||||
@cached(ttl=60)
|
||||
def create_queue(name: QueueNameType, *, account_id: str) -> None:
|
||||
client = get_queue_client(account_id)
|
||||
try:
|
||||
client.create_queue(str(name))
|
||||
except ResourceExistsError:
|
||||
pass
|
||||
|
||||
|
||||
def delete_queue(name: QueueNameType, *, account_id: str) -> None:
|
||||
client = get_queue_client(account_id)
|
||||
queues = client.list_queues()
|
||||
if str(name) in [x["name"] for x in queues]:
|
||||
client.delete_queue(name)
|
||||
|
||||
|
||||
def get_queue(name: QueueNameType, *, account_id: str) -> Optional[QueueServiceClient]:
|
||||
client = get_queue_client(account_id)
|
||||
try:
|
||||
return client.get_queue_client(str(name))
|
||||
except ResourceNotFoundError:
|
||||
return None
|
||||
|
||||
|
||||
def send_message(
|
||||
name: QueueNameType,
|
||||
message: bytes,
|
||||
*,
|
||||
account_id: str,
|
||||
) -> None:
|
||||
queue = get_queue(name, account_id=account_id)
|
||||
if queue:
|
||||
try:
|
||||
queue.send_message(base64.b64encode(message).decode())
|
||||
except ResourceNotFoundError:
|
||||
pass
|
||||
|
||||
|
||||
A = TypeVar("A", bound=BaseModel)
|
||||
|
||||
|
||||
MIN_PEEK_SIZE = 1
|
||||
MAX_PEEK_SIZE = 32
|
||||
|
||||
|
||||
# Peek at a max of 32 messages
|
||||
# https://docs.microsoft.com/en-us/python/api/azure-storage-queue/azure.storage.queue.queueclient
|
||||
def peek_queue(
|
||||
name: QueueNameType,
|
||||
*,
|
||||
account_id: str,
|
||||
object_type: Type[A],
|
||||
max_messages: int = MAX_PEEK_SIZE,
|
||||
) -> List[A]:
|
||||
result: List[A] = []
|
||||
|
||||
# message count
|
||||
if max_messages < MIN_PEEK_SIZE or max_messages > MAX_PEEK_SIZE:
|
||||
raise ValueError("invalid max messages: %s" % max_messages)
|
||||
|
||||
queue = get_queue(name, account_id=account_id)
|
||||
if not queue:
|
||||
return result
|
||||
|
||||
for message in queue.peek_messages(max_messages=max_messages):
|
||||
decoded = base64.b64decode(message.content)
|
||||
raw = json.loads(decoded)
|
||||
result.append(object_type.parse_obj(raw))
|
||||
return result
|
||||
|
||||
|
||||
def queue_object(
|
||||
name: QueueNameType,
|
||||
message: BaseModel,
|
||||
*,
|
||||
account_id: str,
|
||||
visibility_timeout: Optional[int] = None,
|
||||
) -> bool:
|
||||
queue = get_queue(name, account_id=account_id)
|
||||
if not queue:
|
||||
raise Exception("unable to queue object, no such queue: %s" % queue)
|
||||
|
||||
encoded = base64.b64encode(message.json(exclude_none=True).encode()).decode()
|
||||
try:
|
||||
queue.send_message(encoded, visibility_timeout=visibility_timeout)
|
||||
return True
|
||||
except ResourceNotFoundError:
|
||||
return False
|
70
src/api-service/__app__/onefuzzlib/azure/subnet.py
Normal file
70
src/api-service/__app__/onefuzzlib/azure/subnet.py
Normal file
@ -0,0 +1,70 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from azure.mgmt.network import NetworkManagementClient
|
||||
from msrestazure.azure_exceptions import CloudError
|
||||
from onefuzztypes.enums import ErrorCode
|
||||
from onefuzztypes.models import Error
|
||||
|
||||
from .creds import mgmt_client_factory
|
||||
|
||||
|
||||
def get_subnet_id(resource_group: str, name: str) -> Optional[str]:
|
||||
network_client = mgmt_client_factory(NetworkManagementClient)
|
||||
try:
|
||||
subnet = network_client.subnets.get(resource_group, name, name)
|
||||
return cast(str, subnet.id)
|
||||
except CloudError:
|
||||
logging.info(
|
||||
"subnet missing: resource group: %s name: %s",
|
||||
resource_group,
|
||||
name,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def delete_subnet(resource_group: str, name: str) -> Union[None, CloudError, Any]:
|
||||
network_client = mgmt_client_factory(NetworkManagementClient)
|
||||
try:
|
||||
return network_client.virtual_networks.delete(resource_group, name)
|
||||
except CloudError as err:
|
||||
if err.error and "InUseSubnetCannotBeDeleted" in str(err.error):
|
||||
logging.error(
|
||||
"subnet delete failed: %s %s : %s", resource_group, name, repr(err)
|
||||
)
|
||||
return None
|
||||
else:
|
||||
raise err
|
||||
|
||||
|
||||
def create_virtual_network(
|
||||
resource_group: str, name: str, location: str
|
||||
) -> Optional[Error]:
|
||||
logging.info(
|
||||
"creating subnet - resource group: %s name: %s location: %s",
|
||||
resource_group,
|
||||
name,
|
||||
location,
|
||||
)
|
||||
|
||||
network_client = mgmt_client_factory(NetworkManagementClient)
|
||||
params = {
|
||||
"location": location,
|
||||
"address_space": {"address_prefixes": ["10.0.0.0/8"]},
|
||||
"subnets": [{"name": name, "address_prefix": "10.0.0.0/16"}],
|
||||
}
|
||||
if "ONEFUZZ_OWNER" in os.environ:
|
||||
params["tags"] = {"OWNER": os.environ["ONEFUZZ_OWNER"]}
|
||||
|
||||
try:
|
||||
network_client.virtual_networks.create_or_update(resource_group, name, params)
|
||||
except CloudError as err:
|
||||
return Error(code=ErrorCode.UNABLE_TO_CREATE_NETWORK, errors=[str(err.message)])
|
||||
|
||||
return None
|
30
src/api-service/__app__/onefuzzlib/azure/table.py
Normal file
30
src/api-service/__app__/onefuzzlib/azure/table.py
Normal file
@ -0,0 +1,30 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from azure.cosmosdb.table import TableService
|
||||
from memoization import cached
|
||||
|
||||
from .creds import get_storage_account_name_key
|
||||
|
||||
|
||||
@cached(ttl=60)
|
||||
def get_client(
|
||||
table: Optional[str] = None, account_id: Optional[str] = None
|
||||
) -> TableService:
|
||||
if account_id is None:
|
||||
account_id = os.environ["ONEFUZZ_FUNC_STORAGE"]
|
||||
|
||||
logging.info("getting table account: (account_id: %s)", account_id)
|
||||
name, key = get_storage_account_name_key(account_id)
|
||||
client = TableService(account_name=name, account_key=key)
|
||||
|
||||
if table and not client.exists(table):
|
||||
logging.info("creating missing table %s", table)
|
||||
client.create_table(table, fail_on_exist=False)
|
||||
return client
|
273
src/api-service/__app__/onefuzzlib/azure/vm.py
Normal file
273
src/api-service/__app__/onefuzzlib/azure/vm.py
Normal file
@ -0,0 +1,273 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Union, cast
|
||||
from uuid import UUID
|
||||
|
||||
from azure.mgmt.compute import ComputeManagementClient
|
||||
from azure.mgmt.compute.models import VirtualMachine
|
||||
from msrestazure.azure_exceptions import CloudError
|
||||
from onefuzztypes.enums import OS, ErrorCode
|
||||
from onefuzztypes.models import Authentication, Error
|
||||
from onefuzztypes.primitives import Extension, Region
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .creds import get_base_resource_group, mgmt_client_factory
|
||||
from .disk import delete_disk, list_disks
|
||||
from .image import get_os
|
||||
from .ip import create_public_nic, delete_ip, delete_nic, get_ip, get_public_nic
|
||||
|
||||
|
||||
def get_vm(name: str) -> Optional[VirtualMachine]:
|
||||
resource_group = get_base_resource_group()
|
||||
|
||||
logging.debug("getting vm: %s %s - %s", resource_group, name)
|
||||
compute_client = mgmt_client_factory(ComputeManagementClient)
|
||||
try:
|
||||
return cast(
|
||||
VirtualMachine,
|
||||
compute_client.virtual_machines.get(
|
||||
resource_group, name, expand="instanceView"
|
||||
),
|
||||
)
|
||||
except CloudError as err:
|
||||
logging.debug("vm does not exist %s", err)
|
||||
return None
|
||||
|
||||
|
||||
def create_vm(
|
||||
name: str,
|
||||
location: str,
|
||||
vm_sku: str,
|
||||
image: str,
|
||||
password: str,
|
||||
ssh_public_key: str,
|
||||
) -> Union[None, Error]:
|
||||
resource_group = get_base_resource_group()
|
||||
logging.info("creating vm %s:%s:%s", resource_group, location, name)
|
||||
|
||||
compute_client = mgmt_client_factory(ComputeManagementClient)
|
||||
|
||||
nic = get_public_nic(resource_group, name)
|
||||
if nic is None:
|
||||
result = create_public_nic(resource_group, name, location)
|
||||
if isinstance(result, Error):
|
||||
return result
|
||||
logging.info("waiting on nic creation")
|
||||
return None
|
||||
|
||||
if image.startswith("/"):
|
||||
image_ref = {"id": image}
|
||||
else:
|
||||
image_val = image.split(":", 4)
|
||||
image_ref = {
|
||||
"publisher": image_val[0],
|
||||
"offer": image_val[1],
|
||||
"sku": image_val[2],
|
||||
"version": image_val[3],
|
||||
}
|
||||
|
||||
params: Dict = {
|
||||
"location": location,
|
||||
"os_profile": {
|
||||
"computer_name": "node",
|
||||
"admin_username": "onefuzz",
|
||||
"admin_password": password,
|
||||
},
|
||||
"hardware_profile": {"vm_size": vm_sku},
|
||||
"storage_profile": {"image_reference": image_ref},
|
||||
"network_profile": {"network_interfaces": [{"id": nic.id}]},
|
||||
}
|
||||
|
||||
image_os = get_os(location, image)
|
||||
if isinstance(image_os, Error):
|
||||
return image_os
|
||||
|
||||
if image_os == OS.linux:
|
||||
params["os_profile"]["linux_configuration"] = {
|
||||
"disable_password_authentication": True,
|
||||
"ssh": {
|
||||
"public_keys": [
|
||||
{
|
||||
"path": "/home/onefuzz/.ssh/authorized_keys",
|
||||
"key_data": ssh_public_key,
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
|
||||
if "ONEFUZZ_OWNER" in os.environ:
|
||||
params["tags"] = {"OWNER": os.environ["ONEFUZZ_OWNER"]}
|
||||
|
||||
try:
|
||||
compute_client.virtual_machines.create_or_update(resource_group, name, params)
|
||||
except CloudError as err:
|
||||
if "The request failed due to conflict with a concurrent request" in str(err):
|
||||
logging.debug(
|
||||
"create VM had conflicts with concurrent request, ignoring %s", err
|
||||
)
|
||||
return None
|
||||
return Error(code=ErrorCode.VM_CREATE_FAILED, errors=[str(err)])
|
||||
return None
|
||||
|
||||
|
||||
def get_extension(vm_name: str, extension_name: str) -> Optional[Any]:
|
||||
resource_group = get_base_resource_group()
|
||||
|
||||
logging.debug(
|
||||
"getting extension: %s:%s:%s - %s",
|
||||
resource_group,
|
||||
vm_name,
|
||||
extension_name,
|
||||
)
|
||||
compute_client = mgmt_client_factory(ComputeManagementClient)
|
||||
try:
|
||||
return compute_client.virtual_machine_extensions.get(
|
||||
resource_group, vm_name, extension_name
|
||||
)
|
||||
except CloudError as err:
|
||||
logging.error("extension does not exist %s", err)
|
||||
return None
|
||||
|
||||
|
||||
def create_extension(vm_name: str, extension: Dict) -> Any:
|
||||
resource_group = get_base_resource_group()
|
||||
|
||||
logging.info(
|
||||
"creating extension: %s:%s:%s", resource_group, vm_name, extension["name"]
|
||||
)
|
||||
compute_client = mgmt_client_factory(ComputeManagementClient)
|
||||
return compute_client.virtual_machine_extensions.create_or_update(
|
||||
resource_group, vm_name, extension["name"], extension
|
||||
)
|
||||
|
||||
|
||||
def delete_vm(name: str) -> Any:
|
||||
resource_group = get_base_resource_group()
|
||||
|
||||
logging.info("deleting vm: %s %s", resource_group, name)
|
||||
compute_client = mgmt_client_factory(ComputeManagementClient)
|
||||
return compute_client.virtual_machines.delete(resource_group, name)
|
||||
|
||||
|
||||
def has_components(name: str) -> bool:
|
||||
# check if any of the components associated with a VM still exist.
|
||||
#
|
||||
# Azure VM Deletion requires we first delete the VM, then delete all of it's
|
||||
# resources. This is required to ensure we've cleaned it all up before
|
||||
# marking it "done"
|
||||
resource_group = get_base_resource_group()
|
||||
if get_vm(name):
|
||||
return True
|
||||
if get_public_nic(resource_group, name):
|
||||
return True
|
||||
if get_ip(resource_group, name):
|
||||
return True
|
||||
|
||||
disks = [x.name for x in list_disks(resource_group) if x.name.startswith(name)]
|
||||
if disks:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def delete_vm_components(name: str) -> bool:
|
||||
resource_group = get_base_resource_group()
|
||||
logging.info("deleting vm components %s:%s", resource_group, name)
|
||||
if get_vm(name):
|
||||
logging.info("deleting vm %s:%s", resource_group, name)
|
||||
delete_vm(name)
|
||||
return False
|
||||
|
||||
if get_public_nic(resource_group, name):
|
||||
logging.info("deleting nic %s:%s", resource_group, name)
|
||||
delete_nic(resource_group, name)
|
||||
return False
|
||||
|
||||
if get_ip(resource_group, name):
|
||||
logging.info("deleting ip %s:%s", resource_group, name)
|
||||
delete_ip(resource_group, name)
|
||||
return False
|
||||
|
||||
disks = [x.name for x in list_disks(resource_group) if x.name.startswith(name)]
|
||||
if disks:
|
||||
for disk in disks:
|
||||
logging.info("deleting disk %s:%s", resource_group, disk)
|
||||
delete_disk(resource_group, disk)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class VM(BaseModel):
|
||||
name: Union[UUID, str]
|
||||
region: Region
|
||||
sku: str
|
||||
image: str
|
||||
auth: Authentication
|
||||
|
||||
def is_deleted(self) -> bool:
|
||||
return has_components(str(self.name))
|
||||
|
||||
def exists(self) -> bool:
|
||||
return self.get() is not None
|
||||
|
||||
def get(self) -> Optional[VirtualMachine]:
|
||||
return get_vm(str(self.name))
|
||||
|
||||
def create(self) -> Union[None, Error]:
|
||||
if self.get() is not None:
|
||||
return None
|
||||
|
||||
logging.info("vm creating: %s", self.name)
|
||||
return create_vm(
|
||||
str(self.name),
|
||||
self.region,
|
||||
self.sku,
|
||||
self.image,
|
||||
self.auth.password,
|
||||
self.auth.public_key,
|
||||
)
|
||||
|
||||
def delete(self) -> bool:
|
||||
return delete_vm_components(str(self.name))
|
||||
|
||||
def add_extensions(self, extensions: List[Extension]) -> Union[bool, Error]:
|
||||
status = []
|
||||
to_create = []
|
||||
for config in extensions:
|
||||
if not isinstance(config["name"], str):
|
||||
logging.error("vm agent - incompatable name: %s", repr(config))
|
||||
continue
|
||||
extension = get_extension(str(self.name), config["name"])
|
||||
|
||||
if extension:
|
||||
logging.info(
|
||||
"vm extension state: %s - %s - %s",
|
||||
self.name,
|
||||
config["name"],
|
||||
extension.provisioning_state,
|
||||
)
|
||||
status.append(extension.provisioning_state)
|
||||
else:
|
||||
to_create.append(config)
|
||||
|
||||
if to_create:
|
||||
for config in to_create:
|
||||
create_extension(str(self.name), config)
|
||||
else:
|
||||
if all([x == "Succeeded" for x in status]):
|
||||
return True
|
||||
elif "Failed" in status:
|
||||
return Error(
|
||||
code=ErrorCode.VM_CREATE_FAILED,
|
||||
errors=["failed to launch extension"],
|
||||
)
|
||||
elif not ("Creating" in status or "Updating" in status):
|
||||
logging.error("vm agent - unknown state %s: %s", self.name, status)
|
||||
|
||||
return False
|
334
src/api-service/__app__/onefuzzlib/azure/vmss.py
Normal file
334
src/api-service/__app__/onefuzzlib/azure/vmss.py
Normal file
@ -0,0 +1,334 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Union, cast
|
||||
from uuid import UUID
|
||||
|
||||
from azure.mgmt.compute import ComputeManagementClient
|
||||
from azure.mgmt.compute.models import ResourceSku, ResourceSkuRestrictionsType
|
||||
from memoization import cached
|
||||
from msrestazure.azure_exceptions import CloudError
|
||||
from onefuzztypes.enums import OS, ErrorCode
|
||||
from onefuzztypes.models import Error
|
||||
from onefuzztypes.primitives import Region
|
||||
|
||||
from .creds import get_base_resource_group, mgmt_client_factory
|
||||
from .image import get_os
|
||||
|
||||
|
||||
def list_vmss(name: UUID) -> Optional[List[str]]:
|
||||
resource_group = get_base_resource_group()
|
||||
client = mgmt_client_factory(ComputeManagementClient)
|
||||
try:
|
||||
instances = [
|
||||
x.instance_id
|
||||
for x in client.virtual_machine_scale_set_vms.list(
|
||||
resource_group, str(name)
|
||||
)
|
||||
]
|
||||
return instances
|
||||
except CloudError as err:
|
||||
logging.error("cloud error listing vmss: %s (%s)", name, err)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def delete_vmss(name: UUID) -> Any:
|
||||
resource_group = get_base_resource_group()
|
||||
compute_client = mgmt_client_factory(ComputeManagementClient)
|
||||
try:
|
||||
compute_client.virtual_machine_scale_sets.delete(resource_group, str(name))
|
||||
except CloudError as err:
|
||||
logging.error("cloud error deleting vmss: %s (%s)", name, err)
|
||||
|
||||
|
||||
def get_vmss(name: UUID) -> Optional[Any]:
|
||||
resource_group = get_base_resource_group()
|
||||
logging.debug("getting vm: %s", name)
|
||||
compute_client = mgmt_client_factory(ComputeManagementClient)
|
||||
try:
|
||||
return compute_client.virtual_machine_scale_sets.get(resource_group, str(name))
|
||||
except CloudError as err:
|
||||
logging.debug("vm does not exist %s", err)
|
||||
return None
|
||||
|
||||
|
||||
def resize_vmss(name: UUID, capacity: int) -> None:
|
||||
check_can_update(name)
|
||||
|
||||
resource_group = get_base_resource_group()
|
||||
logging.info("updating VM count - name: %s vm_count: %d", name, capacity)
|
||||
compute_client = mgmt_client_factory(ComputeManagementClient)
|
||||
compute_client.virtual_machine_scale_sets.update(
|
||||
resource_group, str(name), {"sku": {"capacity": capacity}}
|
||||
)
|
||||
|
||||
|
||||
def get_vmss_size(name: UUID) -> Optional[int]:
|
||||
vmss = get_vmss(name)
|
||||
if vmss is None:
|
||||
return None
|
||||
return cast(int, vmss.sku.capacity)
|
||||
|
||||
|
||||
def list_instance_ids(name: UUID) -> Dict[UUID, str]:
|
||||
logging.info("get instance IDs for scaleset: %s", name)
|
||||
resource_group = get_base_resource_group()
|
||||
compute_client = mgmt_client_factory(ComputeManagementClient)
|
||||
|
||||
results = {}
|
||||
try:
|
||||
for instance in compute_client.virtual_machine_scale_set_vms.list(
|
||||
resource_group, str(name)
|
||||
):
|
||||
results[UUID(instance.vm_id)] = cast(str, instance.instance_id)
|
||||
except CloudError:
|
||||
logging.debug("scaleset not available: %s", name)
|
||||
return results
|
||||
|
||||
|
||||
@cached(ttl=60)
|
||||
def get_instance_id(name: UUID, vm_id: UUID) -> Union[str, Error]:
|
||||
resource_group = get_base_resource_group()
|
||||
logging.info("get instance ID for scaleset node: %s:%s", name, vm_id)
|
||||
compute_client = mgmt_client_factory(ComputeManagementClient)
|
||||
|
||||
vm_id_str = str(vm_id)
|
||||
for instance in compute_client.virtual_machine_scale_set_vms.list(
|
||||
resource_group, str(name)
|
||||
):
|
||||
if instance.vm_id == vm_id_str:
|
||||
return cast(str, instance.instance_id)
|
||||
|
||||
return Error(
|
||||
code=ErrorCode.UNABLE_TO_FIND,
|
||||
errors=["unable to find scaleset machine: %s:%s" % (name, vm_id)],
|
||||
)
|
||||
|
||||
|
||||
class UnableToUpdate(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def check_can_update(name: UUID) -> Any:
|
||||
vmss = get_vmss(name)
|
||||
if vmss is None:
|
||||
raise UnableToUpdate
|
||||
|
||||
if vmss.provisioning_state != "Succeeded":
|
||||
raise UnableToUpdate
|
||||
|
||||
return vmss
|
||||
|
||||
|
||||
def reimage_vmss_nodes(name: UUID, vm_ids: List[UUID]) -> Optional[Error]:
|
||||
check_can_update(name)
|
||||
|
||||
resource_group = get_base_resource_group()
|
||||
logging.info("reimaging scaleset VM - name: %s vm_ids:%s", name, vm_ids)
|
||||
compute_client = mgmt_client_factory(ComputeManagementClient)
|
||||
|
||||
instance_ids = []
|
||||
machine_to_id = list_instance_ids(name)
|
||||
for vm_id in vm_ids:
|
||||
if vm_id in machine_to_id:
|
||||
instance_ids.append(machine_to_id[vm_id])
|
||||
else:
|
||||
logging.info("unable to find vm_id for %s:%s", name, vm_id)
|
||||
|
||||
if instance_ids:
|
||||
compute_client.virtual_machine_scale_sets.reimage_all(
|
||||
resource_group, str(name), instance_ids=instance_ids
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def delete_vmss_nodes(name: UUID, vm_ids: List[UUID]) -> Optional[Error]:
|
||||
check_can_update(name)
|
||||
|
||||
resource_group = get_base_resource_group()
|
||||
logging.info("deleting scaleset VM - name: %s vm_ids:%s", name, vm_ids)
|
||||
compute_client = mgmt_client_factory(ComputeManagementClient)
|
||||
|
||||
instance_ids = []
|
||||
machine_to_id = list_instance_ids(name)
|
||||
for vm_id in vm_ids:
|
||||
if vm_id in machine_to_id:
|
||||
instance_ids.append(machine_to_id[vm_id])
|
||||
else:
|
||||
logging.info("unable to find vm_id for %s:%s", name, vm_id)
|
||||
|
||||
if instance_ids:
|
||||
compute_client.virtual_machine_scale_sets.delete_instances(
|
||||
resource_group, str(name), instance_ids=instance_ids
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def update_extensions(name: UUID, extensions: List[Any]) -> None:
|
||||
check_can_update(name)
|
||||
|
||||
resource_group = get_base_resource_group()
|
||||
logging.info("updating VM extensions: %s", name)
|
||||
compute_client = mgmt_client_factory(ComputeManagementClient)
|
||||
compute_client.virtual_machine_scale_sets.update(
|
||||
resource_group,
|
||||
str(name),
|
||||
{"virtual_machine_profile": {"extension_profile": {"extensions": extensions}}},
|
||||
)
|
||||
|
||||
|
||||
def create_vmss(
|
||||
location: Region,
|
||||
name: UUID,
|
||||
vm_sku: str,
|
||||
vm_count: int,
|
||||
image: str,
|
||||
network_id: str,
|
||||
spot_instances: bool,
|
||||
extensions: List[Any],
|
||||
password: str,
|
||||
ssh_public_key: str,
|
||||
tags: Dict[str, str],
|
||||
) -> Optional[Error]:
|
||||
|
||||
vmss = get_vmss(name)
|
||||
if vmss is not None:
|
||||
return None
|
||||
|
||||
logging.info(
|
||||
"creating VM count"
|
||||
"name: %s vm_sku: %s vm_count: %d "
|
||||
"image: %s subnet: %s spot_instances: %s",
|
||||
name,
|
||||
vm_sku,
|
||||
vm_count,
|
||||
image,
|
||||
network_id,
|
||||
spot_instances,
|
||||
)
|
||||
|
||||
resource_group = get_base_resource_group()
|
||||
|
||||
compute_client = mgmt_client_factory(ComputeManagementClient)
|
||||
|
||||
if image.startswith("/"):
|
||||
image_ref = {"id": image}
|
||||
else:
|
||||
image_val = image.split(":", 4)
|
||||
image_ref = {
|
||||
"publisher": image_val[0],
|
||||
"offer": image_val[1],
|
||||
"sku": image_val[2],
|
||||
"version": image_val[3],
|
||||
}
|
||||
|
||||
sku = {"name": vm_sku, "tier": "Standard", "capacity": vm_count}
|
||||
|
||||
params: Dict[str, Any] = {
|
||||
"location": location,
|
||||
"do_not_run_extensions_on_overprovisioned_vms": True,
|
||||
"upgrade_policy": {"mode": "Manual"},
|
||||
"sku": sku,
|
||||
"identity": {"type": "SystemAssigned"},
|
||||
"virtual_machine_profile": {
|
||||
"priority": "Regular",
|
||||
"storage_profile": {"image_reference": image_ref},
|
||||
"os_profile": {
|
||||
"computer_name_prefix": "node",
|
||||
"admin_username": "onefuzz",
|
||||
"admin_password": password,
|
||||
},
|
||||
"network_profile": {
|
||||
"network_interface_configurations": [
|
||||
{
|
||||
"name": "onefuzz-nic",
|
||||
"primary": True,
|
||||
"ip_configurations": [
|
||||
{"name": "onefuzz-ip-config", "subnet": {"id": network_id}}
|
||||
],
|
||||
}
|
||||
]
|
||||
},
|
||||
"extension_profile": {"extensions": extensions},
|
||||
},
|
||||
}
|
||||
|
||||
image_os = get_os(location, image)
|
||||
if isinstance(image_os, Error):
|
||||
return image_os
|
||||
|
||||
if image_os == OS.linux:
|
||||
params["virtual_machine_profile"]["os_profile"]["linux_configuration"] = {
|
||||
"disable_password_authentication": True,
|
||||
"ssh": {
|
||||
"public_keys": [
|
||||
{
|
||||
"path": "/home/onefuzz/.ssh/authorized_keys",
|
||||
"key_data": ssh_public_key,
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
|
||||
if spot_instances:
|
||||
# Setting max price to -1 means it won't be evicted because of
|
||||
# price.
|
||||
#
|
||||
# https://docs.microsoft.com/en-us/azure/
|
||||
# virtual-machine-scale-sets/use-spot#resource-manager-templates
|
||||
params["virtual_machine_profile"].update(
|
||||
{
|
||||
"eviction_policy": "Delete",
|
||||
"priority": "Spot",
|
||||
"billing_profile": {"max_price": -1},
|
||||
}
|
||||
)
|
||||
|
||||
params["tags"] = tags.copy()
|
||||
owner = os.environ.get("ONEFUZZ_OWNER")
|
||||
if owner:
|
||||
params["tags"]["OWNER"] = owner
|
||||
|
||||
try:
|
||||
compute_client.virtual_machine_scale_sets.create_or_update(
|
||||
resource_group, name, params
|
||||
)
|
||||
except CloudError as err:
|
||||
if "The request failed due to conflict with a concurrent request" in repr(err):
|
||||
logging.debug(
|
||||
"create VM had conflicts with concurrent request, ignoring %s", err
|
||||
)
|
||||
return None
|
||||
return Error(
|
||||
code=ErrorCode.VM_CREATE_FAILED,
|
||||
errors=["creating vmss: %s" % err],
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@cached(ttl=60)
|
||||
def list_available_skus(location: str) -> List[str]:
|
||||
compute_client = mgmt_client_factory(ComputeManagementClient)
|
||||
skus: List[ResourceSku] = list(
|
||||
compute_client.resource_skus.list(filter="location eq '%s'" % location)
|
||||
)
|
||||
sku_names: List[str] = []
|
||||
for sku in skus:
|
||||
available = True
|
||||
if sku.restrictions is not None:
|
||||
for restriction in sku.restrictions:
|
||||
if restriction.type == ResourceSkuRestrictionsType.location and (
|
||||
location.upper() in [v.upper() for v in restriction.values]
|
||||
):
|
||||
available = False
|
||||
break
|
||||
|
||||
if available:
|
||||
sku_names.append(sku.name)
|
||||
return sku_names
|
54
src/api-service/__app__/onefuzzlib/dashboard.py
Normal file
54
src/api-service/__app__/onefuzzlib/dashboard.py
Normal file
@ -0,0 +1,54 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import json
|
||||
from enum import Enum
|
||||
from queue import Empty, Queue
|
||||
from typing import Dict, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from onefuzztypes.primitives import Event
|
||||
|
||||
EVENTS: Queue = Queue()
|
||||
|
||||
|
||||
def resolve(data: Event) -> Union[str, int, Dict[str, str]]:
|
||||
if isinstance(data, str):
|
||||
return data
|
||||
if isinstance(data, UUID):
|
||||
return str(data)
|
||||
elif isinstance(data, Enum):
|
||||
return data.name
|
||||
elif isinstance(data, int):
|
||||
return data
|
||||
elif isinstance(data, dict):
|
||||
for x in data:
|
||||
data[x] = str(data[x])
|
||||
return data
|
||||
raise NotImplementedError("no conversion from %s" % type(data))
|
||||
|
||||
|
||||
def get_event() -> Optional[str]:
|
||||
events = []
|
||||
|
||||
for _ in range(10):
|
||||
try:
|
||||
(event, data) = EVENTS.get(block=False)
|
||||
events.append({"type": event, "data": data})
|
||||
EVENTS.task_done()
|
||||
except Empty:
|
||||
break
|
||||
|
||||
if events:
|
||||
return json.dumps({"target": "dashboard", "arguments": events})
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def add_event(message_type: str, data: Dict[str, Event]) -> None:
|
||||
for key in data:
|
||||
data[key] = resolve(data[key])
|
||||
|
||||
EVENTS.put((message_type, data))
|
328
src/api-service/__app__/onefuzzlib/extension.py
Normal file
328
src/api-service/__app__/onefuzzlib/extension.py
Normal file
@ -0,0 +1,328 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from onefuzztypes.enums import OS, AgentMode
|
||||
from onefuzztypes.models import AgentConfig, ReproConfig
|
||||
from onefuzztypes.primitives import Extension, Region
|
||||
|
||||
from .azure.containers import get_container_sas_url, get_file_sas_url, save_blob
|
||||
from .azure.creds import get_func_storage, get_instance_name
|
||||
from .azure.monitor import get_monitor_settings
|
||||
from .reports import get_report
|
||||
|
||||
# TODO: figure out how to create VM specific SSH keys for Windows.
|
||||
#
|
||||
# Previously done via task specific scripts:
|
||||
|
||||
# if is_windows and auth is not None:
|
||||
# ssh_key = auth.public_key.strip()
|
||||
# ssh_path = "$env:ProgramData/ssh/administrators_authorized_keys"
|
||||
# commands += ['Set-Content -Path %s -Value "%s"' % (ssh_path, ssh_key)]
|
||||
# return commands
|
||||
|
||||
|
||||
def generic_extensions(region: Region, os: OS) -> List[Extension]:
|
||||
extensions = [monitor_extension(region, os)]
|
||||
depedency = dependency_extension(region, os)
|
||||
if depedency:
|
||||
extensions.append(depedency)
|
||||
|
||||
return extensions
|
||||
|
||||
|
||||
def monitor_extension(region: Region, os: OS) -> Extension:
|
||||
settings = get_monitor_settings()
|
||||
|
||||
if os == OS.windows:
|
||||
return {
|
||||
"name": "OMSExtension",
|
||||
"publisher": "Microsoft.EnterpriseCloud.Monitoring",
|
||||
"type": "MicrosoftMonitoringAgent",
|
||||
"typeHandlerVersion": "1.0",
|
||||
"location": region,
|
||||
"autoUpgradeMinorVersion": True,
|
||||
"settings": {"workspaceId": settings["id"]},
|
||||
"protectedSettings": {"workspaceKey": settings["key"]},
|
||||
}
|
||||
elif os == OS.linux:
|
||||
return {
|
||||
"name": "OMSExtension",
|
||||
"publisher": "Microsoft.EnterpriseCloud.Monitoring",
|
||||
"type": "OmsAgentForLinux",
|
||||
"typeHandlerVersion": "1.12",
|
||||
"location": region,
|
||||
"autoUpgradeMinorVersion": True,
|
||||
"settings": {"workspaceId": settings["id"]},
|
||||
"protectedSettings": {"workspaceKey": settings["key"]},
|
||||
}
|
||||
raise NotImplementedError("unsupported os: %s" % os)
|
||||
|
||||
|
||||
def dependency_extension(region: Region, os: OS) -> Optional[Extension]:
|
||||
if os == OS.windows:
|
||||
extension = {
|
||||
"name": "DependencyAgentWindows",
|
||||
"publisher": "Microsoft.Azure.Monitoring.DependencyAgent",
|
||||
"type": "DependencyAgentWindows",
|
||||
"typeHandlerVersion": "9.5",
|
||||
"location": region,
|
||||
"autoUpgradeMinorVersion": True,
|
||||
}
|
||||
return extension
|
||||
else:
|
||||
# TODO: dependency agent for linux is not reliable
|
||||
# extension = {
|
||||
# "name": "DependencyAgentLinux",
|
||||
# "publisher": "Microsoft.Azure.Monitoring.DependencyAgent",
|
||||
# "type": "DependencyAgentLinux",
|
||||
# "typeHandlerVersion": "9.5",
|
||||
# "location": vm.region,
|
||||
# "autoUpgradeMinorVersion": True,
|
||||
# }
|
||||
return None
|
||||
|
||||
|
||||
def build_pool_config(pool_name: str) -> str:
|
||||
agent_config = AgentConfig(
|
||||
pool_name=pool_name,
|
||||
onefuzz_url="https://%s.azurewebsites.net" % get_instance_name(),
|
||||
instrumentation_key=os.environ.get("APPINSIGHTS_INSTRUMENTATIONKEY"),
|
||||
telemetry_key=os.environ.get("ONEFUZZ_TELEMETRY"),
|
||||
)
|
||||
|
||||
save_blob(
|
||||
"vm-scripts",
|
||||
"%s/config.json" % pool_name,
|
||||
agent_config.json(),
|
||||
account_id=get_func_storage(),
|
||||
)
|
||||
|
||||
return get_file_sas_url(
|
||||
"vm-scripts",
|
||||
"%s/config.json" % pool_name,
|
||||
account_id=get_func_storage(),
|
||||
read=True,
|
||||
)
|
||||
|
||||
|
||||
def update_managed_scripts(mode: AgentMode) -> None:
|
||||
commands = [
|
||||
"azcopy sync '%s' instance-specific-setup"
|
||||
% (
|
||||
get_container_sas_url(
|
||||
"instance-specific-setup",
|
||||
read=True,
|
||||
list=True,
|
||||
account_id=get_func_storage(),
|
||||
)
|
||||
),
|
||||
"azcopy sync '%s' tools"
|
||||
% (
|
||||
get_container_sas_url(
|
||||
"tools", read=True, list=True, account_id=get_func_storage()
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
save_blob(
|
||||
"vm-scripts",
|
||||
"managed.ps1",
|
||||
"\r\n".join(commands) + "\r\n",
|
||||
account_id=get_func_storage(),
|
||||
)
|
||||
save_blob(
|
||||
"vm-scripts",
|
||||
"managed.sh",
|
||||
"\n".join(commands) + "\n",
|
||||
account_id=get_func_storage(),
|
||||
)
|
||||
|
||||
|
||||
def agent_config(
|
||||
region: Region, os: OS, mode: AgentMode, *, urls: Optional[List[str]] = None
|
||||
) -> Extension:
|
||||
update_managed_scripts(mode)
|
||||
|
||||
if urls is None:
|
||||
urls = []
|
||||
|
||||
if os == OS.windows:
|
||||
urls += [
|
||||
get_file_sas_url(
|
||||
"vm-scripts",
|
||||
"managed.ps1",
|
||||
account_id=get_func_storage(),
|
||||
read=True,
|
||||
),
|
||||
get_file_sas_url(
|
||||
"tools",
|
||||
"win64/azcopy.exe",
|
||||
account_id=get_func_storage(),
|
||||
read=True,
|
||||
),
|
||||
get_file_sas_url(
|
||||
"tools",
|
||||
"win64/setup.ps1",
|
||||
account_id=get_func_storage(),
|
||||
read=True,
|
||||
),
|
||||
get_file_sas_url(
|
||||
"tools",
|
||||
"win64/onefuzz.ps1",
|
||||
account_id=get_func_storage(),
|
||||
read=True,
|
||||
),
|
||||
]
|
||||
to_execute_cmd = (
|
||||
"powershell -ExecutionPolicy Unrestricted -File win64/setup.ps1 "
|
||||
"-mode %s" % (mode.name)
|
||||
)
|
||||
extension = {
|
||||
"name": "CustomScriptExtension",
|
||||
"type": "CustomScriptExtension",
|
||||
"publisher": "Microsoft.Compute",
|
||||
"location": region,
|
||||
"type_handler_version": "1.9",
|
||||
"auto_upgrade_minor_version": True,
|
||||
"settings": {"commandToExecute": to_execute_cmd, "fileUris": urls},
|
||||
"protectedSettings": {},
|
||||
}
|
||||
return extension
|
||||
elif os == OS.linux:
|
||||
urls += [
|
||||
get_file_sas_url(
|
||||
"vm-scripts",
|
||||
"managed.sh",
|
||||
account_id=get_func_storage(),
|
||||
read=True,
|
||||
),
|
||||
get_file_sas_url(
|
||||
"tools",
|
||||
"linux/azcopy",
|
||||
account_id=get_func_storage(),
|
||||
read=True,
|
||||
),
|
||||
get_file_sas_url(
|
||||
"tools",
|
||||
"linux/setup.sh",
|
||||
account_id=get_func_storage(),
|
||||
read=True,
|
||||
),
|
||||
]
|
||||
to_execute_cmd = "sh setup.sh %s" % (mode.name)
|
||||
|
||||
extension = {
|
||||
"name": "CustomScript",
|
||||
"publisher": "Microsoft.Azure.Extensions",
|
||||
"type": "CustomScript",
|
||||
"typeHandlerVersion": "2.1",
|
||||
"location": region,
|
||||
"autoUpgradeMinorVersion": True,
|
||||
"settings": {"commandToExecute": to_execute_cmd, "fileUris": urls},
|
||||
"protectedSettings": {},
|
||||
}
|
||||
return extension
|
||||
|
||||
raise NotImplementedError("unsupported OS: %s" % os)
|
||||
|
||||
|
||||
def fuzz_extensions(region: Region, os: OS, pool_name: str) -> List[Extension]:
|
||||
urls = [build_pool_config(pool_name)]
|
||||
fuzz_extension = agent_config(region, os, AgentMode.fuzz, urls=urls)
|
||||
extensions = generic_extensions(region, os)
|
||||
extensions += [fuzz_extension]
|
||||
return extensions
|
||||
|
||||
|
||||
def repro_extensions(
|
||||
region: Region,
|
||||
repro_os: OS,
|
||||
repro_id: UUID,
|
||||
repro_config: ReproConfig,
|
||||
setup_container: Optional[str],
|
||||
) -> List[Extension]:
|
||||
# TODO - what about contents of repro.ps1 / repro.sh?
|
||||
report = get_report(repro_config.container, repro_config.path)
|
||||
if report is None:
|
||||
raise Exception("invalid report: %s" % repro_config)
|
||||
|
||||
commands = []
|
||||
if setup_container:
|
||||
commands += [
|
||||
"azcopy sync '%s' ./setup"
|
||||
% (get_container_sas_url(setup_container, read=True, list=True)),
|
||||
]
|
||||
|
||||
urls = [
|
||||
get_file_sas_url(repro_config.container, repro_config.path, read=True),
|
||||
get_file_sas_url(
|
||||
report.input_blob.container, report.input_blob.name, read=True
|
||||
),
|
||||
]
|
||||
|
||||
repro_files = []
|
||||
if repro_os == OS.windows:
|
||||
repro_files = ["%s/repro.ps1" % repro_id]
|
||||
task_script = "\r\n".join(commands)
|
||||
script_name = "task-setup.ps1"
|
||||
else:
|
||||
repro_files = ["%s/repro.sh" % repro_id, "%s/repro-stdout.sh" % repro_id]
|
||||
commands += ["chmod -R +x setup"]
|
||||
task_script = "\n".join(commands)
|
||||
script_name = "task-setup.sh"
|
||||
|
||||
save_blob(
|
||||
"task-configs",
|
||||
"%s/%s" % (repro_id, script_name),
|
||||
task_script,
|
||||
account_id=get_func_storage(),
|
||||
)
|
||||
|
||||
for repro_file in repro_files:
|
||||
urls += [
|
||||
get_file_sas_url(
|
||||
"repro-scripts",
|
||||
repro_file,
|
||||
account_id=get_func_storage(),
|
||||
read=True,
|
||||
),
|
||||
get_file_sas_url(
|
||||
"task-configs",
|
||||
"%s/%s" % (repro_id, script_name),
|
||||
account_id=get_func_storage(),
|
||||
read=True,
|
||||
),
|
||||
]
|
||||
|
||||
base_extension = agent_config(region, repro_os, AgentMode.repro, urls=urls)
|
||||
extensions = generic_extensions(region, repro_os)
|
||||
extensions += [base_extension]
|
||||
return extensions
|
||||
|
||||
|
||||
def proxy_manager_extensions(region: Region) -> List[Extension]:
|
||||
urls = [
|
||||
get_file_sas_url(
|
||||
"proxy-configs",
|
||||
"%s/config.json" % region,
|
||||
account_id=get_func_storage(),
|
||||
read=True,
|
||||
),
|
||||
get_file_sas_url(
|
||||
"tools",
|
||||
"linux/onefuzz-proxy-manager",
|
||||
account_id=get_func_storage(),
|
||||
read=True,
|
||||
),
|
||||
]
|
||||
|
||||
base_extension = agent_config(region, OS.linux, AgentMode.proxy, urls=urls)
|
||||
extensions = generic_extensions(region, OS.linux)
|
||||
extensions += [base_extension]
|
||||
return extensions
|
45
src/api-service/__app__/onefuzzlib/heartbeat.py
Normal file
45
src/api-service/__app__/onefuzzlib/heartbeat.py
Normal file
@ -0,0 +1,45 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from typing import List, Tuple
|
||||
from uuid import UUID
|
||||
|
||||
from onefuzztypes.models import Heartbeat as BASE
|
||||
from onefuzztypes.models import HeartbeatEntry, HeartbeatSummary
|
||||
|
||||
from .orm import ORMMixin
|
||||
|
||||
|
||||
class Heartbeat(BASE, ORMMixin):
|
||||
@classmethod
|
||||
def add(cls, entry: HeartbeatEntry) -> None:
|
||||
for value in entry.data:
|
||||
heartbeat_id = "-".join([str(entry.machine_id), value["type"].name])
|
||||
heartbeat = cls(
|
||||
task_id=entry.task_id,
|
||||
heartbeat_id=heartbeat_id,
|
||||
machine_id=entry.machine_id,
|
||||
heartbeat_type=value["type"],
|
||||
)
|
||||
heartbeat.save()
|
||||
|
||||
@classmethod
|
||||
def get_heartbeats(cls, task_id: UUID) -> List[HeartbeatSummary]:
|
||||
entries = cls.search(query={"task_id": [task_id]})
|
||||
|
||||
result = []
|
||||
for entry in entries:
|
||||
result.append(
|
||||
HeartbeatSummary(
|
||||
timestamp=entry.Timestamp,
|
||||
machine_id=entry.machine_id,
|
||||
type=entry.heartbeat_type,
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def key_fields(cls) -> Tuple[str, str]:
|
||||
return ("task_id", "heartbeat_id")
|
83
src/api-service/__app__/onefuzzlib/jobs.py
Normal file
83
src/api-service/__app__/onefuzzlib/jobs.py
Normal file
@ -0,0 +1,83 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from onefuzztypes.enums import JobState, TaskState
|
||||
from onefuzztypes.models import Job as BASE_JOB
|
||||
|
||||
from .orm import MappingIntStrAny, ORMMixin, QueryFilter
|
||||
from .tasks.main import Task
|
||||
|
||||
|
||||
class Job(BASE_JOB, ORMMixin):
|
||||
@classmethod
|
||||
def key_fields(cls) -> Tuple[str, Optional[str]]:
|
||||
return ("job_id", None)
|
||||
|
||||
@classmethod
|
||||
def search_states(cls, *, states: Optional[List[JobState]] = None) -> List["Job"]:
|
||||
query: QueryFilter = {}
|
||||
if states:
|
||||
query["state"] = states
|
||||
return cls.search(query=query)
|
||||
|
||||
@classmethod
|
||||
def search_expired(cls) -> List["Job"]:
|
||||
time_filter = "end_time lt datetime'%s'" % datetime.utcnow().isoformat()
|
||||
|
||||
return cls.search(
|
||||
query={"state": JobState.available()}, raw_unchecked_filter=time_filter
|
||||
)
|
||||
|
||||
def save_exclude(self) -> Optional[MappingIntStrAny]:
|
||||
return {"task_info": ...}
|
||||
|
||||
def event_include(self) -> Optional[MappingIntStrAny]:
|
||||
return {
|
||||
"job_id": ...,
|
||||
"state": ...,
|
||||
"error": ...,
|
||||
}
|
||||
|
||||
def telemetry_include(self) -> Optional[MappingIntStrAny]:
|
||||
return {
|
||||
"machine_id": ...,
|
||||
"state": ...,
|
||||
"scaleset_id": ...,
|
||||
}
|
||||
|
||||
def init(self) -> None:
|
||||
logging.info("init job: %s", self.job_id)
|
||||
self.state = JobState.enabled
|
||||
self.save()
|
||||
|
||||
def stopping(self) -> None:
|
||||
self.state = JobState.stopping
|
||||
logging.info("stopping job: %s", self.job_id)
|
||||
not_stopped = [
|
||||
task
|
||||
for task in Task.search(query={"job_id": [self.job_id]})
|
||||
if task.state != TaskState.stopped
|
||||
]
|
||||
|
||||
if not_stopped:
|
||||
for task in not_stopped:
|
||||
task.state = TaskState.stopping
|
||||
task.save()
|
||||
else:
|
||||
self.state = JobState.stopped
|
||||
self.save()
|
||||
|
||||
def queue_stop(self) -> None:
|
||||
self.queue(method=self.stopping)
|
||||
|
||||
def on_start(self) -> None:
|
||||
# try to keep this effectively idempotent
|
||||
if self.end_time is None:
|
||||
self.end_time = datetime.utcnow() + timedelta(hours=self.config.duration)
|
||||
self.save()
|
206
src/api-service/__app__/onefuzzlib/notifications/ado.py
Normal file
206
src/api-service/__app__/onefuzzlib/notifications/ado.py
Normal file
@ -0,0 +1,206 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
from typing import Iterator, List, Optional
|
||||
|
||||
from azure.devops.connection import Connection
|
||||
from azure.devops.credentials import BasicAuthentication
|
||||
from azure.devops.exceptions import AzureDevOpsServiceError
|
||||
from azure.devops.v6_0.work_item_tracking.models import (
|
||||
CommentCreate,
|
||||
JsonPatchOperation,
|
||||
Wiql,
|
||||
WorkItem,
|
||||
)
|
||||
from azure.devops.v6_0.work_item_tracking.work_item_tracking_client import (
|
||||
WorkItemTrackingClient,
|
||||
)
|
||||
from memoization import cached
|
||||
from onefuzztypes.models import ADOTemplate, Report
|
||||
|
||||
from .common import Render
|
||||
|
||||
|
||||
@cached(ttl=60)
|
||||
def get_ado_client(base_url: str, token: str) -> WorkItemTrackingClient:
|
||||
connection = Connection(base_url=base_url, creds=BasicAuthentication("PAT", token))
|
||||
client = connection.clients_v6_0.get_work_item_tracking_client()
|
||||
return client
|
||||
|
||||
|
||||
@cached(ttl=60)
|
||||
def get_valid_fields(
|
||||
client: WorkItemTrackingClient, project: Optional[str] = None
|
||||
) -> List[str]:
|
||||
valid_fields = [
|
||||
x.reference_name.lower()
|
||||
for x in client.get_fields(project=project, expand="ExtensionFields")
|
||||
]
|
||||
return valid_fields
|
||||
|
||||
|
||||
class ADO:
|
||||
def __init__(
|
||||
self, container: str, filename: str, config: ADOTemplate, report: Report
|
||||
):
|
||||
self.config = config
|
||||
self.renderer = Render(container, filename, report)
|
||||
self.client = get_ado_client(self.config.base_url, self.config.auth_token)
|
||||
self.project = self.render(self.config.project)
|
||||
|
||||
def render(self, template: str) -> str:
|
||||
return self.renderer.render(template)
|
||||
|
||||
def existing_work_items(self) -> Iterator[WorkItem]:
|
||||
filters = {}
|
||||
for key in self.config.unique_fields:
|
||||
if key == "System.TeamProject":
|
||||
value = self.render(self.config.project)
|
||||
else:
|
||||
value = self.render(self.config.ado_fields[key])
|
||||
filters[key.lower()] = value
|
||||
|
||||
valid_fields = get_valid_fields(
|
||||
self.client, project=filters.get("system.teamproject")
|
||||
)
|
||||
|
||||
post_query_filter = {}
|
||||
|
||||
# WIQL (Work Item Query Language) is an SQL like query language that
|
||||
# doesn't support query params, safe quoting, or any other SQL-injection
|
||||
# protection mechanisms.
|
||||
#
|
||||
# As such, build the WIQL with a those fields we can pre-determine are
|
||||
# "safe" and otherwise use post-query filtering.
|
||||
parts = []
|
||||
for k, v in filters.items():
|
||||
# Only add pre-system approved fields to the query
|
||||
if k not in valid_fields:
|
||||
post_query_filter[k] = v
|
||||
continue
|
||||
|
||||
# WIQL supports wrapping values in ' or " and escaping ' by doubling it
|
||||
#
|
||||
# For this System.Title: hi'there
|
||||
# use this query fragment: [System.Title] = 'hi''there'
|
||||
#
|
||||
# For this System.Title: hi"there
|
||||
# use this query fragment: [System.Title] = 'hi"there'
|
||||
#
|
||||
# For this System.Title: hi'"there
|
||||
# use this query fragment: [System.Title] = 'hi''"there'
|
||||
SINGLE = "'"
|
||||
parts.append("[%s] = '%s'" % (k, v.replace(SINGLE, SINGLE + SINGLE)))
|
||||
|
||||
query = "select [System.Id] from WorkItems"
|
||||
if parts:
|
||||
query += " where " + " AND ".join(parts)
|
||||
|
||||
wiql = Wiql(query=query)
|
||||
for entry in self.client.query_by_wiql(wiql).work_items:
|
||||
item = self.client.get_work_item(entry.id, expand="Fields")
|
||||
lowered_fields = {x.lower(): str(y) for (x, y) in item.fields.items()}
|
||||
if post_query_filter and not all(
|
||||
[
|
||||
k.lower() in lowered_fields and lowered_fields[k.lower()] == v
|
||||
for (k, v) in post_query_filter.items()
|
||||
]
|
||||
):
|
||||
continue
|
||||
yield item
|
||||
|
||||
def update_existing(self, item: WorkItem) -> None:
|
||||
if self.config.on_duplicate.comment:
|
||||
comment = self.render(self.config.on_duplicate.comment)
|
||||
self.client.add_comment(
|
||||
CommentCreate(comment),
|
||||
self.project,
|
||||
item.id,
|
||||
)
|
||||
|
||||
document = []
|
||||
for field in self.config.on_duplicate.increment:
|
||||
value = int(item.fields[field]) if field in item.fields else 0
|
||||
value += 1
|
||||
document.append(
|
||||
JsonPatchOperation(
|
||||
op="Replace", path="/fields/%s" % field, value=str(value)
|
||||
)
|
||||
)
|
||||
|
||||
for field in self.config.on_duplicate.ado_fields:
|
||||
field_value = self.render(self.config.on_duplicate.ado_fields[field])
|
||||
document.append(
|
||||
JsonPatchOperation(
|
||||
op="Replace", path="/fields/%s" % field, value=field_value
|
||||
)
|
||||
)
|
||||
|
||||
if item.fields["System.State"] in self.config.on_duplicate.set_state:
|
||||
document.append(
|
||||
JsonPatchOperation(
|
||||
op="Replace",
|
||||
path="/fields/System.State",
|
||||
value=self.config.on_duplicate.set_state[
|
||||
item.fields["System.State"]
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
if document:
|
||||
self.client.update_work_item(document, item.id, project=self.project)
|
||||
|
||||
def create_new(self) -> None:
|
||||
task_type = self.render(self.config.type)
|
||||
|
||||
document = []
|
||||
if "System.Tags" not in self.config.ado_fields:
|
||||
document.append(
|
||||
JsonPatchOperation(
|
||||
op="Add", path="/fields/System.Tags", value="Onefuzz"
|
||||
)
|
||||
)
|
||||
|
||||
for field in self.config.ado_fields:
|
||||
value = self.render(self.config.ado_fields[field])
|
||||
if field == "System.Tags":
|
||||
value += ";Onefuzz"
|
||||
document.append(
|
||||
JsonPatchOperation(op="Add", path="/fields/%s" % field, value=value)
|
||||
)
|
||||
|
||||
entry = self.client.create_work_item(
|
||||
document=document, project=self.project, type=task_type
|
||||
)
|
||||
|
||||
if self.config.comment:
|
||||
comment = self.render(self.config.comment)
|
||||
self.client.add_comment(
|
||||
CommentCreate(comment),
|
||||
self.project,
|
||||
entry.id,
|
||||
)
|
||||
|
||||
def process(self) -> None:
|
||||
seen = False
|
||||
for work_item in self.existing_work_items():
|
||||
self.update_existing(work_item)
|
||||
seen = True
|
||||
|
||||
if not seen:
|
||||
self.create_new()
|
||||
|
||||
|
||||
def notify_ado(
|
||||
config: ADOTemplate, container: str, filename: str, report: Report
|
||||
) -> None:
|
||||
try:
|
||||
ado = ADO(container, filename, config, report)
|
||||
ado.process()
|
||||
except AzureDevOpsServiceError as err:
|
||||
logging.error("ADO report failed: %s", err)
|
||||
except ValueError as err:
|
||||
logging.error("ADO report value error: %s", err)
|
61
src/api-service/__app__/onefuzzlib/notifications/common.py
Normal file
61
src/api-service/__app__/onefuzzlib/notifications/common.py
Normal file
@ -0,0 +1,61 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from jinja2.sandbox import SandboxedEnvironment
|
||||
from onefuzztypes.models import Report
|
||||
|
||||
from ..azure.containers import auth_download_url
|
||||
from ..jobs import Job
|
||||
from ..tasks.config import get_setup_container
|
||||
from ..tasks.main import Task
|
||||
|
||||
|
||||
class Render:
|
||||
def __init__(self, container: str, filename: str, report: Report):
|
||||
self.report = report
|
||||
self.container = container
|
||||
self.filename = filename
|
||||
task = Task.get(report.job_id, report.task_id)
|
||||
if not task:
|
||||
raise ValueError
|
||||
job = Job.get(report.job_id)
|
||||
if not job:
|
||||
raise ValueError
|
||||
|
||||
self.task_config = task.config
|
||||
self.job_config = job.config
|
||||
self.env = SandboxedEnvironment()
|
||||
|
||||
self.target_url: Optional[str] = None
|
||||
setup_container = get_setup_container(task.config)
|
||||
if setup_container:
|
||||
self.target_url = auth_download_url(
|
||||
setup_container, self.report.executable.replace("setup/", "", 1)
|
||||
)
|
||||
|
||||
self.report_url = auth_download_url(container, filename)
|
||||
self.input_url: Optional[str] = None
|
||||
if self.report.input_blob:
|
||||
self.input_url = auth_download_url(
|
||||
self.report.input_blob.container, self.report.input_blob.name
|
||||
)
|
||||
|
||||
def render(self, template: str) -> str:
|
||||
return self.env.from_string(template).render(
|
||||
{
|
||||
"report": self.report,
|
||||
"task": self.task_config,
|
||||
"job": self.job_config,
|
||||
"report_url": self.report_url,
|
||||
"input_url": self.input_url,
|
||||
"target_url": self.target_url,
|
||||
"report_container": self.container,
|
||||
"report_filename": self.filename,
|
||||
"repro_cmd": "onefuzz repro create_and_connect %s %s"
|
||||
% (self.container, self.filename),
|
||||
}
|
||||
)
|
122
src/api-service/__app__/onefuzzlib/notifications/main.py
Normal file
122
src/api-service/__app__/onefuzzlib/notifications/main.py
Normal file
@ -0,0 +1,122 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||
from uuid import UUID
|
||||
|
||||
from memoization import cached
|
||||
from onefuzztypes import models
|
||||
from onefuzztypes.enums import ErrorCode, TaskState
|
||||
from onefuzztypes.models import ADOTemplate, Error, NotificationTemplate, TeamsTemplate
|
||||
from onefuzztypes.primitives import Container, Event
|
||||
|
||||
from ..azure.containers import get_container_metadata, get_file_sas_url
|
||||
from ..azure.creds import get_fuzz_storage
|
||||
from ..azure.queue import send_message
|
||||
from ..dashboard import add_event
|
||||
from ..orm import ORMMixin
|
||||
from ..reports import get_report
|
||||
from ..tasks.config import get_input_container_queues
|
||||
from ..tasks.main import Task
|
||||
from .ado import notify_ado
|
||||
from .teams import notify_teams
|
||||
|
||||
|
||||
class Notification(models.Notification, ORMMixin):
|
||||
@classmethod
|
||||
def get_by_id(cls, notification_id: UUID) -> Union[Error, "Notification"]:
|
||||
notifications = cls.search(query={"notification_id": [notification_id]})
|
||||
if not notifications:
|
||||
return Error(
|
||||
code=ErrorCode.INVALID_REQUEST, errors=["unable to find Notification"]
|
||||
)
|
||||
|
||||
if len(notifications) != 1:
|
||||
return Error(
|
||||
code=ErrorCode.INVALID_REQUEST,
|
||||
errors=["error identifying Notification"],
|
||||
)
|
||||
notification = notifications[0]
|
||||
return notification
|
||||
|
||||
@classmethod
|
||||
def get_existing(
|
||||
cls, container: Container, config: NotificationTemplate
|
||||
) -> Optional["Notification"]:
|
||||
notifications = Notification.search(query={"container": [container]})
|
||||
for notification in notifications:
|
||||
if notification.config == config:
|
||||
return notification
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def key_fields(cls) -> Tuple[str, str]:
|
||||
return ("notification_id", "container")
|
||||
|
||||
|
||||
@cached(ttl=10)
|
||||
def get_notifications(container: Container) -> List[Notification]:
|
||||
return Notification.search(query={"container": [container]})
|
||||
|
||||
|
||||
@cached(ttl=10)
|
||||
def get_queue_tasks() -> Sequence[Tuple[Task, Sequence[str]]]:
|
||||
results = []
|
||||
for task in Task.search_states(states=TaskState.available()):
|
||||
containers = get_input_container_queues(task.config)
|
||||
if containers:
|
||||
results.append((task, containers))
|
||||
return results
|
||||
|
||||
|
||||
@cached(ttl=60)
|
||||
def container_metadata(container: Container) -> Optional[Dict[str, str]]:
|
||||
return get_container_metadata(container)
|
||||
|
||||
|
||||
def new_files(container: Container, filename: str) -> None:
|
||||
results: Dict[str, Event] = {"container": container, "file": filename}
|
||||
|
||||
metadata = container_metadata(container)
|
||||
if metadata:
|
||||
results["metadata"] = metadata
|
||||
|
||||
notifications = get_notifications(container)
|
||||
if notifications:
|
||||
report = get_report(container, filename)
|
||||
if report:
|
||||
results["executable"] = report.executable
|
||||
results["crash_type"] = report.crash_type
|
||||
results["crash_site"] = report.crash_site
|
||||
results["job_id"] = report.job_id
|
||||
results["task_id"] = report.task_id
|
||||
|
||||
logging.info("notifications for %s %s %s", container, filename, notifications)
|
||||
done = []
|
||||
for notification in notifications:
|
||||
# ignore duplicate configurations
|
||||
if notification.config in done:
|
||||
continue
|
||||
done.append(notification.config)
|
||||
|
||||
if isinstance(notification.config, TeamsTemplate):
|
||||
notify_teams(notification.config, container, filename, report)
|
||||
|
||||
if not report:
|
||||
continue
|
||||
|
||||
if isinstance(notification.config, ADOTemplate):
|
||||
notify_ado(notification.config, container, filename, report)
|
||||
|
||||
for (task, containers) in get_queue_tasks():
|
||||
if container in containers:
|
||||
logging.info("queuing input %s %s %s", container, filename, task.task_id)
|
||||
url = get_file_sas_url(container, filename, read=True, delete=True)
|
||||
send_message(
|
||||
task.task_id, bytes(url, "utf-8"), account_id=get_fuzz_storage()
|
||||
)
|
||||
|
||||
add_event("new_file", results)
|
127
src/api-service/__app__/onefuzzlib/notifications/teams.py
Normal file
127
src/api-service/__app__/onefuzzlib/notifications/teams.py
Normal file
@ -0,0 +1,127 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
from onefuzztypes.models import Report, TeamsTemplate
|
||||
|
||||
from ..azure.containers import auth_download_url
|
||||
from ..tasks.config import get_setup_container
|
||||
from ..tasks.main import Task
|
||||
|
||||
|
||||
def markdown_escape(data: str) -> str:
|
||||
values = "\\*_{}[]()#+-.!"
|
||||
for value in values:
|
||||
data = data.replace(value, "\\" + value)
|
||||
data = data.replace("`", "``")
|
||||
return data
|
||||
|
||||
|
||||
def code_block(data: str) -> str:
|
||||
data = data.replace("`", "``")
|
||||
return "\n```%s\n```\n" % data
|
||||
|
||||
|
||||
def send_teams_webhook(
|
||||
config: TeamsTemplate,
|
||||
title: str,
|
||||
facts: List[Dict[str, str]],
|
||||
text: Optional[str],
|
||||
) -> None:
|
||||
title = markdown_escape(title)
|
||||
|
||||
message: Dict[str, Any] = {
|
||||
"@type": "MessageCard",
|
||||
"@context": "https://schema.org/extensions",
|
||||
"summary": title,
|
||||
"sections": [{"activityTitle": title, "facts": facts}],
|
||||
}
|
||||
|
||||
if text:
|
||||
message["sections"].append({"text": text})
|
||||
|
||||
response = requests.post(config.url, json=message)
|
||||
if not response.ok:
|
||||
logging.error("webhook failed %s %s", response.status_code, response.content)
|
||||
|
||||
|
||||
def notify_teams(
|
||||
config: TeamsTemplate, container: str, filename: str, report: Optional[Report]
|
||||
) -> None:
|
||||
text = None
|
||||
facts: List[Dict[str, str]] = []
|
||||
|
||||
if report:
|
||||
task = Task.get(report.job_id, report.task_id)
|
||||
if not task:
|
||||
logging.error(
|
||||
"report with invalid task %s:%s", report.job_id, report.task_id
|
||||
)
|
||||
return
|
||||
|
||||
title = "new crash in %s: %s @ %s" % (
|
||||
report.executable,
|
||||
report.crash_type,
|
||||
report.crash_site,
|
||||
)
|
||||
|
||||
links = [
|
||||
"[report](%s)" % auth_download_url(container, filename),
|
||||
]
|
||||
|
||||
setup_container = get_setup_container(task.config)
|
||||
if setup_container:
|
||||
links.append(
|
||||
"[executable](%s)"
|
||||
% auth_download_url(
|
||||
setup_container,
|
||||
report.executable.replace("setup/", "", 1),
|
||||
),
|
||||
)
|
||||
|
||||
if report.input_blob:
|
||||
links.append(
|
||||
"[input](%s)"
|
||||
% auth_download_url(
|
||||
report.input_blob.container, report.input_blob.name
|
||||
),
|
||||
)
|
||||
|
||||
facts += [
|
||||
{"name": "Files", "value": " | ".join(links)},
|
||||
{
|
||||
"name": "Task",
|
||||
"value": markdown_escape(
|
||||
"job_id: %s task_id: %s" % (report.job_id, report.task_id)
|
||||
),
|
||||
},
|
||||
{
|
||||
"name": "Repro",
|
||||
"value": code_block(
|
||||
"onefuzz repro create_and_connect %s %s" % (container, filename)
|
||||
),
|
||||
},
|
||||
]
|
||||
|
||||
text = "## Call Stack\n" + "\n".join(code_block(x) for x in report.call_stack)
|
||||
|
||||
else:
|
||||
title = "new file found"
|
||||
facts += [
|
||||
{
|
||||
"name": "file",
|
||||
"value": "[%s/%s](%s)"
|
||||
% (
|
||||
markdown_escape(container),
|
||||
markdown_escape(filename),
|
||||
auth_download_url(container, filename),
|
||||
),
|
||||
}
|
||||
]
|
||||
|
||||
send_teams_webhook(config, title, facts, text)
|
435
src/api-service/__app__/onefuzzlib/orm.py
Normal file
435
src/api-service/__app__/onefuzzlib/orm.py
Normal file
@ -0,0 +1,435 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import inspect
|
||||
import json
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from uuid import UUID
|
||||
|
||||
from azure.common import AzureConflictHttpError, AzureMissingResourceHttpError
|
||||
from onefuzztypes.enums import (
|
||||
ErrorCode,
|
||||
JobState,
|
||||
NodeState,
|
||||
PoolState,
|
||||
ScalesetState,
|
||||
TaskState,
|
||||
TelemetryEvent,
|
||||
UpdateType,
|
||||
VmState,
|
||||
)
|
||||
from onefuzztypes.models import Error
|
||||
from onefuzztypes.primitives import Container, PoolName, Region
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .azure.table import get_client
|
||||
from .dashboard import add_event
|
||||
from .telemetry import track_event_filtered
|
||||
from .updates import queue_update
|
||||
|
||||
A = TypeVar("A", bound="ORMMixin")
|
||||
|
||||
QUERY_VALUE_TYPES = Union[
|
||||
List[int],
|
||||
List[str],
|
||||
List[UUID],
|
||||
List[Region],
|
||||
List[Container],
|
||||
List[PoolName],
|
||||
List[VmState],
|
||||
List[ScalesetState],
|
||||
List[JobState],
|
||||
List[TaskState],
|
||||
List[PoolState],
|
||||
List[NodeState],
|
||||
]
|
||||
QueryFilter = Dict[str, QUERY_VALUE_TYPES]
|
||||
|
||||
SAFE_STRINGS = (UUID, Container, Region, PoolName)
|
||||
KEY = Union[int, str, UUID, Enum]
|
||||
|
||||
QUEUE_DELAY_STOPPING_SECONDS = 30
|
||||
QUEUE_DELAY_CREATE_SECONDS = 5
|
||||
HOURS = 60 * 60
|
||||
|
||||
|
||||
def resolve(key: KEY) -> str:
|
||||
if isinstance(key, str):
|
||||
return key
|
||||
elif isinstance(key, UUID):
|
||||
return str(key)
|
||||
elif isinstance(key, Enum):
|
||||
return key.name
|
||||
elif isinstance(key, int):
|
||||
return str(key)
|
||||
raise NotImplementedError("unsupported type %s - %s" % (type(key), repr(key)))
|
||||
|
||||
|
||||
def build_filters(
|
||||
cls: Type[A], query_args: Optional[QueryFilter]
|
||||
) -> Tuple[Optional[str], QueryFilter]:
|
||||
if not query_args:
|
||||
return (None, {})
|
||||
|
||||
partition_key_field, row_key_field = cls.key_fields()
|
||||
|
||||
search_filter_parts = []
|
||||
post_filters: QueryFilter = {}
|
||||
|
||||
for field, values in query_args.items():
|
||||
if field not in cls.__fields__:
|
||||
raise ValueError("unexpected field %s: %s" % (repr(field), cls))
|
||||
|
||||
if not values:
|
||||
continue
|
||||
|
||||
if field == partition_key_field:
|
||||
field_name = "PartitionKey"
|
||||
elif field == row_key_field:
|
||||
field_name = "RowKey"
|
||||
else:
|
||||
field_name = field
|
||||
|
||||
parts: Optional[List[str]] = None
|
||||
if isinstance(values[0], int):
|
||||
parts = []
|
||||
for x in values:
|
||||
if not isinstance(x, int):
|
||||
raise TypeError("unexpected type")
|
||||
parts.append("%s eq %d" % (field_name, x))
|
||||
elif isinstance(values[0], Enum):
|
||||
parts = []
|
||||
for x in values:
|
||||
if not isinstance(x, Enum):
|
||||
raise TypeError("unexpected type")
|
||||
parts.append("%s eq '%s'" % (field_name, x.name))
|
||||
elif all(isinstance(x, SAFE_STRINGS) for x in values):
|
||||
parts = ["%s eq '%s'" % (field_name, x) for x in values]
|
||||
else:
|
||||
post_filters[field_name] = values
|
||||
|
||||
if parts:
|
||||
if len(parts) == 1:
|
||||
search_filter_parts.append(parts[0])
|
||||
else:
|
||||
search_filter_parts.append("(" + " or ".join(parts) + ")")
|
||||
|
||||
if search_filter_parts:
|
||||
return (" and ".join(search_filter_parts), post_filters)
|
||||
|
||||
return (None, post_filters)
|
||||
|
||||
|
||||
def post_filter(value: Any, filters: Optional[QueryFilter]) -> bool:
|
||||
if not filters:
|
||||
return True
|
||||
|
||||
for field in filters:
|
||||
if field not in value:
|
||||
return False
|
||||
if value[field] not in filters[field]:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
MappingIntStrAny = Mapping[Union[int, str], Any]
|
||||
# A = TypeVar("A", bound="Model")
|
||||
|
||||
|
||||
class ModelMixin(BaseModel):
|
||||
def export_exclude(self) -> Optional[MappingIntStrAny]:
|
||||
return None
|
||||
|
||||
def raw(
|
||||
self,
|
||||
*,
|
||||
by_alias: bool = False,
|
||||
exclude_none: bool = False,
|
||||
exclude: MappingIntStrAny = None,
|
||||
include: MappingIntStrAny = None,
|
||||
) -> Dict[str, Any]:
|
||||
# cycling through json means all wrapped types get resolved, such as UUID
|
||||
result: Dict[str, Any] = json.loads(
|
||||
self.json(
|
||||
by_alias=by_alias,
|
||||
exclude_none=exclude_none,
|
||||
exclude=exclude,
|
||||
include=include,
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class ORMMixin(ModelMixin):
|
||||
Timestamp: Optional[datetime] = Field(alias="Timestamp")
|
||||
etag: Optional[str]
|
||||
|
||||
@classmethod
|
||||
def table_name(cls: Type[A]) -> str:
|
||||
return cls.__name__
|
||||
|
||||
@classmethod
|
||||
def get(
|
||||
cls: Type[A], PartitionKey: KEY, RowKey: Optional[KEY] = None
|
||||
) -> Optional[A]:
|
||||
client = get_client(table=cls.table_name())
|
||||
partition_key = resolve(PartitionKey)
|
||||
row_key = resolve(RowKey) if RowKey else partition_key
|
||||
|
||||
try:
|
||||
raw = client.get_entity(cls.table_name(), partition_key, row_key)
|
||||
except AzureMissingResourceHttpError:
|
||||
return None
|
||||
return cls.load(raw)
|
||||
|
||||
@classmethod
|
||||
def key_fields(cls) -> Tuple[str, Optional[str]]:
|
||||
raise NotImplementedError("keys not defined")
|
||||
|
||||
# FILTERS:
|
||||
# The following
|
||||
# * save_exclude: Specify fields to *exclude* from saving to Storage Tables
|
||||
# * export_exclude: Specify the fields to *exclude* from sending to an external API
|
||||
# * telemetry_include: Specify the fields to *include* for telemetry
|
||||
#
|
||||
# For implementation details see:
|
||||
# https://pydantic-docs.helpmanual.io/usage/exporting_models/#advanced-include-and-exclude
|
||||
def save_exclude(self) -> Optional[MappingIntStrAny]:
|
||||
return None
|
||||
|
||||
def export_exclude(self) -> Optional[MappingIntStrAny]:
|
||||
return {"etag": ..., "Timestamp": ...}
|
||||
|
||||
def telemetry_include(self) -> Optional[MappingIntStrAny]:
|
||||
return {}
|
||||
|
||||
def event_include(self) -> Optional[MappingIntStrAny]:
|
||||
return {}
|
||||
|
||||
def event(self) -> Any:
|
||||
return self.raw(exclude_none=True, include=self.event_include())
|
||||
|
||||
def telemetry(self) -> Any:
|
||||
return self.raw(exclude_none=True, include=self.telemetry_include())
|
||||
|
||||
def _queue_as_needed(self) -> None:
|
||||
# Upon ORM save with state, if the object has a state that needs work,
|
||||
# automatically queue it
|
||||
state = getattr(self, "state", None)
|
||||
if state is None:
|
||||
return
|
||||
needs_work = getattr(state, "needs_work", None)
|
||||
if needs_work is None:
|
||||
return
|
||||
if state not in needs_work():
|
||||
return
|
||||
if state.name in ["stopping", "stop", "shutdown"]:
|
||||
self.queue(visibility_timeout=QUEUE_DELAY_STOPPING_SECONDS)
|
||||
else:
|
||||
self.queue(visibility_timeout=QUEUE_DELAY_CREATE_SECONDS)
|
||||
|
||||
def _event_as_needed(self) -> None:
|
||||
# Upon ORM save, if the object returns event data, we'll send it to the
|
||||
# dashboard event subsystem
|
||||
data = self.event()
|
||||
if not data:
|
||||
return
|
||||
add_event(self.table_name(), data)
|
||||
|
||||
def get_keys(self) -> Tuple[KEY, KEY]:
|
||||
partition_key_field, row_key_field = self.key_fields()
|
||||
|
||||
partition_key = getattr(self, partition_key_field)
|
||||
if row_key_field:
|
||||
row_key = getattr(self, row_key_field)
|
||||
else:
|
||||
row_key = partition_key
|
||||
|
||||
return (partition_key, row_key)
|
||||
|
||||
def save(self, new: bool = False, require_etag: bool = False) -> Optional[Error]:
|
||||
# TODO: migrate to an inspect.signature() model
|
||||
raw = self.raw(by_alias=True, exclude_none=True, exclude=self.save_exclude())
|
||||
for key in raw:
|
||||
if not isinstance(raw[key], (str, int)):
|
||||
raw[key] = json.dumps(raw[key])
|
||||
|
||||
# for datetime fields that passed through filtering, use the real value,
|
||||
# rather than a serialized form
|
||||
for field in self.__fields__:
|
||||
if field not in raw:
|
||||
continue
|
||||
if self.__fields__[field].type_ == datetime:
|
||||
raw[field] = getattr(self, field)
|
||||
|
||||
partition_key_field, row_key_field = self.key_fields()
|
||||
|
||||
# PartitionKey and RowKey must be 'str'
|
||||
raw["PartitionKey"] = resolve(raw[partition_key_field])
|
||||
raw["RowKey"] = resolve(raw[row_key_field or partition_key_field])
|
||||
|
||||
del raw[partition_key_field]
|
||||
if row_key_field in raw:
|
||||
del raw[row_key_field]
|
||||
|
||||
client = get_client(table=self.table_name())
|
||||
|
||||
# never save the timestamp
|
||||
if "Timestamp" in raw:
|
||||
del raw["Timestamp"]
|
||||
|
||||
if new:
|
||||
try:
|
||||
self.etag = client.insert_entity(self.table_name(), raw)
|
||||
except AzureConflictHttpError:
|
||||
return Error(code=ErrorCode.UNABLE_TO_CREATE, errors=["row exists"])
|
||||
elif self.etag and require_etag:
|
||||
self.etag = client.replace_entity(
|
||||
self.table_name(), raw, if_match=self.etag
|
||||
)
|
||||
else:
|
||||
self.etag = client.insert_or_replace_entity(self.table_name(), raw)
|
||||
|
||||
self._queue_as_needed()
|
||||
if self.table_name() in TelemetryEvent.__members__:
|
||||
telem = self.telemetry()
|
||||
if telem:
|
||||
track_event_filtered(TelemetryEvent[self.table_name()], telem)
|
||||
|
||||
self._event_as_needed()
|
||||
return None
|
||||
|
||||
def delete(self) -> None:
|
||||
# fire off an event so Signalr knows it's being deleted
|
||||
self._event_as_needed()
|
||||
|
||||
partition_key, row_key = self.get_keys()
|
||||
|
||||
client = get_client()
|
||||
try:
|
||||
client.delete_entity(
|
||||
self.table_name(), resolve(partition_key), resolve(row_key)
|
||||
)
|
||||
except AzureMissingResourceHttpError:
|
||||
# It's OK if the component is already deleted
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def load(cls: Type[A], data: Dict[str, Union[str, bytes, bytearray]]) -> A:
|
||||
partition_key_field, row_key_field = cls.key_fields()
|
||||
|
||||
if partition_key_field in data:
|
||||
raise Exception(
|
||||
"duplicate PartitionKey field %s for %s"
|
||||
% (partition_key_field, cls.table_name())
|
||||
)
|
||||
if row_key_field in data:
|
||||
raise Exception(
|
||||
"duplicate RowKey field %s for %s" % (row_key_field, cls.table_name())
|
||||
)
|
||||
|
||||
data[partition_key_field] = data["PartitionKey"]
|
||||
if row_key_field is not None:
|
||||
data[row_key_field] = data["RowKey"]
|
||||
|
||||
del data["PartitionKey"]
|
||||
del data["RowKey"]
|
||||
|
||||
for key in inspect.signature(cls).parameters:
|
||||
if key not in data:
|
||||
continue
|
||||
|
||||
annotation = inspect.signature(cls).parameters[key].annotation
|
||||
|
||||
if inspect.isclass(annotation):
|
||||
if issubclass(annotation, BaseModel) or issubclass(annotation, dict):
|
||||
data[key] = json.loads(data[key])
|
||||
continue
|
||||
|
||||
if getattr(annotation, "__origin__", None) == Union and any(
|
||||
inspect.isclass(x) and issubclass(x, BaseModel)
|
||||
for x in annotation.__args__
|
||||
):
|
||||
data[key] = json.loads(data[key])
|
||||
continue
|
||||
|
||||
# Required for Python >=3.7. In 3.6, a `Dict[_,_]` annotation is a class
|
||||
# according to `inspect.isclass`.
|
||||
if getattr(annotation, "__origin__", None) == dict:
|
||||
data[key] = json.loads(data[key])
|
||||
continue
|
||||
|
||||
return cls.parse_obj(data)
|
||||
|
||||
@classmethod
|
||||
def search(
|
||||
cls: Type[A],
|
||||
*,
|
||||
query: Optional[QueryFilter] = None,
|
||||
raw_unchecked_filter: Optional[str] = None,
|
||||
num_results: int = None,
|
||||
) -> List[A]:
|
||||
search_filter, post_filters = build_filters(cls, query)
|
||||
|
||||
if raw_unchecked_filter is not None:
|
||||
if search_filter is None:
|
||||
search_filter = raw_unchecked_filter
|
||||
else:
|
||||
search_filter = "(%s) and (%s)" % (search_filter, raw_unchecked_filter)
|
||||
|
||||
client = get_client(table=cls.table_name())
|
||||
entries = []
|
||||
for row in client.query_entities(
|
||||
cls.table_name(), filter=search_filter, num_results=num_results
|
||||
):
|
||||
if not post_filter(row, post_filters):
|
||||
continue
|
||||
|
||||
entry = cls.load(row)
|
||||
entries.append(entry)
|
||||
return entries
|
||||
|
||||
def queue(
|
||||
self,
|
||||
*,
|
||||
method: Optional[Callable] = None,
|
||||
visibility_timeout: Optional[int] = None,
|
||||
) -> None:
|
||||
if not hasattr(self, "state"):
|
||||
raise NotImplementedError("Queued an ORM mapping without State")
|
||||
|
||||
update_type = UpdateType.__members__.get(type(self).__name__)
|
||||
if update_type is None:
|
||||
raise NotImplementedError("unsupported update type: %s" % self)
|
||||
|
||||
method_name: Optional[str] = None
|
||||
if method is not None:
|
||||
if not hasattr(method, "__name__"):
|
||||
raise Exception("unable to queue method: %s" % method)
|
||||
method_name = method.__name__
|
||||
|
||||
partition_key, row_key = self.get_keys()
|
||||
|
||||
queue_update(
|
||||
update_type,
|
||||
resolve(partition_key),
|
||||
resolve(row_key),
|
||||
method=method_name,
|
||||
visibility_timeout=visibility_timeout,
|
||||
)
|
807
src/api-service/__app__/onefuzzlib/pools.py
Normal file
807
src/api-service/__app__/onefuzzlib/pools.py
Normal file
@ -0,0 +1,807 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from uuid import UUID
|
||||
|
||||
from onefuzztypes.enums import (
|
||||
OS,
|
||||
Architecture,
|
||||
ErrorCode,
|
||||
NodeState,
|
||||
PoolState,
|
||||
ScalesetState,
|
||||
TaskState,
|
||||
)
|
||||
from onefuzztypes.models import Error
|
||||
from onefuzztypes.models import Node as BASE_NODE
|
||||
from onefuzztypes.models import NodeCommand
|
||||
from onefuzztypes.models import NodeTasks as BASE_NODE_TASK
|
||||
from onefuzztypes.models import Pool as BASE_POOL
|
||||
from onefuzztypes.models import Scaleset as BASE_SCALESET
|
||||
from onefuzztypes.models import (
|
||||
ScalesetNodeState,
|
||||
ScalesetSummary,
|
||||
WorkSet,
|
||||
WorkSetSummary,
|
||||
WorkUnitSummary,
|
||||
)
|
||||
from onefuzztypes.primitives import PoolName, Region
|
||||
from pydantic import Field
|
||||
|
||||
from .azure.auth import build_auth
|
||||
from .azure.creds import get_fuzz_storage
|
||||
from .azure.image import get_os
|
||||
from .azure.network import Network
|
||||
from .azure.queue import create_queue, delete_queue, peek_queue, queue_object
|
||||
from .azure.table import get_client
|
||||
from .azure.vmss import (
|
||||
UnableToUpdate,
|
||||
create_vmss,
|
||||
delete_vmss,
|
||||
delete_vmss_nodes,
|
||||
get_instance_id,
|
||||
get_vmss,
|
||||
get_vmss_size,
|
||||
list_instance_ids,
|
||||
reimage_vmss_nodes,
|
||||
resize_vmss,
|
||||
update_extensions,
|
||||
)
|
||||
from .extension import fuzz_extensions
|
||||
from .orm import MappingIntStrAny, ORMMixin, QueryFilter
|
||||
|
||||
# Future work:
|
||||
#
|
||||
# Enabling autoscaling for the scalesets based on the pool work queues.
|
||||
# https://docs.microsoft.com/en-us/azure/azure-monitor/platform/autoscale-common-metrics#commonly-used-storage-metrics
|
||||
|
||||
|
||||
class Node(BASE_NODE, ORMMixin):
|
||||
@classmethod
|
||||
def search_states(
|
||||
cls,
|
||||
*,
|
||||
scaleset_id: Optional[UUID] = None,
|
||||
states: Optional[List[NodeState]] = None,
|
||||
pool_name: Optional[str] = None,
|
||||
) -> List["Node"]:
|
||||
query: QueryFilter = {}
|
||||
if scaleset_id:
|
||||
query["scaleset_id"] = [scaleset_id]
|
||||
if states:
|
||||
query["state"] = states
|
||||
if pool_name:
|
||||
query["pool_name"] = [pool_name]
|
||||
return cls.search(query=query)
|
||||
|
||||
@classmethod
|
||||
def get_by_machine_id(cls, machine_id: UUID) -> Optional["Node"]:
|
||||
nodes = cls.search(query={"machine_id": [machine_id]})
|
||||
if not nodes:
|
||||
return None
|
||||
|
||||
if len(nodes) != 1:
|
||||
return None
|
||||
return nodes[0]
|
||||
|
||||
@classmethod
|
||||
def key_fields(cls) -> Tuple[str, str]:
|
||||
return ("pool_name", "machine_id")
|
||||
|
||||
def save_exclude(self) -> Optional[MappingIntStrAny]:
|
||||
return {"tasks": ...}
|
||||
|
||||
def telemetry_include(self) -> Optional[MappingIntStrAny]:
|
||||
return {
|
||||
"machine_id": ...,
|
||||
"state": ...,
|
||||
"scaleset_id": ...,
|
||||
}
|
||||
|
||||
def event_include(self) -> Optional[MappingIntStrAny]:
|
||||
return {
|
||||
"pool_name": ...,
|
||||
"machine_id": ...,
|
||||
"state": ...,
|
||||
"scaleset_id": ...,
|
||||
}
|
||||
|
||||
def scaleset_node_exists(self) -> bool:
|
||||
if self.scaleset_id is None:
|
||||
return False
|
||||
|
||||
scaleset = Scaleset.get_by_id(self.scaleset_id)
|
||||
if not isinstance(scaleset, Scaleset):
|
||||
return False
|
||||
|
||||
instance_id = get_instance_id(scaleset.scaleset_id, self.machine_id)
|
||||
return isinstance(instance_id, str)
|
||||
|
||||
@classmethod
|
||||
def stop_task(cls, task_id: UUID) -> None:
|
||||
# For now, this just re-images the node. Eventually, this
|
||||
# should send a message to the node to let the agent shut down
|
||||
# gracefully
|
||||
nodes = NodeTasks.get_nodes_by_task_id(task_id)
|
||||
for node in nodes:
|
||||
if node.state not in NodeState.ready_for_reset():
|
||||
logging.info(
|
||||
"stopping task %s on machine_id:%s",
|
||||
task_id,
|
||||
node.machine_id,
|
||||
)
|
||||
node.state = NodeState.done
|
||||
node.save()
|
||||
|
||||
|
||||
class NodeTasks(BASE_NODE_TASK, ORMMixin):
|
||||
@classmethod
|
||||
def key_fields(cls) -> Tuple[str, str]:
|
||||
return ("machine_id", "task_id")
|
||||
|
||||
def telemetry_include(self) -> Optional[MappingIntStrAny]:
|
||||
return {
|
||||
"machine_id": ...,
|
||||
"task_id": ...,
|
||||
"state": ...,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_nodes_by_task_id(cls, task_id: UUID) -> List["Node"]:
|
||||
result = []
|
||||
for entry in cls.search(query={"task_id": [task_id]}):
|
||||
node = Node.get_by_machine_id(entry.machine_id)
|
||||
if node:
|
||||
result.append(node)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def get_by_machine_id(cls, machine_id: UUID) -> List["NodeTasks"]:
|
||||
return cls.search(query={"machine_id": [machine_id]})
|
||||
|
||||
@classmethod
|
||||
def get_by_task_id(cls, task_id: UUID) -> List["NodeTasks"]:
|
||||
return cls.search(query={"task_id": [task_id]})
|
||||
|
||||
|
||||
# this isn't anticipated to be needed by the client, hence it not
|
||||
# being in onefuzztypes
|
||||
class NodeMessage(ORMMixin):
|
||||
agent_id: UUID
|
||||
message_id: str = Field(default_factory=datetime.datetime.utcnow().timestamp)
|
||||
message: NodeCommand
|
||||
|
||||
@classmethod
|
||||
def key_fields(cls) -> Tuple[str, str]:
|
||||
return ("agent_id", "create_date")
|
||||
|
||||
@classmethod
|
||||
def get_messages(
|
||||
cls, agent_id: UUID, num_results: int = None
|
||||
) -> List["NodeMessage"]:
|
||||
entries: List["NodeMessage"] = cls.search(
|
||||
query={"agent_id": [agent_id]}, num_results=num_results
|
||||
)
|
||||
return entries
|
||||
|
||||
@classmethod
|
||||
def delete_messages(cls, agent_id: UUID, message_ids: List[str]) -> None:
|
||||
client = get_client(table=cls.table_name())
|
||||
batch = client.batch(table_name=cls.table_name())
|
||||
|
||||
for message_id in message_ids:
|
||||
batch.delete_entity(agent_id, message_id)
|
||||
|
||||
client.commit_batch(cls.table_name(), batch)
|
||||
|
||||
|
||||
class Pool(BASE_POOL, ORMMixin):
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
*,
|
||||
name: PoolName,
|
||||
os: OS,
|
||||
arch: Architecture,
|
||||
managed: bool,
|
||||
client_id: Optional[UUID],
|
||||
) -> "Pool":
|
||||
return cls(
|
||||
name=name,
|
||||
os=os,
|
||||
arch=arch,
|
||||
managed=managed,
|
||||
client_id=client_id,
|
||||
config=None,
|
||||
)
|
||||
|
||||
def save_exclude(self) -> Optional[MappingIntStrAny]:
|
||||
return {
|
||||
"nodes": ...,
|
||||
"queue": ...,
|
||||
"work_queue": ...,
|
||||
"config": ...,
|
||||
"node_summary": ...,
|
||||
}
|
||||
|
||||
def export_exclude(self) -> Optional[MappingIntStrAny]:
|
||||
return {
|
||||
"etag": ...,
|
||||
"timestamp": ...,
|
||||
}
|
||||
|
||||
def event_include(self) -> Optional[MappingIntStrAny]:
|
||||
return {
|
||||
"name": ...,
|
||||
"pool_id": ...,
|
||||
"os": ...,
|
||||
"state": ...,
|
||||
"managed": ...,
|
||||
}
|
||||
|
||||
def telemetry_include(self) -> Optional[MappingIntStrAny]:
|
||||
return {
|
||||
"pool_id": ...,
|
||||
"os": ...,
|
||||
"state": ...,
|
||||
"managed": ...,
|
||||
}
|
||||
|
||||
def populate_scaleset_summary(self) -> None:
|
||||
self.scaleset_summary = [
|
||||
ScalesetSummary(scaleset_id=x.scaleset_id, state=x.state)
|
||||
for x in Scaleset.search_by_pool(self.name)
|
||||
]
|
||||
|
||||
def populate_work_queue(self) -> None:
|
||||
self.work_queue = []
|
||||
|
||||
# Only populate the work queue summaries if the pool is initialized. We
|
||||
# can then be sure that the queue is available in the operations below.
|
||||
if self.state == PoolState.init:
|
||||
return
|
||||
|
||||
worksets = peek_queue(
|
||||
self.get_pool_queue(), account_id=get_fuzz_storage(), object_type=WorkSet
|
||||
)
|
||||
|
||||
for workset in worksets:
|
||||
work_units = [
|
||||
WorkUnitSummary(
|
||||
job_id=work_unit.job_id,
|
||||
task_id=work_unit.task_id,
|
||||
task_type=work_unit.task_type,
|
||||
)
|
||||
for work_unit in workset.work_units
|
||||
]
|
||||
self.work_queue.append(WorkSetSummary(work_units=work_units))
|
||||
|
||||
def get_pool_queue(self) -> str:
|
||||
return "pool-%s" % self.pool_id.hex
|
||||
|
||||
def init(self) -> None:
|
||||
create_queue(self.get_pool_queue(), account_id=get_fuzz_storage())
|
||||
self.state = PoolState.running
|
||||
self.save()
|
||||
|
||||
def schedule_workset(self, work_set: WorkSet) -> bool:
|
||||
# Don't schedule work for pools that can't and won't do work.
|
||||
if self.state in [PoolState.shutdown, PoolState.halt]:
|
||||
return False
|
||||
|
||||
return queue_object(
|
||||
self.get_pool_queue(), work_set, account_id=get_fuzz_storage()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_by_id(cls, pool_id: UUID) -> Union[Error, "Pool"]:
|
||||
pools = cls.search(query={"pool_id": [pool_id]})
|
||||
if not pools:
|
||||
return Error(code=ErrorCode.INVALID_REQUEST, errors=["unable to find pool"])
|
||||
|
||||
if len(pools) != 1:
|
||||
return Error(
|
||||
code=ErrorCode.INVALID_REQUEST, errors=["error identifying pool"]
|
||||
)
|
||||
pool = pools[0]
|
||||
return pool
|
||||
|
||||
@classmethod
|
||||
def get_by_name(cls, name: PoolName) -> Union[Error, "Pool"]:
|
||||
pools = cls.search(query={"name": [name]})
|
||||
if not pools:
|
||||
return Error(code=ErrorCode.INVALID_REQUEST, errors=["unable to find pool"])
|
||||
|
||||
if len(pools) != 1:
|
||||
return Error(
|
||||
code=ErrorCode.INVALID_REQUEST, errors=["error identifying pool"]
|
||||
)
|
||||
pool = pools[0]
|
||||
return pool
|
||||
|
||||
@classmethod
|
||||
def search_states(cls, *, states: Optional[List[PoolState]] = None) -> List["Pool"]:
|
||||
query: QueryFilter = {}
|
||||
if states:
|
||||
query["state"] = states
|
||||
return cls.search(query=query)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
""" shutdown allows nodes to finish current work then delete """
|
||||
scalesets = Scaleset.search_by_pool(self.name)
|
||||
nodes = Node.search(query={"pool_name": [self.name]})
|
||||
if not scalesets and not nodes:
|
||||
logging.info("pool stopped, deleting: %s", self.name)
|
||||
|
||||
self.state = PoolState.halt
|
||||
self.delete()
|
||||
return
|
||||
|
||||
for scaleset in scalesets:
|
||||
scaleset.state = ScalesetState.shutdown
|
||||
scaleset.save()
|
||||
|
||||
for node in nodes:
|
||||
node.state = NodeState.shutdown
|
||||
node.save()
|
||||
|
||||
self.save()
|
||||
|
||||
def halt(self) -> None:
|
||||
""" halt the pool immediately """
|
||||
scalesets = Scaleset.search_by_pool(self.name)
|
||||
nodes = Node.search(query={"pool_name": [self.name]})
|
||||
if not scalesets and not nodes:
|
||||
delete_queue(self.get_pool_queue(), account_id=get_fuzz_storage())
|
||||
logging.info("pool stopped, deleting: %s", self.name)
|
||||
self.state = PoolState.halt
|
||||
self.delete()
|
||||
return
|
||||
|
||||
for scaleset in scalesets:
|
||||
scaleset.state = ScalesetState.halt
|
||||
scaleset.save()
|
||||
|
||||
for node in nodes:
|
||||
logging.info(
|
||||
"deleting node from pool: %s (%s) - machine_id:%s",
|
||||
self.pool_id,
|
||||
self.name,
|
||||
node.machine_id,
|
||||
)
|
||||
node.delete()
|
||||
|
||||
self.save()
|
||||
|
||||
@classmethod
|
||||
def key_fields(cls) -> Tuple[str, str]:
|
||||
return ("name", "pool_id")
|
||||
|
||||
|
||||
class Scaleset(BASE_SCALESET, ORMMixin):
|
||||
def save_exclude(self) -> Optional[MappingIntStrAny]:
|
||||
return {"nodes": ...}
|
||||
|
||||
def event_include(self) -> Optional[MappingIntStrAny]:
|
||||
return {
|
||||
"pool_name": ...,
|
||||
"scaleset_id": ...,
|
||||
"state": ...,
|
||||
"os": ...,
|
||||
"size": ...,
|
||||
"error": ...,
|
||||
}
|
||||
|
||||
def telemetry_include(self) -> Optional[MappingIntStrAny]:
|
||||
return {
|
||||
"scaleset_id": ...,
|
||||
"os": ...,
|
||||
"vm_sku": ...,
|
||||
"size": ...,
|
||||
"spot_instances": ...,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
*,
|
||||
pool_name: PoolName,
|
||||
vm_sku: str,
|
||||
image: str,
|
||||
region: Region,
|
||||
size: int,
|
||||
spot_instances: bool,
|
||||
tags: Dict[str, str],
|
||||
client_id: Optional[UUID] = None,
|
||||
client_object_id: Optional[UUID] = None,
|
||||
) -> "Scaleset":
|
||||
return cls(
|
||||
pool_name=pool_name,
|
||||
vm_sku=vm_sku,
|
||||
image=image,
|
||||
region=region,
|
||||
size=size,
|
||||
spot_instances=spot_instances,
|
||||
auth=build_auth(),
|
||||
client_id=client_id,
|
||||
client_object_id=client_object_id,
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def search_by_pool(cls, pool_name: PoolName) -> List["Scaleset"]:
|
||||
return cls.search(query={"pool_name": [pool_name]})
|
||||
|
||||
@classmethod
|
||||
def get_by_id(cls, scaleset_id: UUID) -> Union[Error, "Scaleset"]:
|
||||
scalesets = cls.search(query={"scaleset_id": [scaleset_id]})
|
||||
if not scalesets:
|
||||
return Error(
|
||||
code=ErrorCode.INVALID_REQUEST, errors=["unable to find scaleset"]
|
||||
)
|
||||
|
||||
if len(scalesets) != 1:
|
||||
return Error(
|
||||
code=ErrorCode.INVALID_REQUEST, errors=["error identifying scaleset"]
|
||||
)
|
||||
scaleset = scalesets[0]
|
||||
return scaleset
|
||||
|
||||
@classmethod
|
||||
def get_by_object_id(cls, object_id: UUID) -> List["Scaleset"]:
|
||||
return cls.search(query={"client_object_id": [object_id]})
|
||||
|
||||
def init(self) -> None:
|
||||
logging.info("scaleset init: %s", self.scaleset_id)
|
||||
|
||||
# Handle the race condition between a pool being deleted and a
|
||||
# scaleset being added to the pool.
|
||||
pool = Pool.get_by_name(self.pool_name)
|
||||
if isinstance(pool, Error):
|
||||
self.error = pool
|
||||
self.state = ScalesetState.halt
|
||||
self.save()
|
||||
return
|
||||
|
||||
if pool.state == PoolState.init:
|
||||
logging.info(
|
||||
"scaleset waiting for pool: %s - %s", self.pool_name, self.scaleset_id
|
||||
)
|
||||
elif pool.state == PoolState.running:
|
||||
image_os = get_os(self.region, self.image)
|
||||
if isinstance(image_os, Error):
|
||||
self.error = image_os
|
||||
self.state = ScalesetState.creation_failed
|
||||
elif image_os != pool.os:
|
||||
self.error = Error(
|
||||
code=ErrorCode.INVALID_REQUEST,
|
||||
errors=["invalid os (got: %s needed: %s)" % (image_os, pool.os)],
|
||||
)
|
||||
self.state = ScalesetState.creation_failed
|
||||
else:
|
||||
self.state = ScalesetState.setup
|
||||
else:
|
||||
self.state = ScalesetState.setup
|
||||
|
||||
self.save()
|
||||
|
||||
def setup(self) -> None:
|
||||
# TODO: How do we pass in SSH configs for Windows? Previously
|
||||
# This was done as part of the generated per-task setup script.
|
||||
logging.info("scaleset setup: %s", self.scaleset_id)
|
||||
|
||||
network = Network(self.region)
|
||||
network_id = network.get_id()
|
||||
if not network_id:
|
||||
logging.info("creating network: %s", self.region)
|
||||
result = network.create()
|
||||
if isinstance(result, Error):
|
||||
self.error = result
|
||||
self.state = ScalesetState.creation_failed
|
||||
self.save()
|
||||
return
|
||||
|
||||
if self.auth is None:
|
||||
self.error = Error(
|
||||
code=ErrorCode.UNABLE_TO_CREATE, errors=["missing required auth"]
|
||||
)
|
||||
self.state = ScalesetState.creation_failed
|
||||
self.save()
|
||||
return
|
||||
|
||||
vmss = get_vmss(self.scaleset_id)
|
||||
if vmss is None:
|
||||
pool = Pool.get_by_name(self.pool_name)
|
||||
if isinstance(pool, Error):
|
||||
self.error = pool
|
||||
self.state = ScalesetState.halt
|
||||
self.save()
|
||||
return
|
||||
|
||||
logging.info("creating scaleset: %s", self.scaleset_id)
|
||||
extensions = fuzz_extensions(self.region, pool.os, self.pool_name)
|
||||
result = create_vmss(
|
||||
self.region,
|
||||
self.scaleset_id,
|
||||
self.vm_sku,
|
||||
self.size,
|
||||
self.image,
|
||||
network_id,
|
||||
self.spot_instances,
|
||||
extensions,
|
||||
self.auth.password,
|
||||
self.auth.public_key,
|
||||
self.tags,
|
||||
)
|
||||
if isinstance(result, Error):
|
||||
self.error = result
|
||||
logging.error(
|
||||
"stopping task because of failed vmss: %s %s",
|
||||
self.scaleset_id,
|
||||
result,
|
||||
)
|
||||
self.state = ScalesetState.creation_failed
|
||||
else:
|
||||
logging.info("creating scaleset: %s", self.scaleset_id)
|
||||
elif vmss.provisioning_state == "Creating":
|
||||
logging.info("Waiting on scaleset creation: %s", self.scaleset_id)
|
||||
else:
|
||||
logging.info("scaleset running: %s", self.scaleset_id)
|
||||
self.state = ScalesetState.running
|
||||
self.client_object_id = vmss.identity.principal_id
|
||||
self.save()
|
||||
|
||||
# result = 'did I modify the scaleset in azure'
|
||||
def cleanup_nodes(self) -> bool:
|
||||
if self.state == ScalesetState.halt:
|
||||
self.halt()
|
||||
return True
|
||||
|
||||
nodes = Node.search_states(
|
||||
scaleset_id=self.scaleset_id, states=NodeState.ready_for_reset()
|
||||
)
|
||||
if not nodes:
|
||||
logging.debug("scaleset node gc done (no nodes) %s", self.scaleset_id)
|
||||
return False
|
||||
|
||||
to_delete = []
|
||||
to_reimage = []
|
||||
|
||||
for node in nodes:
|
||||
# delete nodes that are not waiting on the scaleset GC
|
||||
if not node.scaleset_node_exists():
|
||||
node.delete()
|
||||
elif node.state in [NodeState.shutdown, NodeState.halt]:
|
||||
to_delete.append(node)
|
||||
else:
|
||||
to_reimage.append(node)
|
||||
|
||||
# Perform operations until they fail due to scaleset getting locked
|
||||
try:
|
||||
if to_delete:
|
||||
self.delete_nodes(to_delete)
|
||||
for node in to_delete:
|
||||
node.state = NodeState.halt
|
||||
node.save()
|
||||
|
||||
if to_reimage:
|
||||
self.reimage_nodes(to_reimage)
|
||||
except UnableToUpdate:
|
||||
logging.info("scaleset update already in progress: %s", self.scaleset_id)
|
||||
return True
|
||||
|
||||
def resize(self) -> None:
|
||||
logging.info(
|
||||
"scaleset resize: %s - current: %s new: %s",
|
||||
self.scaleset_id,
|
||||
self.size,
|
||||
self.new_size,
|
||||
)
|
||||
|
||||
# no work needed to resize
|
||||
if self.new_size is None:
|
||||
self.state = ScalesetState.running
|
||||
self.save()
|
||||
return
|
||||
|
||||
# just in case, always ensure size is within max capacity
|
||||
self.new_size = min(self.new_size, self.max_size())
|
||||
|
||||
# Treat Azure knowledge of the size of the scaleset as "ground truth"
|
||||
size = get_vmss_size(self.scaleset_id)
|
||||
if size is None:
|
||||
logging.info("scaleset is unavailable. Re-queuing")
|
||||
self.save()
|
||||
return
|
||||
|
||||
if size == self.new_size:
|
||||
# NOTE: this is the only place we reset to the 'running' state.
|
||||
# This ensures that our idea of scaleset size agrees with Azure
|
||||
node_count = len(Node.search_states(scaleset_id=self.scaleset_id))
|
||||
if node_count == self.size:
|
||||
logging.info("resize finished: %s", self.scaleset_id)
|
||||
self.new_size = None
|
||||
self.state = ScalesetState.running
|
||||
else:
|
||||
logging.info(
|
||||
"resize is finished, waiting for nodes to check in: "
|
||||
"%s (%d of %d nodes checked in)",
|
||||
self.scaleset_id,
|
||||
node_count,
|
||||
self.size,
|
||||
)
|
||||
# When adding capacity, call the resize API directly
|
||||
elif self.new_size > self.size:
|
||||
try:
|
||||
resize_vmss(self.scaleset_id, self.new_size)
|
||||
except UnableToUpdate:
|
||||
logging.info("scaleset is mid-operation already")
|
||||
# Shut down any nodes without work. Otherwise, rely on Scaleset.reimage_node
|
||||
# to pick up that the scaleset is too big upon task completion
|
||||
else:
|
||||
nodes = Node.search_states(
|
||||
scaleset_id=self.scaleset_id, states=[NodeState.init, NodeState.free]
|
||||
)
|
||||
for node in nodes:
|
||||
if size > self.new_size:
|
||||
node.state = NodeState.halt
|
||||
node.save()
|
||||
size -= 1
|
||||
else:
|
||||
break
|
||||
|
||||
self.save()
|
||||
|
||||
def delete_nodes(self, nodes: List[Node]) -> None:
|
||||
if not nodes:
|
||||
logging.debug("no nodes to delete")
|
||||
return
|
||||
|
||||
if self.state == ScalesetState.halt:
|
||||
logging.debug("scaleset delete will delete node: %s", self.scaleset_id)
|
||||
return
|
||||
|
||||
machine_ids = [x.machine_id for x in nodes]
|
||||
|
||||
logging.info("deleting %s:%s", self.scaleset_id, machine_ids)
|
||||
delete_vmss_nodes(self.scaleset_id, machine_ids)
|
||||
self.size -= len(machine_ids)
|
||||
self.save()
|
||||
|
||||
def reimage_nodes(self, nodes: List[Node]) -> None:
|
||||
from .tasks.main import Task
|
||||
|
||||
if not nodes:
|
||||
logging.debug("no nodes to reimage")
|
||||
return
|
||||
|
||||
for node in nodes:
|
||||
for entry in NodeTasks.get_by_machine_id(node.machine_id):
|
||||
task = Task.get_by_task_id(entry.task_id)
|
||||
if isinstance(task, Task):
|
||||
if task.state in [TaskState.stopping, TaskState.stopped]:
|
||||
continue
|
||||
|
||||
task.error = Error(
|
||||
code=ErrorCode.TASK_FAILED,
|
||||
errors=["node reimaged during task execution"],
|
||||
)
|
||||
task.state = TaskState.stopping
|
||||
task.save()
|
||||
entry.delete()
|
||||
|
||||
if self.state == ScalesetState.shutdown:
|
||||
self.delete_nodes(nodes)
|
||||
return
|
||||
|
||||
if self.state == ScalesetState.halt:
|
||||
logging.debug("scaleset delete will delete node: %s", self.scaleset_id)
|
||||
return
|
||||
|
||||
machine_ids = [x.machine_id for x in nodes]
|
||||
|
||||
result = reimage_vmss_nodes(self.scaleset_id, machine_ids)
|
||||
if isinstance(result, Error):
|
||||
raise Exception(
|
||||
"unable to reimage nodes: %s:%s - %s"
|
||||
% (self.scaleset_id, machine_ids, result)
|
||||
)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
logging.info("scaleset shutdown: %s", self.scaleset_id)
|
||||
size = get_vmss_size(self.scaleset_id)
|
||||
if size is None or size == 0:
|
||||
self.state = ScalesetState.halt
|
||||
self.halt()
|
||||
return
|
||||
self.save()
|
||||
|
||||
def halt(self) -> None:
|
||||
for node in Node.search_states(scaleset_id=self.scaleset_id):
|
||||
logging.info("deleting node %s:%s", self.scaleset_id, node.machine_id)
|
||||
node.delete()
|
||||
|
||||
vmss = get_vmss(self.scaleset_id)
|
||||
if vmss is None:
|
||||
logging.info("scaleset deleted: %s", self.scaleset_id)
|
||||
self.state = ScalesetState.halt
|
||||
self.delete()
|
||||
else:
|
||||
logging.info("scaleset deleting: %s", self.scaleset_id)
|
||||
delete_vmss(self.scaleset_id)
|
||||
self.save()
|
||||
|
||||
def max_size(self) -> int:
|
||||
# https://docs.microsoft.com/en-us/azure/virtual-machine-scale-sets/
|
||||
# virtual-machine-scale-sets-placement-groups#checklist-for-using-large-scale-sets
|
||||
if self.image.startswith("/"):
|
||||
return 600
|
||||
else:
|
||||
return 1000
|
||||
|
||||
@classmethod
|
||||
def search_states(
|
||||
cls, *, states: Optional[List[ScalesetState]] = None
|
||||
) -> List["Scaleset"]:
|
||||
query: QueryFilter = {}
|
||||
if states:
|
||||
query["state"] = states
|
||||
return cls.search(query=query)
|
||||
|
||||
def update_nodes(self) -> None:
|
||||
# Be in at-least 'setup' before checking for the list of VMs
|
||||
if self.state == self.init:
|
||||
return
|
||||
|
||||
nodes = Node.search_states(scaleset_id=self.scaleset_id)
|
||||
azure_nodes = list_instance_ids(self.scaleset_id)
|
||||
|
||||
self.nodes = []
|
||||
|
||||
for (machine_id, instance_id) in azure_nodes.items():
|
||||
node_state: Optional[ScalesetNodeState] = None
|
||||
for node in nodes:
|
||||
if node.machine_id == machine_id:
|
||||
node_state = ScalesetNodeState(
|
||||
machine_id=machine_id,
|
||||
instance_id=instance_id,
|
||||
state=node.state,
|
||||
)
|
||||
break
|
||||
if not node_state:
|
||||
node_state = ScalesetNodeState(
|
||||
machine_id=machine_id,
|
||||
instance_id=instance_id,
|
||||
)
|
||||
self.nodes.append(node_state)
|
||||
|
||||
def update_configs(self) -> None:
|
||||
if self.state != ScalesetState.running:
|
||||
logging.debug(
|
||||
"scaleset not running, not updating configs: %s", self.scaleset_id
|
||||
)
|
||||
return
|
||||
|
||||
pool = Pool.get_by_name(self.pool_name)
|
||||
if isinstance(pool, Error):
|
||||
self.error = pool
|
||||
return self.halt()
|
||||
|
||||
logging.debug("updating scaleset configs: %s", self.scaleset_id)
|
||||
extensions = fuzz_extensions(self.region, pool.os, self.pool_name)
|
||||
try:
|
||||
update_extensions(self.scaleset_id, extensions)
|
||||
except UnableToUpdate:
|
||||
logging.debug(
|
||||
"unable to update configs, update already in progress: %s",
|
||||
self.scaleset_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def key_fields(cls) -> Tuple[str, str]:
|
||||
return ("pool_name", "scaleset_id")
|
236
src/api-service/__app__/onefuzzlib/proxy.py
Normal file
236
src/api-service/__app__/onefuzzlib/proxy.py
Normal file
@ -0,0 +1,236 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from azure.mgmt.compute.models import VirtualMachine
|
||||
from onefuzztypes.enums import VmState
|
||||
from onefuzztypes.models import (
|
||||
Authentication,
|
||||
Error,
|
||||
Forward,
|
||||
ProxyConfig,
|
||||
ProxyHeartbeat,
|
||||
)
|
||||
from onefuzztypes.primitives import Region
|
||||
from pydantic import Field
|
||||
|
||||
from .__version__ import __version__
|
||||
from .azure.auth import build_auth
|
||||
from .azure.containers import get_file_sas_url, save_blob
|
||||
from .azure.ip import get_public_ip
|
||||
from .azure.queue import get_queue_sas
|
||||
from .azure.vm import VM
|
||||
from .extension import proxy_manager_extensions
|
||||
from .orm import HOURS, MappingIntStrAny, ORMMixin, QueryFilter
|
||||
from .proxy_forward import ProxyForward
|
||||
|
||||
PROXY_SKU = "Standard_B2s"
|
||||
PROXY_IMAGE = "Canonical:UbuntuServer:18.04-LTS:latest"
|
||||
|
||||
|
||||
# This isn't intended to ever be shared to the client, hence not being in
|
||||
# onefuzztypes
|
||||
class Proxy(ORMMixin):
|
||||
region: Region
|
||||
state: VmState = Field(default=VmState.init)
|
||||
auth: Authentication = Field(default_factory=build_auth)
|
||||
ip: Optional[str]
|
||||
error: Optional[str]
|
||||
version: str = Field(default=__version__)
|
||||
heartbeat: Optional[ProxyHeartbeat]
|
||||
|
||||
@classmethod
|
||||
def key_fields(cls) -> Tuple[str, Optional[str]]:
|
||||
return ("region", None)
|
||||
|
||||
def event_include(self) -> Optional[MappingIntStrAny]:
|
||||
return {
|
||||
"region": ...,
|
||||
"state": ...,
|
||||
"ip": ...,
|
||||
"error": ...,
|
||||
}
|
||||
|
||||
def get_vm(self) -> VM:
|
||||
vm = VM(
|
||||
name="proxy-%s" % self.region,
|
||||
region=self.region,
|
||||
sku=PROXY_SKU,
|
||||
image=PROXY_IMAGE,
|
||||
auth=self.auth,
|
||||
)
|
||||
return vm
|
||||
|
||||
def init(self) -> None:
|
||||
vm = self.get_vm()
|
||||
vm_data = vm.get()
|
||||
if vm_data:
|
||||
if vm_data.provisioning_state == "Failed":
|
||||
self.set_failed(vm)
|
||||
else:
|
||||
self.save_proxy_config()
|
||||
self.state = VmState.extensions_launch
|
||||
else:
|
||||
result = vm.create()
|
||||
if isinstance(result, Error):
|
||||
self.error = repr(result)
|
||||
self.state = VmState.stopping
|
||||
self.save()
|
||||
|
||||
def set_failed(self, vm_data: VirtualMachine) -> None:
|
||||
logging.error("vm failed to provision: %s", vm_data.name)
|
||||
for status in vm_data.instance_view.statuses:
|
||||
if status.level.name.lower() == "error":
|
||||
logging.error(
|
||||
"vm status: %s %s %s %s",
|
||||
vm_data.name,
|
||||
status.code,
|
||||
status.display_status,
|
||||
status.message,
|
||||
)
|
||||
self.state = VmState.vm_allocation_failed
|
||||
|
||||
def extensions_launch(self) -> None:
|
||||
vm = self.get_vm()
|
||||
vm_data = vm.get()
|
||||
if not vm_data:
|
||||
logging.error("Azure VM does not exist: %s", vm.name)
|
||||
self.state = VmState.stopping
|
||||
self.save()
|
||||
return
|
||||
|
||||
if vm_data.provisioning_state == "Failed":
|
||||
self.set_failed(vm_data)
|
||||
self.save()
|
||||
return
|
||||
|
||||
ip = get_public_ip(vm_data.network_profile.network_interfaces[0].id)
|
||||
if ip is None:
|
||||
self.save()
|
||||
return
|
||||
self.ip = ip
|
||||
|
||||
extensions = proxy_manager_extensions(self.region)
|
||||
result = vm.add_extensions(extensions)
|
||||
if isinstance(result, Error):
|
||||
logging.error("vm extension failed: %s", repr(result))
|
||||
self.error = repr(result)
|
||||
self.state = VmState.stopping
|
||||
elif result:
|
||||
self.state = VmState.running
|
||||
|
||||
self.save()
|
||||
|
||||
def stopping(self) -> None:
|
||||
vm = self.get_vm()
|
||||
if not vm.is_deleted():
|
||||
logging.info("stopping proxy: %s", self.region)
|
||||
vm.delete()
|
||||
self.save()
|
||||
else:
|
||||
self.stopped()
|
||||
|
||||
def stopped(self) -> None:
|
||||
logging.info("removing proxy: %s", self.region)
|
||||
self.delete()
|
||||
|
||||
def is_used(self) -> bool:
|
||||
if len(self.get_forwards()) == 0:
|
||||
logging.info("proxy has no forwards: %s", self.region)
|
||||
return False
|
||||
return True
|
||||
|
||||
def is_alive(self) -> bool:
|
||||
# Unfortunately, with and without TZ information is required for compare
|
||||
# or exceptions are generated
|
||||
ten_minutes_ago_no_tz = datetime.datetime.utcnow() - datetime.timedelta(
|
||||
minutes=10
|
||||
)
|
||||
ten_minutes_ago = ten_minutes_ago_no_tz.astimezone(datetime.timezone.utc)
|
||||
if (
|
||||
self.heartbeat is not None
|
||||
and self.heartbeat.timestamp < ten_minutes_ago_no_tz
|
||||
):
|
||||
logging.error(
|
||||
"proxy last heartbeat is more than an 10 minutes old: %s", self.region
|
||||
)
|
||||
return False
|
||||
|
||||
elif not self.heartbeat and self.Timestamp and self.Timestamp < ten_minutes_ago:
|
||||
logging.error(
|
||||
"proxy has no heartbeat in the last 10 minutes: %s", self.region
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_forwards(self) -> List[Forward]:
|
||||
forwards: List[Forward] = []
|
||||
for entry in ProxyForward.search_forward(region=self.region):
|
||||
if entry.endtime < datetime.datetime.now(tz=datetime.timezone.utc):
|
||||
entry.delete()
|
||||
else:
|
||||
forwards.append(
|
||||
Forward(
|
||||
src_port=entry.port,
|
||||
dst_ip=entry.dst_ip,
|
||||
dst_port=entry.dst_port,
|
||||
)
|
||||
)
|
||||
return forwards
|
||||
|
||||
def save_proxy_config(self) -> None:
|
||||
forwards = self.get_forwards()
|
||||
proxy_config = ProxyConfig(
|
||||
url=get_file_sas_url(
|
||||
"proxy-configs",
|
||||
"%s/config.json" % self.region,
|
||||
account_id=os.environ["ONEFUZZ_FUNC_STORAGE"],
|
||||
read=True,
|
||||
),
|
||||
notification=get_queue_sas(
|
||||
"proxy",
|
||||
account_id=os.environ["ONEFUZZ_FUNC_STORAGE"],
|
||||
add=True,
|
||||
),
|
||||
forwards=forwards,
|
||||
region=self.region,
|
||||
)
|
||||
|
||||
save_blob(
|
||||
"proxy-configs",
|
||||
"%s/config.json" % self.region,
|
||||
proxy_config.json(),
|
||||
account_id=os.environ["ONEFUZZ_FUNC_STORAGE"],
|
||||
)
|
||||
|
||||
def queue_stop(self, count: int) -> None:
|
||||
self.queue(method=self.stopping, visibility_timeout=count * HOURS)
|
||||
|
||||
@classmethod
|
||||
def search_states(cls, *, states: Optional[List[VmState]] = None) -> List["Proxy"]:
|
||||
query: QueryFilter = {}
|
||||
if states:
|
||||
query["state"] = states
|
||||
return cls.search(query=query)
|
||||
|
||||
@classmethod
|
||||
def get_or_create(cls, region: Region) -> Optional["Proxy"]:
|
||||
proxy = Proxy.get(region)
|
||||
if proxy is not None:
|
||||
if proxy.version != __version__:
|
||||
# If the proxy is out-of-date, delete and re-create it
|
||||
proxy.state = VmState.stopping
|
||||
proxy.save()
|
||||
return None
|
||||
return proxy
|
||||
|
||||
proxy = Proxy(region=region)
|
||||
proxy.save()
|
||||
return proxy
|
134
src/api-service/__app__/onefuzzlib/proxy_forward.py
Normal file
134
src/api-service/__app__/onefuzzlib/proxy_forward.py
Normal file
@ -0,0 +1,134 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from uuid import UUID
|
||||
|
||||
from onefuzztypes.enums import ErrorCode
|
||||
from onefuzztypes.models import Error, Forward
|
||||
from onefuzztypes.primitives import Region
|
||||
from pydantic import Field
|
||||
|
||||
from .azure.ip import get_scaleset_instance_ip
|
||||
from .orm import ORMMixin, QueryFilter
|
||||
|
||||
PORT_RANGES = range(6000, 7000)
|
||||
|
||||
|
||||
# This isn't intended to ever be shared to the client, hence not being in
|
||||
# onefuzztypes
|
||||
class ProxyForward(ORMMixin):
|
||||
region: Region
|
||||
port: int
|
||||
scaleset_id: UUID
|
||||
machine_id: UUID
|
||||
dst_ip: str
|
||||
dst_port: int
|
||||
endtime: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
|
||||
|
||||
@classmethod
|
||||
def key_fields(cls) -> Tuple[str, str]:
|
||||
return ("region", "port")
|
||||
|
||||
@classmethod
|
||||
def update_or_create(
|
||||
cls,
|
||||
region: Region,
|
||||
scaleset_id: UUID,
|
||||
machine_id: UUID,
|
||||
dst_port: int,
|
||||
duration: int,
|
||||
) -> Union["ProxyForward", Error]:
|
||||
private_ip = get_scaleset_instance_ip(scaleset_id, machine_id)
|
||||
if not private_ip:
|
||||
return Error(
|
||||
code=ErrorCode.UNABLE_TO_PORT_FORWARD, errors=["no private ip for node"]
|
||||
)
|
||||
|
||||
entries = cls.search_forward(
|
||||
scaleset_id=scaleset_id,
|
||||
machine_id=machine_id,
|
||||
dst_port=dst_port,
|
||||
region=region,
|
||||
)
|
||||
if entries:
|
||||
entry = entries[0]
|
||||
entry.endtime = datetime.datetime.utcnow() + datetime.timedelta(
|
||||
hours=duration
|
||||
)
|
||||
entry.save()
|
||||
return entry
|
||||
|
||||
existing = [int(x.port) for x in entries]
|
||||
for port in PORT_RANGES:
|
||||
if port in existing:
|
||||
continue
|
||||
|
||||
entry = cls(
|
||||
region=region,
|
||||
port=port,
|
||||
scaleset_id=scaleset_id,
|
||||
machine_id=machine_id,
|
||||
dst_ip=private_ip,
|
||||
dst_port=dst_port,
|
||||
endtime=datetime.datetime.utcnow() + datetime.timedelta(hours=duration),
|
||||
)
|
||||
result = entry.save(new=True)
|
||||
if isinstance(result, Error):
|
||||
logging.info("port is already used: %s", entry)
|
||||
continue
|
||||
|
||||
return entry
|
||||
|
||||
return Error(
|
||||
code=ErrorCode.UNABLE_TO_PORT_FORWARD, errors=["all forward ports used"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def remove_forward(
|
||||
cls,
|
||||
scaleset_id: UUID,
|
||||
*,
|
||||
machine_id: Optional[UUID] = None,
|
||||
dst_port: Optional[int] = None,
|
||||
) -> List[Region]:
|
||||
entries = cls.search_forward(
|
||||
scaleset_id=scaleset_id, machine_id=machine_id, dst_port=dst_port
|
||||
)
|
||||
regions = set()
|
||||
for entry in entries:
|
||||
regions.add(entry.region)
|
||||
entry.delete()
|
||||
return list(regions)
|
||||
|
||||
@classmethod
|
||||
def search_forward(
|
||||
cls,
|
||||
*,
|
||||
scaleset_id: Optional[UUID] = None,
|
||||
region: Optional[Region] = None,
|
||||
machine_id: Optional[UUID] = None,
|
||||
dst_port: Optional[int] = None,
|
||||
) -> List["ProxyForward"]:
|
||||
|
||||
query: QueryFilter = {}
|
||||
if region is not None:
|
||||
query["region"] = [region]
|
||||
|
||||
if scaleset_id is not None:
|
||||
query["scaleset_id"] = [scaleset_id]
|
||||
|
||||
if machine_id is not None:
|
||||
query["machine_id"] = [machine_id]
|
||||
|
||||
if dst_port is not None:
|
||||
query["dst_port"] = [dst_port]
|
||||
|
||||
return cls.search(query=query)
|
||||
|
||||
def to_forward(self) -> Forward:
|
||||
return Forward(src_port=self.port, dst_ip=self.dst_ip, dst_port=self.dst_port)
|
56
src/api-service/__app__/onefuzzlib/reports.py
Normal file
56
src/api-service/__app__/onefuzzlib/reports.py
Normal file
@ -0,0 +1,56 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from onefuzztypes.models import Report
|
||||
from pydantic import ValidationError
|
||||
|
||||
from .azure.containers import get_blob
|
||||
|
||||
|
||||
def parse_report(content: str, metadata: Optional[str] = None) -> Optional[Report]:
|
||||
if isinstance(content, bytes):
|
||||
try:
|
||||
content = content.decode()
|
||||
except UnicodeDecodeError as err:
|
||||
logging.error(
|
||||
"unable to parse report (%s): unicode decode of report failed - %s",
|
||||
metadata,
|
||||
err,
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
data = json.loads(content)
|
||||
except json.decoder.JSONDecodeError as err:
|
||||
logging.error(
|
||||
"unable to parse report (%s): json decoding failed - %s", metadata, err
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
entry = Report.parse_obj(data)
|
||||
except ValidationError as err:
|
||||
logging.error("unable to parse report (%s): %s", metadata, err)
|
||||
return None
|
||||
|
||||
return entry
|
||||
|
||||
|
||||
def get_report(container: str, filename: str) -> Optional[Report]:
|
||||
metadata = "/".join([container, filename])
|
||||
if not filename.endswith(".json"):
|
||||
logging.error("get_report invalid extension: %s", metadata)
|
||||
return None
|
||||
|
||||
blob = get_blob(container, filename)
|
||||
if blob is None:
|
||||
logging.error("get_report invalid blob: %s", metadata)
|
||||
return None
|
||||
|
||||
return parse_report(blob, metadata=metadata)
|
237
src/api-service/__app__/onefuzzlib/repro.py
Normal file
237
src/api-service/__app__/onefuzzlib/repro.py
Normal file
@ -0,0 +1,237 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from azure.mgmt.compute.models import VirtualMachine
|
||||
from onefuzztypes.enums import OS, ContainerType, ErrorCode, VmState
|
||||
from onefuzztypes.models import Error
|
||||
from onefuzztypes.models import Repro as BASE_REPRO
|
||||
from onefuzztypes.models import ReproConfig, TaskVm
|
||||
|
||||
from .azure.auth import build_auth
|
||||
from .azure.containers import save_blob
|
||||
from .azure.creds import get_base_region, get_func_storage
|
||||
from .azure.ip import get_public_ip
|
||||
from .azure.vm import VM
|
||||
from .extension import repro_extensions
|
||||
from .orm import HOURS, ORMMixin, QueryFilter
|
||||
from .reports import get_report
|
||||
from .tasks.main import Task
|
||||
|
||||
DEFAULT_OS = {
|
||||
OS.linux: "Canonical:UbuntuServer:18.04-LTS:latest",
|
||||
OS.windows: "MicrosoftWindowsDesktop:Windows-10:rs5-pro:latest",
|
||||
}
|
||||
|
||||
DEFAULT_SKU = "Standard_DS1_v2"
|
||||
|
||||
|
||||
class Repro(BASE_REPRO, ORMMixin):
|
||||
def set_error(self, error: Error) -> None:
|
||||
logging.error(
|
||||
"repro failed: vm_id: %s task_id: %s: error: %s",
|
||||
self.vm_id,
|
||||
self.task_id,
|
||||
error,
|
||||
)
|
||||
self.error = error
|
||||
self.state = VmState.stopping
|
||||
self.save()
|
||||
|
||||
def get_vm(self) -> VM:
|
||||
task = Task.get_by_task_id(self.task_id)
|
||||
if isinstance(task, Error):
|
||||
raise Exception("previously existing task missing: %s", self.task_id)
|
||||
|
||||
vm_config = task.get_repro_vm_config()
|
||||
if vm_config is None:
|
||||
# if using a pool without any scalesets defined yet, use reasonable defaults
|
||||
if task.os not in DEFAULT_OS:
|
||||
raise NotImplementedError("unsupported OS for repro %s" % task.os)
|
||||
|
||||
vm_config = TaskVm(
|
||||
region=get_base_region(), sku=DEFAULT_SKU, image=DEFAULT_OS[task.os]
|
||||
)
|
||||
|
||||
if self.auth is None:
|
||||
raise Exception("missing auth")
|
||||
|
||||
return VM(
|
||||
name=self.vm_id,
|
||||
region=vm_config.region,
|
||||
sku=vm_config.sku,
|
||||
image=vm_config.image,
|
||||
auth=self.auth,
|
||||
)
|
||||
|
||||
def init(self) -> None:
|
||||
vm = self.get_vm()
|
||||
vm_data = vm.get()
|
||||
if vm_data:
|
||||
if vm_data.provisioning_state == "Failed":
|
||||
self.set_failed(vm)
|
||||
else:
|
||||
script_result = self.build_repro_script()
|
||||
if isinstance(script_result, Error):
|
||||
return self.set_error(script_result)
|
||||
|
||||
self.state = VmState.extensions_launch
|
||||
else:
|
||||
result = vm.create()
|
||||
if isinstance(result, Error):
|
||||
return self.set_error(result)
|
||||
self.save()
|
||||
|
||||
def set_failed(self, vm_data: VirtualMachine) -> None:
|
||||
errors = []
|
||||
for status in vm_data.instance_view.statuses:
|
||||
if status.level.name.lower() == "error":
|
||||
errors.append(
|
||||
"%s %s %s" % (status.code, status.display_status, status.message)
|
||||
)
|
||||
return self.set_error(Error(code=ErrorCode.VM_CREATE_FAILED, errors=errors))
|
||||
|
||||
def get_setup_container(self) -> Optional[str]:
|
||||
task = Task.get_by_task_id(self.task_id)
|
||||
if isinstance(task, Task):
|
||||
for container in task.config.containers:
|
||||
if container.type == ContainerType.setup:
|
||||
return container.name
|
||||
return None
|
||||
|
||||
def extensions_launch(self) -> None:
|
||||
vm = self.get_vm()
|
||||
vm_data = vm.get()
|
||||
if not vm_data:
|
||||
return self.set_error(
|
||||
Error(
|
||||
code=ErrorCode.VM_CREATE_FAILED,
|
||||
errors=["failed before launching extensions"],
|
||||
)
|
||||
)
|
||||
|
||||
if vm_data.provisioning_state == "Failed":
|
||||
return self.set_failed(vm_data)
|
||||
|
||||
if not self.ip:
|
||||
self.ip = get_public_ip(vm_data.network_profile.network_interfaces[0].id)
|
||||
|
||||
extensions = repro_extensions(
|
||||
vm.region, self.os, self.vm_id, self.config, self.get_setup_container()
|
||||
)
|
||||
result = vm.add_extensions(extensions)
|
||||
if isinstance(result, Error):
|
||||
return self.set_error(result)
|
||||
elif result:
|
||||
self.state = VmState.running
|
||||
|
||||
self.save()
|
||||
|
||||
def stopping(self) -> None:
|
||||
vm = self.get_vm()
|
||||
if not vm.is_deleted():
|
||||
logging.info("vm stopping: %s", self.vm_id)
|
||||
vm.delete()
|
||||
self.save()
|
||||
else:
|
||||
self.stopped()
|
||||
|
||||
def stopped(self) -> None:
|
||||
logging.info("vm stopped: %s", self.vm_id)
|
||||
self.delete()
|
||||
|
||||
def build_repro_script(self) -> Optional[Error]:
|
||||
if self.auth is None:
|
||||
return Error(code=ErrorCode.VM_CREATE_FAILED, errors=["missing auth"])
|
||||
|
||||
task = Task.get_by_task_id(self.task_id)
|
||||
if isinstance(task, Error):
|
||||
return task
|
||||
|
||||
report = get_report(self.config.container, self.config.path)
|
||||
if report is None:
|
||||
return Error(code=ErrorCode.VM_CREATE_FAILED, errors=["missing report"])
|
||||
|
||||
files = {}
|
||||
|
||||
if task.os == OS.windows:
|
||||
ssh_path = "$env:ProgramData/ssh/administrators_authorized_keys"
|
||||
cmds = [
|
||||
'Set-Content -Path %s -Value "%s"' % (ssh_path, self.auth.public_key),
|
||||
". C:\\onefuzz\\tools\\win64\\onefuzz.ps1",
|
||||
"Set-SetSSHACL",
|
||||
'while (1) { cdb -server tcp:port=1337 -c "g" setup\\%s %s }'
|
||||
% (
|
||||
task.config.task.target_exe,
|
||||
report.input_blob.name,
|
||||
),
|
||||
]
|
||||
cmd = "\r\n".join(cmds)
|
||||
files["repro.ps1"] = cmd
|
||||
elif task.os == OS.linux:
|
||||
gdb_fmt = (
|
||||
"ASAN_OPTIONS='abort_on_error=1' gdbserver "
|
||||
"%s /onefuzz/setup/%s /onefuzz/downloaded/%s"
|
||||
)
|
||||
cmd = "while :; do %s; done" % (
|
||||
gdb_fmt
|
||||
% (
|
||||
"localhost:1337",
|
||||
task.config.task.target_exe,
|
||||
report.input_blob.name,
|
||||
)
|
||||
)
|
||||
files["repro.sh"] = cmd
|
||||
|
||||
cmd = "#!/bin/bash\n%s" % (
|
||||
gdb_fmt % ("-", task.config.task.target_exe, report.input_blob.name)
|
||||
)
|
||||
files["repro-stdout.sh"] = cmd
|
||||
else:
|
||||
raise NotImplementedError("invalid task os: %s" % task.os)
|
||||
|
||||
for filename in files:
|
||||
save_blob(
|
||||
"repro-scripts",
|
||||
"%s/%s" % (self.vm_id, filename),
|
||||
files[filename],
|
||||
account_id=get_func_storage(),
|
||||
)
|
||||
|
||||
logging.info("saved repro script")
|
||||
return None
|
||||
|
||||
def queue_stop(self, count: int) -> None:
|
||||
self.queue(method=self.stopping, visibility_timeout=count * HOURS)
|
||||
|
||||
@classmethod
|
||||
def search_states(cls, *, states: Optional[List[VmState]] = None) -> List["Repro"]:
|
||||
query: QueryFilter = {}
|
||||
if states:
|
||||
query["state"] = states
|
||||
return cls.search(query=query)
|
||||
|
||||
@classmethod
|
||||
def create(cls, config: ReproConfig) -> Union[Error, "Repro"]:
|
||||
report = get_report(config.container, config.path)
|
||||
if not report:
|
||||
return Error(
|
||||
code=ErrorCode.UNABLE_TO_FIND, errors=["unable to find report"]
|
||||
)
|
||||
|
||||
task = Task.get_by_task_id(report.task_id)
|
||||
if isinstance(task, Error):
|
||||
return task
|
||||
|
||||
vm = cls(config=config, task_id=task.task_id, os=task.os, auth=build_auth())
|
||||
vm.save()
|
||||
vm.queue_stop(config.duration)
|
||||
return vm
|
||||
|
||||
@classmethod
|
||||
def key_fields(cls) -> Tuple[str, Optional[str]]:
|
||||
return ("vm_id", None)
|
151
src/api-service/__app__/onefuzzlib/request.py
Normal file
151
src/api-service/__app__/onefuzzlib/request.py
Normal file
@ -0,0 +1,151 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Optional, Sequence, Type, TypeVar, Union
|
||||
from uuid import UUID
|
||||
|
||||
from azure.functions import HttpRequest, HttpResponse
|
||||
from azure.graphrbac.models import GraphErrorException
|
||||
from onefuzztypes.enums import ErrorCode
|
||||
from onefuzztypes.models import Error
|
||||
from onefuzztypes.responses import BaseResponse
|
||||
from pydantic import ValidationError
|
||||
|
||||
from .azure.creds import is_member_of
|
||||
from .orm import ModelMixin
|
||||
|
||||
# We don't actually use these types at runtime at this time. Rather,
|
||||
# these are used in a bound TypeVar. MyPy suggests to only import these
|
||||
# types during type checking.
|
||||
if TYPE_CHECKING:
|
||||
from onefuzztypes.requests import BaseRequest # noqa: F401
|
||||
from pydantic import BaseModel # noqa: F401
|
||||
|
||||
|
||||
def check_access(req: HttpRequest) -> Optional[Error]:
|
||||
if "ONEFUZZ_AAD_GROUP_ID" not in os.environ:
|
||||
return None
|
||||
|
||||
group_id = os.environ["ONEFUZZ_AAD_GROUP_ID"]
|
||||
member_id = req.headers["x-ms-client-principal-id"]
|
||||
try:
|
||||
result = is_member_of(group_id, member_id)
|
||||
except GraphErrorException:
|
||||
return Error(
|
||||
code=ErrorCode.UNAUTHORIZED, errors=["unable to interact with graph"]
|
||||
)
|
||||
if not result:
|
||||
logging.error("unauthorized access: %s is not in %s", member_id, group_id)
|
||||
return Error(
|
||||
code=ErrorCode.UNAUTHORIZED,
|
||||
errors=["not approved to use this instance of onefuzz"],
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def ok(
|
||||
data: Union[BaseResponse, Sequence[BaseResponse], ModelMixin, Sequence[ModelMixin]]
|
||||
) -> HttpResponse:
|
||||
if isinstance(data, BaseResponse):
|
||||
return HttpResponse(data.json(exclude_none=True), mimetype="application/json")
|
||||
|
||||
if isinstance(data, list) and len(data) > 0 and isinstance(data[0], BaseResponse):
|
||||
decoded = [json.loads(x.json(exclude_none=True)) for x in data]
|
||||
return HttpResponse(json.dumps(decoded), mimetype="application/json")
|
||||
|
||||
if isinstance(data, ModelMixin):
|
||||
return HttpResponse(
|
||||
data.json(exclude_none=True, exclude=data.export_exclude()),
|
||||
mimetype="application/json",
|
||||
)
|
||||
|
||||
decoded = [
|
||||
x.raw(exclude_none=True, exclude=x.export_exclude())
|
||||
if isinstance(x, ModelMixin)
|
||||
else x
|
||||
for x in data
|
||||
]
|
||||
return HttpResponse(
|
||||
json.dumps(decoded),
|
||||
mimetype="application/json",
|
||||
)
|
||||
|
||||
|
||||
def not_ok(
|
||||
error: Error, *, status_code: int = 400, context: Union[str, UUID]
|
||||
) -> HttpResponse:
|
||||
if 400 <= status_code and status_code <= 599:
|
||||
logging.error("request error - %s: %s" % (str(context), error.json()))
|
||||
|
||||
return HttpResponse(
|
||||
error.json(), status_code=status_code, mimetype="application/json"
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
"status code %s is not int the expected range [400; 599]" % status_code
|
||||
)
|
||||
|
||||
|
||||
def redirect(location: str) -> HttpResponse:
|
||||
return HttpResponse(status_code=302, headers={"Location": location})
|
||||
|
||||
|
||||
def convert_error(err: ValidationError) -> Error:
|
||||
errors = []
|
||||
for error in err.errors():
|
||||
if isinstance(error["loc"], tuple):
|
||||
name = ".".join([str(x) for x in error["loc"]])
|
||||
else:
|
||||
name = str(error["loc"])
|
||||
errors.append("%s: %s" % (name, error["msg"]))
|
||||
return Error(code=ErrorCode.INVALID_REQUEST, errors=errors)
|
||||
|
||||
|
||||
# TODO: loosen restrictions here during dev. We should be specific
|
||||
# about only parsing things that are of a "Request" type, but there are
|
||||
# a handful of types that need work in order to enforce that.
|
||||
#
|
||||
# These can be easily found by swapping the following comment and running
|
||||
# mypy.
|
||||
#
|
||||
# A = TypeVar("A", bound="BaseRequest")
|
||||
A = TypeVar("A", bound="BaseModel")
|
||||
|
||||
|
||||
def parse_request(cls: Type[A], req: HttpRequest) -> Union[A, Error]:
|
||||
access = check_access(req)
|
||||
if isinstance(access, Error):
|
||||
return access
|
||||
|
||||
try:
|
||||
return cls.parse_obj(req.get_json())
|
||||
except ValidationError as err:
|
||||
return convert_error(err)
|
||||
|
||||
|
||||
def parse_uri(cls: Type[A], req: HttpRequest) -> Union[A, Error]:
|
||||
access = check_access(req)
|
||||
if isinstance(access, Error):
|
||||
return access
|
||||
|
||||
data = {}
|
||||
for key in req.params:
|
||||
data[key] = req.params[key]
|
||||
|
||||
try:
|
||||
return cls.parse_obj(data)
|
||||
except ValidationError as err:
|
||||
return convert_error(err)
|
||||
|
||||
|
||||
class RequestException(Exception):
|
||||
def __init__(self, error: Error):
|
||||
self.error = error
|
||||
message = "error %s" % error
|
||||
super().__init__(message)
|
47
src/api-service/__app__/onefuzzlib/task_event.py
Normal file
47
src/api-service/__app__/onefuzzlib/task_event.py
Normal file
@ -0,0 +1,47 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
from uuid import UUID
|
||||
|
||||
from onefuzztypes.models import TaskEvent as BASE_TASK_EVENT
|
||||
from onefuzztypes.models import (
|
||||
TaskEventSummary,
|
||||
WorkerDoneEvent,
|
||||
WorkerEvent,
|
||||
WorkerRunningEvent,
|
||||
)
|
||||
|
||||
from .orm import ORMMixin
|
||||
|
||||
|
||||
class TaskEvent(BASE_TASK_EVENT, ORMMixin):
|
||||
@classmethod
|
||||
def get_summary(cls, task_id: UUID) -> List[TaskEventSummary]:
|
||||
events = cls.search(query={"task_id": [task_id]})
|
||||
events.sort(key=lambda e: e.Timestamp)
|
||||
|
||||
return [
|
||||
TaskEventSummary(
|
||||
timestamp=e.Timestamp,
|
||||
event_data=cls.get_event_data(e.event_data),
|
||||
event_type=type(e.event_data.event).__name__,
|
||||
)
|
||||
for e in events
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def key_fields(cls) -> Tuple[str, Optional[str]]:
|
||||
return ("task_id", None)
|
||||
|
||||
@classmethod
|
||||
def get_event_data(cls, worker_event: WorkerEvent) -> str:
|
||||
event = worker_event.event
|
||||
if isinstance(event, WorkerDoneEvent):
|
||||
return "exit status: %s" % event.exit_status
|
||||
elif isinstance(event, WorkerRunningEvent):
|
||||
return ""
|
||||
else:
|
||||
return "Unrecognized event: %s" % event
|
320
src/api-service/__app__/onefuzzlib/tasks/config.py
Normal file
320
src/api-service/__app__/onefuzzlib/tasks/config.py
Normal file
@ -0,0 +1,320 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from onefuzztypes.enums import Compare, ContainerPermission, ContainerType, TaskFeature
|
||||
from onefuzztypes.models import TaskConfig, TaskDefinition, TaskUnitConfig
|
||||
|
||||
from ..azure.containers import blob_exists, get_container_sas_url, get_containers
|
||||
from ..azure.creds import get_fuzz_storage, get_instance_name
|
||||
from ..azure.queue import get_queue_sas
|
||||
from .defs import TASK_DEFINITIONS
|
||||
|
||||
LOGGER = logging.getLogger("onefuzz")
|
||||
|
||||
|
||||
def get_input_container_queues(config: TaskConfig) -> Optional[List[str]]: # tasks.Task
|
||||
|
||||
if config.task.type not in TASK_DEFINITIONS:
|
||||
raise TaskConfigError("unsupported task type: %s" % config.task.type.name)
|
||||
|
||||
container_type = TASK_DEFINITIONS[config.task.type].monitor_queue
|
||||
if container_type:
|
||||
return [x.name for x in config.containers if x.type == container_type]
|
||||
return None
|
||||
|
||||
|
||||
def check_val(compare: Compare, expected: int, actual: int) -> bool:
|
||||
if compare == Compare.Equal:
|
||||
return expected == actual
|
||||
|
||||
if compare == Compare.AtLeast:
|
||||
return expected <= actual
|
||||
|
||||
if compare == Compare.AtMost:
|
||||
return expected >= actual
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def check_container(
|
||||
compare: Compare,
|
||||
expected: int,
|
||||
container_type: ContainerType,
|
||||
containers: Dict[ContainerType, List[str]],
|
||||
) -> None:
|
||||
actual = len(containers.get(container_type, []))
|
||||
if not check_val(compare, expected, actual):
|
||||
raise TaskConfigError(
|
||||
"container type %s: expected %s %d, got %d"
|
||||
% (container_type.name, compare.name, expected, actual)
|
||||
)
|
||||
|
||||
|
||||
def check_containers(definition: TaskDefinition, config: TaskConfig) -> None:
|
||||
all_containers = set(get_containers().keys())
|
||||
|
||||
containers: Dict[ContainerType, List[str]] = {}
|
||||
for container in config.containers:
|
||||
if container.name not in all_containers:
|
||||
raise TaskConfigError("missing container: %s" % container.name)
|
||||
|
||||
if container.type not in containers:
|
||||
containers[container.type] = []
|
||||
containers[container.type].append(container.name)
|
||||
|
||||
for container_def in definition.containers:
|
||||
check_container(
|
||||
container_def.compare, container_def.value, container_def.type, containers
|
||||
)
|
||||
|
||||
for container_type in containers:
|
||||
if container_type not in [x.type for x in definition.containers]:
|
||||
raise TaskConfigError(
|
||||
"unsupported container type for this task: %s", container_type.name
|
||||
)
|
||||
|
||||
if definition.monitor_queue:
|
||||
if definition.monitor_queue not in [x.type for x in definition.containers]:
|
||||
raise TaskConfigError(
|
||||
"unable to monitor container type as it is not used by this task: %s"
|
||||
% definition.monitor_queue.name
|
||||
)
|
||||
|
||||
|
||||
def check_config(config: TaskConfig) -> None:
|
||||
if config.task.type not in TASK_DEFINITIONS:
|
||||
raise TaskConfigError("unsupported task type: %s" % config.task.type.name)
|
||||
|
||||
if config.vm is not None and config.pool is not None:
|
||||
raise TaskConfigError("either the vm or pool must be specified, but not both")
|
||||
|
||||
definition = TASK_DEFINITIONS[config.task.type]
|
||||
|
||||
check_containers(definition, config)
|
||||
|
||||
if (
|
||||
TaskFeature.supervisor_exe in definition.features
|
||||
and not config.task.supervisor_exe
|
||||
):
|
||||
err = "missing supervisor_exe"
|
||||
LOGGER.error(err)
|
||||
raise TaskConfigError("missing supervisor_exe")
|
||||
|
||||
if config.vm:
|
||||
if not check_val(definition.vm.compare, definition.vm.value, config.vm.count):
|
||||
err = "invalid vm count: expected %s %d, got %s" % (
|
||||
definition.vm.compare,
|
||||
definition.vm.value,
|
||||
config.vm.count,
|
||||
)
|
||||
LOGGER.error(err)
|
||||
raise TaskConfigError(err)
|
||||
elif config.pool:
|
||||
if not check_val(definition.vm.compare, definition.vm.value, config.pool.count):
|
||||
err = "invalid vm count: expected %s %d, got %s" % (
|
||||
definition.vm.compare,
|
||||
definition.vm.value,
|
||||
config.pool.count,
|
||||
)
|
||||
LOGGER.error(err)
|
||||
raise TaskConfigError(err)
|
||||
else:
|
||||
raise TaskConfigError("either the vm or pool must be specified")
|
||||
|
||||
if TaskFeature.target_exe in definition.features:
|
||||
container = [x for x in config.containers if x.type == ContainerType.setup][0]
|
||||
if not blob_exists(container.name, config.task.target_exe):
|
||||
err = "target_exe `%s` does not exist in the setup container `%s`" % (
|
||||
config.task.target_exe,
|
||||
container.name,
|
||||
)
|
||||
LOGGER.warning(err)
|
||||
|
||||
if TaskFeature.generator_exe in definition.features:
|
||||
container = [x for x in config.containers if x.type == ContainerType.tools][0]
|
||||
if not config.task.generator_exe:
|
||||
raise TaskConfigError("generator_exe is not defined")
|
||||
|
||||
tools_paths = ["{tools_dir}/", "{tools_dir}\\"]
|
||||
for tool_path in tools_paths:
|
||||
if config.task.generator_exe.startswith(tool_path):
|
||||
generator = config.task.generator_exe.replace(tool_path, "")
|
||||
if not blob_exists(container.name, generator):
|
||||
err = (
|
||||
"generator_exe `%s` does not exist in the tools container `%s`"
|
||||
% (
|
||||
config.task.generator_exe,
|
||||
container.name,
|
||||
)
|
||||
)
|
||||
LOGGER.error(err)
|
||||
raise TaskConfigError(err)
|
||||
|
||||
if TaskFeature.stats_file in definition.features:
|
||||
if config.task.stats_file is not None and config.task.stats_format is None:
|
||||
err = "using a stats_file requires a stats_format"
|
||||
LOGGER.error(err)
|
||||
raise TaskConfigError(err)
|
||||
|
||||
|
||||
def build_task_config(
|
||||
job_id: UUID, task_id: UUID, task_config: TaskConfig
|
||||
) -> TaskUnitConfig:
|
||||
|
||||
if task_config.task.type not in TASK_DEFINITIONS:
|
||||
raise TaskConfigError("unsupported task type: %s" % task_config.task.type.name)
|
||||
|
||||
definition = TASK_DEFINITIONS[task_config.task.type]
|
||||
|
||||
config = TaskUnitConfig(
|
||||
job_id=job_id,
|
||||
task_id=task_id,
|
||||
task_type=task_config.task.type,
|
||||
instrumentation_key=os.environ.get("APPINSIGHTS_INSTRUMENTATIONKEY"),
|
||||
telemetry_key=os.environ.get("ONEFUZZ_TELEMETRY"),
|
||||
heartbeat_queue=get_queue_sas(
|
||||
"heartbeat",
|
||||
account_id=os.environ["ONEFUZZ_FUNC_STORAGE"],
|
||||
add=True,
|
||||
),
|
||||
back_channel_address="https://%s.azurewebsites.net/api/back_channel"
|
||||
% (get_instance_name()),
|
||||
)
|
||||
|
||||
if definition.monitor_queue:
|
||||
config.input_queue = get_queue_sas(
|
||||
task_id,
|
||||
add=True,
|
||||
read=True,
|
||||
update=True,
|
||||
process=True,
|
||||
account_id=get_fuzz_storage(),
|
||||
)
|
||||
|
||||
for container_def in definition.containers:
|
||||
if container_def.type == ContainerType.setup:
|
||||
continue
|
||||
|
||||
containers = []
|
||||
for (i, container) in enumerate(task_config.containers):
|
||||
if container.type != container_def.type:
|
||||
continue
|
||||
|
||||
containers.append(
|
||||
{
|
||||
"path": "_".join(["task", container_def.type.name, str(i)]),
|
||||
"url": get_container_sas_url(
|
||||
container.name,
|
||||
read=ContainerPermission.Read in container_def.permissions,
|
||||
write=ContainerPermission.Write in container_def.permissions,
|
||||
add=ContainerPermission.Add in container_def.permissions,
|
||||
delete=ContainerPermission.Delete in container_def.permissions,
|
||||
create=ContainerPermission.Create in container_def.permissions,
|
||||
list=ContainerPermission.List in container_def.permissions,
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
if not containers:
|
||||
continue
|
||||
|
||||
if (
|
||||
container_def.compare in [Compare.Equal, Compare.AtMost]
|
||||
and container_def.value == 1
|
||||
):
|
||||
setattr(config, container_def.type.name, containers[0])
|
||||
else:
|
||||
setattr(config, container_def.type.name, containers)
|
||||
|
||||
EMPTY_DICT: Dict[str, str] = {}
|
||||
EMPTY_LIST: List[str] = []
|
||||
|
||||
if TaskFeature.supervisor_exe in definition.features:
|
||||
config.supervisor_exe = task_config.task.supervisor_exe
|
||||
|
||||
if TaskFeature.supervisor_env in definition.features:
|
||||
config.supervisor_env = task_config.task.supervisor_env or EMPTY_DICT
|
||||
|
||||
if TaskFeature.supervisor_options in definition.features:
|
||||
config.supervisor_options = task_config.task.supervisor_options or EMPTY_LIST
|
||||
|
||||
if TaskFeature.supervisor_input_marker in definition.features:
|
||||
config.supervisor_input_marker = task_config.task.supervisor_input_marker
|
||||
|
||||
if TaskFeature.target_exe in definition.features:
|
||||
config.target_exe = "setup/%s" % task_config.task.target_exe
|
||||
|
||||
if TaskFeature.target_env in definition.features:
|
||||
config.target_env = task_config.task.target_env or EMPTY_DICT
|
||||
|
||||
if TaskFeature.target_options in definition.features:
|
||||
config.target_options = task_config.task.target_options or EMPTY_LIST
|
||||
|
||||
if TaskFeature.target_options_merge in definition.features:
|
||||
config.target_options_merge = task_config.task.target_options_merge or False
|
||||
|
||||
if TaskFeature.rename_output in definition.features:
|
||||
config.rename_output = task_config.task.rename_output or False
|
||||
|
||||
if TaskFeature.generator_exe in definition.features:
|
||||
config.generator_exe = task_config.task.generator_exe
|
||||
|
||||
if TaskFeature.generator_env in definition.features:
|
||||
config.generator_env = task_config.task.generator_env or EMPTY_DICT
|
||||
|
||||
if TaskFeature.generator_options in definition.features:
|
||||
config.generator_options = task_config.task.generator_options or EMPTY_LIST
|
||||
|
||||
if (
|
||||
TaskFeature.wait_for_files in definition.features
|
||||
and task_config.task.wait_for_files
|
||||
):
|
||||
config.wait_for_files = task_config.task.wait_for_files.name
|
||||
|
||||
if TaskFeature.analyzer_exe in definition.features:
|
||||
config.analyzer_exe = task_config.task.analyzer_exe
|
||||
|
||||
if TaskFeature.analyzer_options in definition.features:
|
||||
config.analyzer_options = task_config.task.analyzer_options or EMPTY_LIST
|
||||
|
||||
if TaskFeature.analyzer_env in definition.features:
|
||||
config.analyzer_env = task_config.task.analyzer_env or EMPTY_DICT
|
||||
|
||||
if TaskFeature.stats_file in definition.features:
|
||||
config.stats_file = task_config.task.stats_file
|
||||
config.stats_format = task_config.task.stats_format
|
||||
|
||||
if TaskFeature.target_timeout in definition.features:
|
||||
config.target_timeout = task_config.task.target_timeout
|
||||
|
||||
if TaskFeature.check_asan_log in definition.features:
|
||||
config.check_asan_log = task_config.task.check_asan_log
|
||||
|
||||
if TaskFeature.check_debugger in definition.features:
|
||||
config.check_debugger = task_config.task.check_debugger
|
||||
|
||||
if TaskFeature.check_retry_count in definition.features:
|
||||
config.check_retry_count = task_config.task.check_retry_count or 0
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def get_setup_container(config: TaskConfig) -> str:
|
||||
for container in config.containers:
|
||||
if container.type == ContainerType.setup:
|
||||
return container.name
|
||||
|
||||
raise TaskConfigError(
|
||||
"task missing setup container: task_type = %s" % config.task.type
|
||||
)
|
||||
|
||||
|
||||
class TaskConfigError(Exception):
|
||||
pass
|
376
src/api-service/__app__/onefuzzlib/tasks/defs.py
Normal file
376
src/api-service/__app__/onefuzzlib/tasks/defs.py
Normal file
@ -0,0 +1,376 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from onefuzztypes.enums import (
|
||||
Compare,
|
||||
ContainerPermission,
|
||||
ContainerType,
|
||||
TaskFeature,
|
||||
TaskType,
|
||||
)
|
||||
from onefuzztypes.models import ContainerDefinition, TaskDefinition, VmDefinition
|
||||
|
||||
# all tasks are required to have a 'setup' container
|
||||
TASK_DEFINITIONS = {
|
||||
TaskType.generic_analysis: TaskDefinition(
|
||||
features=[
|
||||
TaskFeature.target_exe,
|
||||
TaskFeature.target_options,
|
||||
TaskFeature.analyzer_exe,
|
||||
TaskFeature.analyzer_env,
|
||||
TaskFeature.analyzer_options,
|
||||
],
|
||||
vm=VmDefinition(compare=Compare.AtLeast, value=1),
|
||||
containers=[
|
||||
ContainerDefinition(
|
||||
type=ContainerType.setup,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.analysis,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[
|
||||
ContainerPermission.Write,
|
||||
ContainerPermission.Read,
|
||||
ContainerPermission.List,
|
||||
ContainerPermission.Create,
|
||||
],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.crashes,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.tools,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
],
|
||||
monitor_queue=ContainerType.crashes,
|
||||
),
|
||||
TaskType.libfuzzer_fuzz: TaskDefinition(
|
||||
features=[
|
||||
TaskFeature.target_exe,
|
||||
TaskFeature.target_env,
|
||||
TaskFeature.target_options,
|
||||
TaskFeature.target_workers,
|
||||
],
|
||||
vm=VmDefinition(compare=Compare.AtLeast, value=1),
|
||||
containers=[
|
||||
ContainerDefinition(
|
||||
type=ContainerType.setup,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.crashes,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Write, ContainerPermission.Create],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.inputs,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[
|
||||
ContainerPermission.Write,
|
||||
ContainerPermission.Read,
|
||||
ContainerPermission.List,
|
||||
ContainerPermission.Create,
|
||||
],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.readonly_inputs,
|
||||
compare=Compare.AtLeast,
|
||||
value=0,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
],
|
||||
monitor_queue=None,
|
||||
),
|
||||
TaskType.libfuzzer_crash_report: TaskDefinition(
|
||||
features=[
|
||||
TaskFeature.target_exe,
|
||||
TaskFeature.target_env,
|
||||
TaskFeature.target_options,
|
||||
TaskFeature.target_timeout,
|
||||
TaskFeature.check_retry_count,
|
||||
],
|
||||
vm=VmDefinition(compare=Compare.AtLeast, value=1),
|
||||
containers=[
|
||||
ContainerDefinition(
|
||||
type=ContainerType.setup,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.crashes,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.reports,
|
||||
compare=Compare.AtMost,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Create],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.unique_reports,
|
||||
compare=Compare.AtMost,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Create],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.no_repro,
|
||||
compare=Compare.AtMost,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Create],
|
||||
),
|
||||
],
|
||||
monitor_queue=ContainerType.crashes,
|
||||
),
|
||||
TaskType.libfuzzer_coverage: TaskDefinition(
|
||||
features=[
|
||||
TaskFeature.target_exe,
|
||||
TaskFeature.target_env,
|
||||
TaskFeature.target_options,
|
||||
],
|
||||
vm=VmDefinition(compare=Compare.Equal, value=1),
|
||||
containers=[
|
||||
ContainerDefinition(
|
||||
type=ContainerType.setup,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.readonly_inputs,
|
||||
compare=Compare.AtLeast,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.coverage,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[
|
||||
ContainerPermission.Create,
|
||||
ContainerPermission.List,
|
||||
ContainerPermission.Read,
|
||||
ContainerPermission.Write,
|
||||
],
|
||||
),
|
||||
],
|
||||
monitor_queue=ContainerType.readonly_inputs,
|
||||
),
|
||||
TaskType.libfuzzer_merge: TaskDefinition(
|
||||
features=[
|
||||
TaskFeature.target_exe,
|
||||
TaskFeature.target_env,
|
||||
TaskFeature.target_options,
|
||||
],
|
||||
vm=VmDefinition(compare=Compare.Equal, value=1),
|
||||
containers=[
|
||||
ContainerDefinition(
|
||||
type=ContainerType.setup,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.readonly_inputs,
|
||||
compare=Compare.AtLeast,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.inputs,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Create, ContainerPermission.List],
|
||||
),
|
||||
],
|
||||
monitor_queue=ContainerType.inputs,
|
||||
),
|
||||
TaskType.generic_supervisor: TaskDefinition(
|
||||
features=[
|
||||
TaskFeature.target_exe,
|
||||
TaskFeature.target_options,
|
||||
TaskFeature.supervisor_exe,
|
||||
TaskFeature.supervisor_env,
|
||||
TaskFeature.supervisor_options,
|
||||
TaskFeature.supervisor_input_marker,
|
||||
TaskFeature.wait_for_files,
|
||||
],
|
||||
vm=VmDefinition(compare=Compare.AtLeast, value=1),
|
||||
containers=[
|
||||
ContainerDefinition(
|
||||
type=ContainerType.setup,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.tools,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.crashes,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Create],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.inputs,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[
|
||||
ContainerPermission.Create,
|
||||
ContainerPermission.Read,
|
||||
ContainerPermission.List,
|
||||
],
|
||||
),
|
||||
],
|
||||
monitor_queue=None,
|
||||
),
|
||||
TaskType.generic_merge: TaskDefinition(
|
||||
features=[
|
||||
TaskFeature.target_exe,
|
||||
TaskFeature.target_options,
|
||||
TaskFeature.supervisor_exe,
|
||||
TaskFeature.supervisor_env,
|
||||
TaskFeature.supervisor_options,
|
||||
TaskFeature.supervisor_input_marker,
|
||||
TaskFeature.stats_file,
|
||||
],
|
||||
vm=VmDefinition(compare=Compare.AtLeast, value=1),
|
||||
containers=[
|
||||
ContainerDefinition(
|
||||
type=ContainerType.setup,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.tools,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.readonly_inputs,
|
||||
compare=Compare.AtLeast,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.inputs,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Create, ContainerPermission.List],
|
||||
),
|
||||
],
|
||||
monitor_queue=None,
|
||||
),
|
||||
TaskType.generic_generator: TaskDefinition(
|
||||
features=[
|
||||
TaskFeature.generator_exe,
|
||||
TaskFeature.generator_env,
|
||||
TaskFeature.generator_options,
|
||||
TaskFeature.target_exe,
|
||||
TaskFeature.target_env,
|
||||
TaskFeature.target_options,
|
||||
TaskFeature.rename_output,
|
||||
TaskFeature.target_timeout,
|
||||
TaskFeature.check_asan_log,
|
||||
TaskFeature.check_debugger,
|
||||
TaskFeature.check_retry_count,
|
||||
],
|
||||
vm=VmDefinition(compare=Compare.AtLeast, value=1),
|
||||
containers=[
|
||||
ContainerDefinition(
|
||||
type=ContainerType.setup,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.tools,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.crashes,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Create],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.readonly_inputs,
|
||||
compare=Compare.AtLeast,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
],
|
||||
monitor_queue=None,
|
||||
),
|
||||
TaskType.generic_crash_report: TaskDefinition(
|
||||
features=[
|
||||
TaskFeature.target_exe,
|
||||
TaskFeature.target_env,
|
||||
TaskFeature.target_options,
|
||||
TaskFeature.target_timeout,
|
||||
TaskFeature.check_asan_log,
|
||||
TaskFeature.check_debugger,
|
||||
TaskFeature.check_retry_count,
|
||||
],
|
||||
vm=VmDefinition(compare=Compare.AtLeast, value=1),
|
||||
containers=[
|
||||
ContainerDefinition(
|
||||
type=ContainerType.setup,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.crashes,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.reports,
|
||||
compare=Compare.AtMost,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Create],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.unique_reports,
|
||||
compare=Compare.AtMost,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Create],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.no_repro,
|
||||
compare=Compare.AtMost,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Create],
|
||||
),
|
||||
],
|
||||
monitor_queue=ContainerType.crashes,
|
||||
),
|
||||
}
|
231
src/api-service/__app__/onefuzzlib/tasks/main.py
Normal file
231
src/api-service/__app__/onefuzzlib/tasks/main.py
Normal file
@ -0,0 +1,231 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from uuid import UUID
|
||||
|
||||
from onefuzztypes.enums import ErrorCode, TaskState
|
||||
from onefuzztypes.models import Error
|
||||
from onefuzztypes.models import Task as BASE_TASK
|
||||
from onefuzztypes.models import TaskConfig, TaskVm
|
||||
|
||||
from ..azure.creds import get_fuzz_storage
|
||||
from ..azure.image import get_os
|
||||
from ..azure.queue import create_queue, delete_queue
|
||||
from ..orm import MappingIntStrAny, ORMMixin, QueryFilter
|
||||
from ..pools import Node, Pool, Scaleset
|
||||
from ..proxy_forward import ProxyForward
|
||||
|
||||
|
||||
class Task(BASE_TASK, ORMMixin):
|
||||
def ready_to_schedule(self) -> bool:
|
||||
if self.config.prereq_tasks:
|
||||
for task_id in self.config.prereq_tasks:
|
||||
task = Task.get_by_task_id(task_id)
|
||||
# if a prereq task fails, then mark this task as failed
|
||||
if isinstance(task, Error):
|
||||
self.error = task
|
||||
self.state = TaskState.stopping
|
||||
self.save()
|
||||
return False
|
||||
|
||||
if task.state not in task.state.has_started():
|
||||
return False
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def create(cls, config: TaskConfig, job_id: UUID) -> Union["Task", Error]:
|
||||
if config.vm:
|
||||
os = get_os(config.vm.region, config.vm.image)
|
||||
elif config.pool:
|
||||
pool = Pool.get_by_name(config.pool.pool_name)
|
||||
if isinstance(pool, Error):
|
||||
return pool
|
||||
os = pool.os
|
||||
else:
|
||||
raise Exception("task must have vm or pool")
|
||||
task = cls(config=config, job_id=job_id, os=os)
|
||||
task.save()
|
||||
return task
|
||||
|
||||
def save_exclude(self) -> Optional[MappingIntStrAny]:
|
||||
return {"heartbeats": ...}
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
if self.config.prereq_tasks:
|
||||
for prereq_id in self.config.prereq_tasks:
|
||||
prereq = Task.get_by_task_id(prereq_id)
|
||||
if isinstance(prereq, Error):
|
||||
logging.info("task prereq has error: %s - %s", self.task_id, prereq)
|
||||
self.error = prereq
|
||||
self.state = TaskState.stopping
|
||||
self.save()
|
||||
return False
|
||||
if prereq.state != TaskState.running:
|
||||
logging.info(
|
||||
"task is waiting on prereq: %s - %s:",
|
||||
self.task_id,
|
||||
prereq.task_id,
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
# At current, the telemetry filter will generate something like this:
|
||||
#
|
||||
# {
|
||||
# 'job_id': 'f4a20fd8-0dcc-4a4e-8804-6ef7df50c978',
|
||||
# 'task_id': '835f7b3f-43ad-4718-b7e4-d506d9667b09',
|
||||
# 'state': 'stopped',
|
||||
# 'config': {
|
||||
# 'task': {'type': 'libfuzzer_coverage'},
|
||||
# 'vm': {'count': 1}
|
||||
# }
|
||||
# }
|
||||
def telemetry_include(self) -> Optional[MappingIntStrAny]:
|
||||
return {
|
||||
"job_id": ...,
|
||||
"task_id": ...,
|
||||
"state": ...,
|
||||
"config": {"vm": {"count": ...}, "task": {"type": ...}},
|
||||
}
|
||||
|
||||
def event_include(self) -> Optional[MappingIntStrAny]:
|
||||
return {
|
||||
"job_id": ...,
|
||||
"task_id": ...,
|
||||
"state": ...,
|
||||
"error": ...,
|
||||
}
|
||||
|
||||
def init(self) -> None:
|
||||
create_queue(self.task_id, account_id=get_fuzz_storage())
|
||||
self.state = TaskState.waiting
|
||||
self.save()
|
||||
|
||||
def stopping(self) -> None:
|
||||
# TODO: we need to tell every node currently working on this task to stop
|
||||
# TODO: we need to 'unschedule' this task from the existing pools
|
||||
|
||||
self.state = TaskState.stopping
|
||||
logging.info("stopping task: %s:%s", self.job_id, self.task_id)
|
||||
ProxyForward.remove_forward(self.task_id)
|
||||
delete_queue(str(self.task_id), account_id=get_fuzz_storage())
|
||||
Node.stop_task(self.task_id)
|
||||
self.state = TaskState.stopped
|
||||
self.save()
|
||||
|
||||
def queue_stop(self) -> None:
|
||||
self.queue(method=self.stopping)
|
||||
|
||||
@classmethod
|
||||
def search_states(
|
||||
cls, *, job_id: Optional[UUID] = None, states: Optional[List[TaskState]] = None
|
||||
) -> List["Task"]:
|
||||
query: QueryFilter = {}
|
||||
if job_id:
|
||||
query["job_id"] = [job_id]
|
||||
if states:
|
||||
query["state"] = states
|
||||
|
||||
return cls.search(query=query)
|
||||
|
||||
@classmethod
|
||||
def search_expired(cls) -> List["Task"]:
|
||||
time_filter = "end_time lt datetime'%s'" % datetime.utcnow().isoformat()
|
||||
return cls.search(
|
||||
query={"state": TaskState.available()}, raw_unchecked_filter=time_filter
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_by_task_id(cls, task_id: UUID) -> Union[Error, "Task"]:
|
||||
tasks = cls.search(query={"task_id": [task_id]})
|
||||
if not tasks:
|
||||
return Error(code=ErrorCode.INVALID_REQUEST, errors=["unable to find task"])
|
||||
|
||||
if len(tasks) != 1:
|
||||
return Error(
|
||||
code=ErrorCode.INVALID_REQUEST, errors=["error identifying task"]
|
||||
)
|
||||
task = tasks[0]
|
||||
return task
|
||||
|
||||
def get_pool(self) -> Optional[Pool]:
|
||||
if self.config.pool:
|
||||
pool = Pool.get_by_name(self.config.pool.pool_name)
|
||||
if isinstance(pool, Error):
|
||||
logging.info(
|
||||
"unable to schedule task to pool: %s - %s", self.task_id, pool
|
||||
)
|
||||
return None
|
||||
return pool
|
||||
elif self.config.vm:
|
||||
scalesets = Scaleset.search()
|
||||
scalesets = [
|
||||
x
|
||||
for x in scalesets
|
||||
if x.vm_sku == self.config.vm.sku and x.image == self.config.vm.image
|
||||
]
|
||||
for scaleset in scalesets:
|
||||
pool = Pool.get_by_name(scaleset.pool_name)
|
||||
if isinstance(pool, Error):
|
||||
logging.info(
|
||||
"unable to schedule task to pool: %s - %s",
|
||||
self.task_id,
|
||||
pool,
|
||||
)
|
||||
else:
|
||||
return pool
|
||||
|
||||
logging.warn(
|
||||
"unable to find a scaleset that matches the task prereqs: %s",
|
||||
self.task_id,
|
||||
)
|
||||
return None
|
||||
|
||||
def get_repro_vm_config(self) -> Union[TaskVm, None]:
|
||||
if self.config.vm:
|
||||
return self.config.vm
|
||||
|
||||
if self.config.pool is None:
|
||||
raise Exception("either pool or vm must be specified: %s" % self.task_id)
|
||||
|
||||
pool = Pool.get_by_name(self.config.pool.pool_name)
|
||||
if isinstance(pool, Error):
|
||||
logging.info("unable to find pool from task: %s", self.task_id)
|
||||
return None
|
||||
|
||||
for scaleset in Scaleset.search_by_pool(self.config.pool.pool_name):
|
||||
return TaskVm(
|
||||
region=scaleset.region,
|
||||
sku=scaleset.vm_sku,
|
||||
image=scaleset.image,
|
||||
)
|
||||
|
||||
logging.warning(
|
||||
"no scalesets are defined for task: %s:%s", self.job_id, self.task_id
|
||||
)
|
||||
return None
|
||||
|
||||
def on_start(self) -> None:
|
||||
# try to keep this effectively idempotent
|
||||
|
||||
if self.end_time is None:
|
||||
self.end_time = datetime.utcnow() + timedelta(
|
||||
hours=self.config.task.duration
|
||||
)
|
||||
self.save()
|
||||
|
||||
from ..jobs import Job
|
||||
|
||||
job = Job.get(self.job_id)
|
||||
if job:
|
||||
job.on_start()
|
||||
|
||||
@classmethod
|
||||
def key_fields(cls) -> Tuple[str, str]:
|
||||
return ("job_id", "task_id")
|
94
src/api-service/__app__/onefuzzlib/tasks/scheduler.py
Normal file
94
src/api-service/__app__/onefuzzlib/tasks/scheduler.py
Normal file
@ -0,0 +1,94 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
from typing import Dict, List
|
||||
from uuid import UUID
|
||||
|
||||
from onefuzztypes.enums import OS, TaskState
|
||||
from onefuzztypes.models import WorkSet, WorkUnit
|
||||
|
||||
from ..azure.containers import blob_exists, get_container_sas_url, save_blob
|
||||
from ..azure.creds import get_func_storage
|
||||
from .config import build_task_config, get_setup_container
|
||||
from .main import Task
|
||||
|
||||
HOURS = 60 * 60
|
||||
|
||||
|
||||
def schedule_tasks() -> None:
|
||||
to_schedule: Dict[UUID, List[Task]] = {}
|
||||
|
||||
for task in Task.search_states(states=[TaskState.waiting]):
|
||||
if not task.ready_to_schedule():
|
||||
continue
|
||||
|
||||
if task.job_id not in to_schedule:
|
||||
to_schedule[task.job_id] = []
|
||||
to_schedule[task.job_id].append(task)
|
||||
|
||||
for tasks in to_schedule.values():
|
||||
# TODO: for now, we're only scheduling one task per VM.
|
||||
|
||||
for task in tasks:
|
||||
logging.info("scheduling task: %s", task.task_id)
|
||||
agent_config = build_task_config(task.job_id, task.task_id, task.config)
|
||||
|
||||
setup_container = get_setup_container(task.config)
|
||||
setup_url = get_container_sas_url(setup_container, read=True, list=True)
|
||||
|
||||
setup_script = None
|
||||
|
||||
if task.os == OS.windows and blob_exists(setup_container, "setup.ps1"):
|
||||
setup_script = "setup.ps1"
|
||||
if task.os == OS.linux and blob_exists(setup_container, "setup.sh"):
|
||||
setup_script = "setup.sh"
|
||||
|
||||
save_blob(
|
||||
"task-configs",
|
||||
"%s/config.json" % task.task_id,
|
||||
agent_config.json(),
|
||||
account_id=get_func_storage(),
|
||||
)
|
||||
reboot = False
|
||||
count = 1
|
||||
if task.config.pool:
|
||||
count = task.config.pool.count
|
||||
reboot = task.config.task.reboot_after_setup is True
|
||||
elif task.config.vm:
|
||||
# this branch should go away when we stop letting people specify
|
||||
# VM configs directly.
|
||||
count = task.config.vm.count
|
||||
reboot = (
|
||||
task.config.vm.reboot_after_setup is True
|
||||
or task.config.task.reboot_after_setup is True
|
||||
)
|
||||
|
||||
task_config = agent_config
|
||||
task_config_json = task_config.json()
|
||||
work_unit = WorkUnit(
|
||||
job_id=task_config.job_id,
|
||||
task_id=task_config.task_id,
|
||||
task_type=task_config.task_type,
|
||||
config=task_config_json,
|
||||
)
|
||||
|
||||
# For now, only offer singleton work sets.
|
||||
work_set = WorkSet(
|
||||
reboot=reboot,
|
||||
script=(setup_script is not None),
|
||||
setup_url=setup_url,
|
||||
work_units=[work_unit],
|
||||
)
|
||||
|
||||
pool = task.get_pool()
|
||||
if not pool:
|
||||
logging.info("unable to find pool for task: %s", task.task_id)
|
||||
continue
|
||||
|
||||
for _ in range(count):
|
||||
pool.schedule_workset(work_set)
|
||||
task.state = TaskState.scheduled
|
||||
task.save()
|
72
src/api-service/__app__/onefuzzlib/telemetry.py
Normal file
72
src/api-service/__app__/onefuzzlib/telemetry.py
Normal file
@ -0,0 +1,72 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from onefuzztypes.enums import TelemetryData, TelemetryEvent
|
||||
from opencensus.ext.azure.log_exporter import AzureLogHandler
|
||||
|
||||
LOCAL_CLIENT: Optional[logging.Logger] = None
|
||||
CENTRAL_CLIENT: Optional[logging.Logger] = None
|
||||
|
||||
|
||||
def _get_client(environ_key: str) -> Optional[logging.Logger]:
|
||||
key = os.environ.get(environ_key)
|
||||
if key is None:
|
||||
return None
|
||||
client = logging.getLogger("onefuzz")
|
||||
client.addHandler(AzureLogHandler(connection_string="InstrumentationKey=%s" % key))
|
||||
return client
|
||||
|
||||
|
||||
def _central_client() -> Optional[logging.Logger]:
|
||||
global CENTRAL_CLIENT
|
||||
if not CENTRAL_CLIENT:
|
||||
CENTRAL_CLIENT = _get_client("ONEFUZZ_TELEMETRY")
|
||||
return CENTRAL_CLIENT
|
||||
|
||||
|
||||
def _local_client() -> Union[None, Any, logging.Logger]:
|
||||
global LOCAL_CLIENT
|
||||
if not LOCAL_CLIENT:
|
||||
LOCAL_CLIENT = _get_client("APPINSIGHTS_INSTRUMENTATIONKEY")
|
||||
return LOCAL_CLIENT
|
||||
|
||||
|
||||
# NOTE: All telemetry that is *NOT* using the ORM telemetry_include should
|
||||
# go through this method
|
||||
#
|
||||
# This provides a point of inspection to know if it's data that is safe to
|
||||
# log to the central OneFuzz telemetry point
|
||||
def track_event(
|
||||
event: TelemetryEvent, data: Dict[TelemetryData, Union[str, int]]
|
||||
) -> None:
|
||||
central = _central_client()
|
||||
local = _local_client()
|
||||
|
||||
if local:
|
||||
serialized = {k.name: v for (k, v) in data.items()}
|
||||
local.info(event.name, extra={"custom_dimensions": serialized})
|
||||
|
||||
if event in TelemetryEvent.can_share() and central:
|
||||
serialized = {
|
||||
k.name: v for (k, v) in data.items() if k in TelemetryData.can_share()
|
||||
}
|
||||
central.info(event.name, extra={"custom_dimensions": serialized})
|
||||
|
||||
|
||||
# NOTE: This should *only* be used for logging Telemetry data that uses
|
||||
# the ORM telemetry_include method to limit data for telemetry.
|
||||
def track_event_filtered(event: TelemetryEvent, data: Any) -> None:
|
||||
central = _central_client()
|
||||
local = _local_client()
|
||||
|
||||
if local:
|
||||
local.info(event.name, extra={"custom_dimensions": data})
|
||||
|
||||
if central and event in TelemetryEvent.can_share():
|
||||
central.info(event.name, extra={"custom_dimensions": data})
|
110
src/api-service/__app__/onefuzzlib/updates.py
Normal file
110
src/api-service/__app__/onefuzzlib/updates.py
Normal file
@ -0,0 +1,110 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
from typing import Dict, Optional, Type
|
||||
|
||||
from msrestazure.azure_exceptions import CloudError
|
||||
from onefuzztypes.enums import UpdateType
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .azure.creds import get_func_storage
|
||||
from .azure.queue import queue_object
|
||||
|
||||
|
||||
# This class isn't intended to be shared outside of the service
|
||||
class Update(BaseModel):
|
||||
update_type: UpdateType
|
||||
PartitionKey: Optional[str]
|
||||
RowKey: Optional[str]
|
||||
method: Optional[str]
|
||||
|
||||
|
||||
def queue_update(
|
||||
update_type: UpdateType,
|
||||
PartitionKey: Optional[str] = None,
|
||||
RowKey: Optional[str] = None,
|
||||
method: Optional[str] = None,
|
||||
visibility_timeout: int = None,
|
||||
) -> None:
|
||||
logging.info(
|
||||
"queuing type:%s id:[%s,%s] method:%s timeout: %s",
|
||||
update_type.name,
|
||||
PartitionKey,
|
||||
RowKey,
|
||||
method,
|
||||
visibility_timeout,
|
||||
)
|
||||
|
||||
update = Update(
|
||||
update_type=update_type, PartitionKey=PartitionKey, RowKey=RowKey, method=method
|
||||
)
|
||||
|
||||
try:
|
||||
if not queue_object(
|
||||
"update-queue",
|
||||
update,
|
||||
account_id=get_func_storage(),
|
||||
visibility_timeout=visibility_timeout,
|
||||
):
|
||||
logging.error("unable to queue update")
|
||||
except CloudError as err:
|
||||
logging.error("GOT ERROR %s", repr(err))
|
||||
logging.error("GOT ERROR %s", dir(err))
|
||||
raise err
|
||||
|
||||
|
||||
def execute_update(update: Update) -> None:
|
||||
from .jobs import Job
|
||||
from .orm import ORMMixin
|
||||
from .pools import Node, Pool, Scaleset
|
||||
from .proxy import Proxy
|
||||
from .repro import Repro
|
||||
from .tasks.main import Task
|
||||
|
||||
update_objects: Dict[UpdateType, Type[ORMMixin]] = {
|
||||
UpdateType.Task: Task,
|
||||
UpdateType.Job: Job,
|
||||
UpdateType.Repro: Repro,
|
||||
UpdateType.Proxy: Proxy,
|
||||
UpdateType.Pool: Pool,
|
||||
UpdateType.Node: Node,
|
||||
UpdateType.Scaleset: Scaleset,
|
||||
}
|
||||
|
||||
# TODO: remove these from being queued, these updates are handled elsewhere
|
||||
if update.update_type == UpdateType.Scaleset:
|
||||
return
|
||||
|
||||
if update.update_type in update_objects:
|
||||
if update.PartitionKey is None or update.RowKey is None:
|
||||
raise Exception("unsupported update: %s" % update)
|
||||
|
||||
obj = update_objects[update.update_type].get(update.PartitionKey, update.RowKey)
|
||||
if not obj:
|
||||
logging.error("unable find to obj to update %s", update)
|
||||
return
|
||||
|
||||
if update.method and hasattr(obj, update.method):
|
||||
getattr(obj, update.method)()
|
||||
return
|
||||
else:
|
||||
state = getattr(obj, "state", None)
|
||||
if state is None:
|
||||
logging.error("queued update for object without state: %s", update)
|
||||
return
|
||||
func = getattr(obj, state.name, None)
|
||||
if func is None:
|
||||
logging.info("no function to implement state: %s", update)
|
||||
return
|
||||
func()
|
||||
return
|
||||
|
||||
raise NotImplementedError("unimplemented update type: %s" % update.update_type.name)
|
||||
|
||||
|
||||
def perform_update(update: Update) -> None:
|
||||
logging.info("performing queued update: %s", update)
|
||||
execute_update(update)
|
31
src/api-service/__app__/onefuzzlib/versions.py
Normal file
31
src/api-service/__app__/onefuzzlib/versions.py
Normal file
@ -0,0 +1,31 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
from memoization import cached
|
||||
from onefuzztypes.responses import Version
|
||||
|
||||
from .__version__ import __version__
|
||||
|
||||
|
||||
@cached
|
||||
def read_local_file(filename: str) -> str:
|
||||
path = os.path.join(os.path.dirname(os.path.realpath(__file__)), filename)
|
||||
if os.path.exists(path):
|
||||
with open(path, "r") as handle:
|
||||
return handle.read().strip()
|
||||
else:
|
||||
return "UNKNOWN"
|
||||
|
||||
|
||||
def versions() -> Dict[str, Version]:
|
||||
entry = Version(
|
||||
git=read_local_file("git.version"),
|
||||
build=read_local_file("build.id"),
|
||||
version=__version__,
|
||||
)
|
||||
return {"onefuzz": entry}
|
Reference in New Issue
Block a user