initial public release

This commit is contained in:
Brian Caswell
2020-09-18 12:21:04 -04:00
parent 9c3aa0bdfb
commit d3a0b292e6
387 changed files with 43810 additions and 28 deletions

View File

@ -0,0 +1,37 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import subprocess # nosec - used for ssh key generation
import tempfile
from typing import Tuple
from uuid import uuid4
from onefuzztypes.models import Authentication
def generate_keypair() -> Tuple[str, str]:
with tempfile.TemporaryDirectory() as tmpdir:
filename = os.path.join(tmpdir, "key")
cmd = ["ssh-keygen", "-t", "rsa", "-f", filename, "-P", "", "-b", "2048"]
subprocess.check_output(cmd) # nosec - all arguments are under our control
with open(filename, "r") as handle:
private = handle.read()
with open(filename + ".pub", "r") as handle:
public = handle.read().strip()
return (public, private)
def build_auth() -> Authentication:
public_key, private_key = generate_keypair()
auth = Authentication(
password=str(uuid4()), public_key=public_key, private_key=private_key
)
return auth

View File

@ -0,0 +1,155 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import datetime
import os
import urllib.parse
from typing import Any, Dict, Optional, Union, cast
from azure.common import AzureHttpError, AzureMissingResourceHttpError
from azure.storage.blob import BlobPermissions, ContainerPermissions
from .creds import get_blob_service
def get_containers(account_id: Optional[str] = None) -> Dict[str, Dict[str, str]]:
return {
x.name: x.metadata
for x in get_blob_service(account_id).list_containers(include_metadata=True)
if not x.name.startswith("$")
}
def get_container_metadata(
name: str, account_id: Optional[str] = None
) -> Optional[Dict[str, str]]:
try:
result = get_blob_service(account_id).get_container_metadata(name)
if result:
return cast(Dict[str, str], result)
except AzureHttpError:
pass
return None
def create_container(
name: str, metadata: Optional[Dict[str, str]], account_id: Optional[str] = None
) -> Optional[str]:
try:
get_blob_service(account_id).create_container(name, metadata=metadata)
except AzureHttpError:
# azure storage already logs errors
return None
return get_container_sas_url(
name, read=True, add=True, create=True, write=True, delete=True, list=True
)
def delete_container(name: str, account_id: Optional[str] = None) -> bool:
try:
return bool(get_blob_service(account_id).delete_container(name))
except AzureHttpError:
# azure storage already logs errors
return False
def get_container_sas_url(
container: str,
account_id: Optional[str] = None,
read: bool = False,
add: bool = False,
create: bool = False,
write: bool = False,
delete: bool = False,
list: bool = False,
) -> str:
service = get_blob_service(account_id)
expiry = datetime.datetime.utcnow() + datetime.timedelta(days=30)
permission = ContainerPermissions(read, add, create, write, delete, list)
sas_token = service.generate_container_shared_access_signature(
container, permission=permission, expiry=expiry
)
url = service.make_container_url(container, sas_token=sas_token)
url = url.replace("?restype=container&", "?")
return str(url)
def get_file_sas_url(
container: str,
name: str,
account_id: Optional[str] = None,
read: bool = False,
add: bool = False,
create: bool = False,
write: bool = False,
delete: bool = False,
list: bool = False,
days: int = 30,
hours: int = 0,
minutes: int = 0,
) -> str:
service = get_blob_service(account_id)
expiry = datetime.datetime.utcnow() + datetime.timedelta(
days=days, hours=hours, minutes=minutes
)
permission = BlobPermissions(read, add, create, write, delete, list)
sas_token = service.generate_blob_shared_access_signature(
container, name, permission=permission, expiry=expiry
)
url = service.make_blob_url(container, name, sas_token=sas_token)
return str(url)
def save_blob(
container: str, name: str, data: Union[str, bytes], account_id: Optional[str] = None
) -> None:
service = get_blob_service(account_id)
service.create_container(container)
if isinstance(data, str):
service.create_blob_from_text(container, name, data)
elif isinstance(data, bytes):
service.create_blob_from_bytes(container, name, data)
def get_blob(
container: str, name: str, account_id: Optional[str] = None
) -> Optional[Any]: # should be bytes
service = get_blob_service(account_id)
try:
blob = service.get_blob_to_bytes(container, name).content
return blob
except AzureMissingResourceHttpError:
return None
def blob_exists(container: str, name: str, account_id: Optional[str] = None) -> bool:
service = get_blob_service(account_id)
try:
service.get_blob_properties(container, name)
return True
except AzureMissingResourceHttpError:
return False
def delete_blob(container: str, name: str, account_id: Optional[str] = None) -> bool:
service = get_blob_service(account_id)
try:
service.delete_blob(container, name)
return True
except AzureMissingResourceHttpError:
return False
def auth_download_url(container: str, filename: str) -> str:
instance = os.environ["ONEFUZZ_INSTANCE"]
return "%s/api/download?%s" % (
instance,
urllib.parse.urlencode({"container": container, "filename": filename}),
)

View File

@ -0,0 +1,116 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
import os
from typing import Any, List, Optional, Tuple
from azure.cli.core import CLIError
from azure.common.client_factory import get_client_from_cli_profile
from azure.graphrbac import GraphRbacManagementClient
from azure.graphrbac.models import CheckGroupMembershipParameters
from azure.mgmt.resource import ResourceManagementClient
from azure.mgmt.storage import StorageManagementClient
from azure.mgmt.subscription import SubscriptionClient
from azure.storage.blob import BlockBlobService
from memoization import cached
from msrestazure.azure_active_directory import MSIAuthentication
from msrestazure.tools import parse_resource_id
from .monkeypatch import allow_more_workers
@cached(ttl=60)
def get_msi() -> MSIAuthentication:
return MSIAuthentication()
@cached(ttl=60)
def mgmt_client_factory(client_class: Any) -> Any:
allow_more_workers()
try:
return get_client_from_cli_profile(client_class)
except CLIError:
if issubclass(client_class, SubscriptionClient):
return client_class(get_msi())
else:
return client_class(get_msi(), get_subscription())
@cached(ttl=60)
def get_storage_account_name_key(account_id: Optional[str] = None) -> Tuple[str, str]:
db_client = mgmt_client_factory(StorageManagementClient)
if account_id is None:
account_id = os.environ["ONEFUZZ_DATA_STORAGE"]
resource = parse_resource_id(account_id)
key = (
db_client.storage_accounts.list_keys(
resource["resource_group"], resource["name"]
)
.keys[0]
.value
)
return resource["name"], key
@cached(ttl=60)
def get_blob_service(account_id: Optional[str] = None) -> BlockBlobService:
logging.info("getting blob container (account_id: %s)", account_id)
name, key = get_storage_account_name_key(account_id)
service = BlockBlobService(account_name=name, account_key=key)
return service
@cached
def get_base_resource_group() -> Any: # should be str
return parse_resource_id(os.environ["ONEFUZZ_RESOURCE_GROUP"])["resource_group"]
@cached
def get_base_region() -> Any: # should be str
client = mgmt_client_factory(ResourceManagementClient)
group = client.resource_groups.get(get_base_resource_group())
return group.location
@cached
def get_subscription() -> Any: # should be str
return parse_resource_id(os.environ["ONEFUZZ_DATA_STORAGE"])["subscription"]
@cached
def get_fuzz_storage() -> str:
return os.environ["ONEFUZZ_DATA_STORAGE"]
@cached
def get_func_storage() -> str:
return os.environ["ONEFUZZ_FUNC_STORAGE"]
@cached
def get_instance_name() -> str:
return os.environ["ONEFUZZ_INSTANCE_NAME"]
@cached(ttl=60)
def get_regions() -> List[str]:
client = mgmt_client_factory(SubscriptionClient)
subscription = get_subscription()
locations = client.subscriptions.list_locations(subscription)
return sorted([x.name for x in locations])
def get_graph_client() -> Any:
return mgmt_client_factory(GraphRbacManagementClient)
def is_member_of(group_id: str, member_id: str) -> bool:
client = get_graph_client()
return bool(
client.groups.is_member_of(
CheckGroupMembershipParameters(group_id=group_id, member_id=member_id)
).value
)

View File

@ -0,0 +1,29 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
from typing import Any
from azure.mgmt.compute import ComputeManagementClient
from msrestazure.azure_exceptions import CloudError
from .creds import mgmt_client_factory
def list_disks(resource_group: str) -> Any:
logging.info("listing disks %s", resource_group)
compute_client = mgmt_client_factory(ComputeManagementClient)
return compute_client.disks.list_by_resource_group(resource_group)
def delete_disk(resource_group: str, name: str) -> bool:
logging.info("deleting disks %s : %s", resource_group, name)
compute_client = mgmt_client_factory(ComputeManagementClient)
try:
compute_client.disks.delete(resource_group, name)
return True
except CloudError as err:
logging.error("unable to delete disk: %s", err)
return False

View File

@ -0,0 +1,42 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import Union
from azure.mgmt.compute import ComputeManagementClient
from memoization import cached
from msrestazure.azure_exceptions import CloudError
from msrestazure.tools import parse_resource_id
from onefuzztypes.enums import OS, ErrorCode
from onefuzztypes.models import Error
from onefuzztypes.primitives import Region
from .creds import mgmt_client_factory
@cached(ttl=60)
def get_os(region: Region, image: str) -> Union[Error, OS]:
client = mgmt_client_factory(ComputeManagementClient)
parsed = parse_resource_id(image)
if "resource_group" in parsed:
try:
name = client.images.get(
parsed["resource_group"], parsed["name"]
).storage_profile.os_disk.os_type.name
except CloudError as err:
return Error(code=ErrorCode.INVALID_IMAGE, errors=[str(err)])
else:
publisher, offer, sku, version = image.split(":")
try:
if version == "latest":
version = client.virtual_machine_images.list(
region, publisher, offer, sku, top=1
)[0].name
name = client.virtual_machine_images.get(
region, publisher, offer, sku, version
).os_disk_image.operating_system.name
except CloudError as err:
return Error(code=ErrorCode.INVALID_IMAGE, errors=[str(err)])
return OS[name]

View File

@ -0,0 +1,150 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
import os
from typing import Any, Dict, Optional, Union
from uuid import UUID
from azure.mgmt.network import NetworkManagementClient
from msrestazure.azure_exceptions import CloudError
from msrestazure.tools import parse_resource_id
from onefuzztypes.enums import ErrorCode
from onefuzztypes.models import Error
from .creds import get_base_resource_group, mgmt_client_factory
from .subnet import create_virtual_network, get_subnet_id
from .vmss import get_instance_id
def get_scaleset_instance_ip(scaleset: UUID, machine_id: UUID) -> Optional[str]:
instance = get_instance_id(scaleset, machine_id)
if not isinstance(instance, str):
return None
resource_group = get_base_resource_group()
client = mgmt_client_factory(NetworkManagementClient)
intf = client.network_interfaces.list_virtual_machine_scale_set_network_interfaces(
resource_group, str(scaleset)
)
try:
for interface in intf:
resource = parse_resource_id(interface.virtual_machine.id)
if resource.get("resource_name") != instance:
continue
for config in interface.ip_configurations:
if config.private_ip_address is None:
continue
return str(config.private_ip_address)
except CloudError:
# this can fail if an interface is removed during the iteration
pass
return None
def get_ip(resource_group: str, name: str) -> Optional[Any]:
logging.info("getting ip %s:%s", resource_group, name)
network_client = mgmt_client_factory(NetworkManagementClient)
try:
return network_client.public_ip_addresses.get(resource_group, name)
except CloudError:
return None
def delete_ip(resource_group: str, name: str) -> Any:
logging.info("deleting ip %s:%s", resource_group, name)
network_client = mgmt_client_factory(NetworkManagementClient)
return network_client.public_ip_addresses.delete(resource_group, name)
def create_ip(resource_group: str, name: str, location: str) -> Any:
logging.info("creating ip for %s:%s in %s", resource_group, name, location)
network_client = mgmt_client_factory(NetworkManagementClient)
params: Dict[str, Union[str, Dict[str, str]]] = {
"location": location,
"public_ip_allocation_method": "Dynamic",
}
if "ONEFUZZ_OWNER" in os.environ:
params["tags"] = {"OWNER": os.environ["ONEFUZZ_OWNER"]}
return network_client.public_ip_addresses.create_or_update(
resource_group, name, params
)
def get_public_nic(resource_group: str, name: str) -> Optional[Any]:
logging.info("getting nic: %s %s", resource_group, name)
network_client = mgmt_client_factory(NetworkManagementClient)
try:
return network_client.network_interfaces.get(resource_group, name)
except CloudError:
return None
def delete_nic(resource_group: str, name: str) -> Optional[Any]:
logging.info("deleting nic %s:%s", resource_group, name)
network_client = mgmt_client_factory(NetworkManagementClient)
return network_client.network_interfaces.delete(resource_group, name)
def create_public_nic(resource_group: str, name: str, location: str) -> Optional[Error]:
logging.info("creating nic for %s:%s in %s", resource_group, name, location)
network_client = mgmt_client_factory(NetworkManagementClient)
subnet_id = get_subnet_id(resource_group, location)
if not subnet_id:
return create_virtual_network(resource_group, location, location)
ip = get_ip(resource_group, name)
if not ip:
create_ip(resource_group, name, location)
return None
params = {
"location": location,
"ip_configurations": [
{
"name": "myIPConfig",
"public_ip_address": ip,
"subnet": {"id": subnet_id},
}
],
}
if "ONEFUZZ_OWNER" in os.environ:
params["tags"] = {"OWNER": os.environ["ONEFUZZ_OWNER"]}
try:
network_client.network_interfaces.create_or_update(resource_group, name, params)
except CloudError as err:
if "RetryableError" not in repr(err):
return Error(
code=ErrorCode.VM_CREATE_FAILED,
errors=["unable to create nic: %s" % err],
)
return None
def get_public_ip(resource_id: str) -> Optional[str]:
logging.info("getting ip for %s", resource_id)
network_client = mgmt_client_factory(NetworkManagementClient)
resource = parse_resource_id(resource_id)
ip = (
network_client.network_interfaces.get(
resource["resource_group"], resource["name"]
)
.ip_configurations[0]
.public_ip_address
)
resource = parse_resource_id(ip.id)
ip = network_client.public_ip_addresses.get(
resource["resource_group"], resource["name"]
).ip_address
if ip is None:
return None
else:
return str(ip)

View File

@ -0,0 +1,29 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
from typing import Any, Dict
from azure.mgmt.loganalytics import LogAnalyticsManagementClient
from memoization import cached
from .creds import get_base_resource_group, mgmt_client_factory
@cached(ttl=60)
def get_montior_client() -> Any:
return mgmt_client_factory(LogAnalyticsManagementClient)
@cached(ttl=60)
def get_monitor_settings() -> Dict[str, str]:
resource_group = get_base_resource_group()
workspace_name = os.environ["ONEFUZZ_MONITOR"]
client = get_montior_client()
customer_id = client.workspaces.get(resource_group, workspace_name).customer_id
shared_key = client.shared_keys.get_shared_keys(
resource_group, workspace_name
).primary_shared_key
return {"id": customer_id, "key": shared_key}

View File

@ -0,0 +1,24 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import inspect
import logging
WORKERS_DONE = False
def allow_more_workers() -> None:
global WORKERS_DONE
if WORKERS_DONE:
return
stack = inspect.stack()
for entry in stack:
if entry.filename.endswith("azure_functions_worker/dispatcher.py"):
if entry.frame.f_locals["self"]._sync_call_tp._max_workers == 1:
logging.info("bumped thread worker count to 50")
entry.frame.f_locals["self"]._sync_call_tp._max_workers = 50
WORKERS_DONE = True

View File

@ -0,0 +1,46 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
from typing import Optional, Union
from msrestazure.azure_exceptions import CloudError
from onefuzztypes.enums import ErrorCode
from onefuzztypes.models import Error
from onefuzztypes.primitives import Region
from .creds import get_base_resource_group
from .subnet import create_virtual_network, delete_subnet, get_subnet_id
class Network:
def __init__(self, region: Region):
self.group = get_base_resource_group()
self.region = region
def exists(self) -> bool:
return self.get_id() is not None
def get_id(self) -> Optional[str]:
return get_subnet_id(self.group, self.region)
def create(self) -> Union[None, Error]:
if not self.exists():
result = create_virtual_network(self.group, self.region, self.region)
if isinstance(result, CloudError):
error = Error(
code=ErrorCode.UNABLE_TO_CREATE_NETWORK, errors=[result.message]
)
logging.error(
"network creation failed: %s- %s",
self.region,
error,
)
return error
return None
def delete(self) -> None:
delete_subnet(self.group, self.region)

View File

@ -0,0 +1,153 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import base64
import datetime
import json
import logging
from typing import List, Optional, Type, TypeVar, Union
from uuid import UUID
from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError
from azure.storage.queue import (
QueueSasPermissions,
QueueServiceClient,
generate_queue_sas,
)
from memoization import cached
from pydantic import BaseModel
from .creds import get_storage_account_name_key
QueueNameType = Union[str, UUID]
@cached(ttl=60)
def get_queue_client(account_id: str) -> QueueServiceClient:
logging.info("getting blob container (account_id: %s)", account_id)
name, key = get_storage_account_name_key(account_id)
account_url = "https://%s.queue.core.windows.net" % name
client = QueueServiceClient(
account_url=account_url,
credential={"account_name": name, "account_key": key},
)
return client
def get_queue_sas(
queue: QueueNameType,
*,
account_id: str,
read: bool = False,
add: bool = False,
update: bool = False,
process: bool = False,
) -> str:
logging.info("getting queue sas %s (account_id: %s)", queue, account_id)
name, key = get_storage_account_name_key(account_id)
expiry = datetime.datetime.utcnow() + datetime.timedelta(days=30)
token = generate_queue_sas(
name,
str(queue),
key,
permission=QueueSasPermissions(
read=read, add=add, update=update, process=process
),
expiry=expiry,
)
url = "https://%s.queue.core.windows.net/%s?%s" % (name, queue, token)
return url
@cached(ttl=60)
def create_queue(name: QueueNameType, *, account_id: str) -> None:
client = get_queue_client(account_id)
try:
client.create_queue(str(name))
except ResourceExistsError:
pass
def delete_queue(name: QueueNameType, *, account_id: str) -> None:
client = get_queue_client(account_id)
queues = client.list_queues()
if str(name) in [x["name"] for x in queues]:
client.delete_queue(name)
def get_queue(name: QueueNameType, *, account_id: str) -> Optional[QueueServiceClient]:
client = get_queue_client(account_id)
try:
return client.get_queue_client(str(name))
except ResourceNotFoundError:
return None
def send_message(
name: QueueNameType,
message: bytes,
*,
account_id: str,
) -> None:
queue = get_queue(name, account_id=account_id)
if queue:
try:
queue.send_message(base64.b64encode(message).decode())
except ResourceNotFoundError:
pass
A = TypeVar("A", bound=BaseModel)
MIN_PEEK_SIZE = 1
MAX_PEEK_SIZE = 32
# Peek at a max of 32 messages
# https://docs.microsoft.com/en-us/python/api/azure-storage-queue/azure.storage.queue.queueclient
def peek_queue(
name: QueueNameType,
*,
account_id: str,
object_type: Type[A],
max_messages: int = MAX_PEEK_SIZE,
) -> List[A]:
result: List[A] = []
# message count
if max_messages < MIN_PEEK_SIZE or max_messages > MAX_PEEK_SIZE:
raise ValueError("invalid max messages: %s" % max_messages)
queue = get_queue(name, account_id=account_id)
if not queue:
return result
for message in queue.peek_messages(max_messages=max_messages):
decoded = base64.b64decode(message.content)
raw = json.loads(decoded)
result.append(object_type.parse_obj(raw))
return result
def queue_object(
name: QueueNameType,
message: BaseModel,
*,
account_id: str,
visibility_timeout: Optional[int] = None,
) -> bool:
queue = get_queue(name, account_id=account_id)
if not queue:
raise Exception("unable to queue object, no such queue: %s" % queue)
encoded = base64.b64encode(message.json(exclude_none=True).encode()).decode()
try:
queue.send_message(encoded, visibility_timeout=visibility_timeout)
return True
except ResourceNotFoundError:
return False

View File

@ -0,0 +1,70 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
import os
from typing import Any, Optional, Union, cast
from azure.mgmt.network import NetworkManagementClient
from msrestazure.azure_exceptions import CloudError
from onefuzztypes.enums import ErrorCode
from onefuzztypes.models import Error
from .creds import mgmt_client_factory
def get_subnet_id(resource_group: str, name: str) -> Optional[str]:
network_client = mgmt_client_factory(NetworkManagementClient)
try:
subnet = network_client.subnets.get(resource_group, name, name)
return cast(str, subnet.id)
except CloudError:
logging.info(
"subnet missing: resource group: %s name: %s",
resource_group,
name,
)
return None
def delete_subnet(resource_group: str, name: str) -> Union[None, CloudError, Any]:
network_client = mgmt_client_factory(NetworkManagementClient)
try:
return network_client.virtual_networks.delete(resource_group, name)
except CloudError as err:
if err.error and "InUseSubnetCannotBeDeleted" in str(err.error):
logging.error(
"subnet delete failed: %s %s : %s", resource_group, name, repr(err)
)
return None
else:
raise err
def create_virtual_network(
resource_group: str, name: str, location: str
) -> Optional[Error]:
logging.info(
"creating subnet - resource group: %s name: %s location: %s",
resource_group,
name,
location,
)
network_client = mgmt_client_factory(NetworkManagementClient)
params = {
"location": location,
"address_space": {"address_prefixes": ["10.0.0.0/8"]},
"subnets": [{"name": name, "address_prefix": "10.0.0.0/16"}],
}
if "ONEFUZZ_OWNER" in os.environ:
params["tags"] = {"OWNER": os.environ["ONEFUZZ_OWNER"]}
try:
network_client.virtual_networks.create_or_update(resource_group, name, params)
except CloudError as err:
return Error(code=ErrorCode.UNABLE_TO_CREATE_NETWORK, errors=[str(err.message)])
return None

View File

@ -0,0 +1,30 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
import os
from typing import Optional
from azure.cosmosdb.table import TableService
from memoization import cached
from .creds import get_storage_account_name_key
@cached(ttl=60)
def get_client(
table: Optional[str] = None, account_id: Optional[str] = None
) -> TableService:
if account_id is None:
account_id = os.environ["ONEFUZZ_FUNC_STORAGE"]
logging.info("getting table account: (account_id: %s)", account_id)
name, key = get_storage_account_name_key(account_id)
client = TableService(account_name=name, account_key=key)
if table and not client.exists(table):
logging.info("creating missing table %s", table)
client.create_table(table, fail_on_exist=False)
return client

View File

@ -0,0 +1,273 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
import os
from typing import Any, Dict, List, Optional, Union, cast
from uuid import UUID
from azure.mgmt.compute import ComputeManagementClient
from azure.mgmt.compute.models import VirtualMachine
from msrestazure.azure_exceptions import CloudError
from onefuzztypes.enums import OS, ErrorCode
from onefuzztypes.models import Authentication, Error
from onefuzztypes.primitives import Extension, Region
from pydantic import BaseModel
from .creds import get_base_resource_group, mgmt_client_factory
from .disk import delete_disk, list_disks
from .image import get_os
from .ip import create_public_nic, delete_ip, delete_nic, get_ip, get_public_nic
def get_vm(name: str) -> Optional[VirtualMachine]:
resource_group = get_base_resource_group()
logging.debug("getting vm: %s %s - %s", resource_group, name)
compute_client = mgmt_client_factory(ComputeManagementClient)
try:
return cast(
VirtualMachine,
compute_client.virtual_machines.get(
resource_group, name, expand="instanceView"
),
)
except CloudError as err:
logging.debug("vm does not exist %s", err)
return None
def create_vm(
name: str,
location: str,
vm_sku: str,
image: str,
password: str,
ssh_public_key: str,
) -> Union[None, Error]:
resource_group = get_base_resource_group()
logging.info("creating vm %s:%s:%s", resource_group, location, name)
compute_client = mgmt_client_factory(ComputeManagementClient)
nic = get_public_nic(resource_group, name)
if nic is None:
result = create_public_nic(resource_group, name, location)
if isinstance(result, Error):
return result
logging.info("waiting on nic creation")
return None
if image.startswith("/"):
image_ref = {"id": image}
else:
image_val = image.split(":", 4)
image_ref = {
"publisher": image_val[0],
"offer": image_val[1],
"sku": image_val[2],
"version": image_val[3],
}
params: Dict = {
"location": location,
"os_profile": {
"computer_name": "node",
"admin_username": "onefuzz",
"admin_password": password,
},
"hardware_profile": {"vm_size": vm_sku},
"storage_profile": {"image_reference": image_ref},
"network_profile": {"network_interfaces": [{"id": nic.id}]},
}
image_os = get_os(location, image)
if isinstance(image_os, Error):
return image_os
if image_os == OS.linux:
params["os_profile"]["linux_configuration"] = {
"disable_password_authentication": True,
"ssh": {
"public_keys": [
{
"path": "/home/onefuzz/.ssh/authorized_keys",
"key_data": ssh_public_key,
}
]
},
}
if "ONEFUZZ_OWNER" in os.environ:
params["tags"] = {"OWNER": os.environ["ONEFUZZ_OWNER"]}
try:
compute_client.virtual_machines.create_or_update(resource_group, name, params)
except CloudError as err:
if "The request failed due to conflict with a concurrent request" in str(err):
logging.debug(
"create VM had conflicts with concurrent request, ignoring %s", err
)
return None
return Error(code=ErrorCode.VM_CREATE_FAILED, errors=[str(err)])
return None
def get_extension(vm_name: str, extension_name: str) -> Optional[Any]:
resource_group = get_base_resource_group()
logging.debug(
"getting extension: %s:%s:%s - %s",
resource_group,
vm_name,
extension_name,
)
compute_client = mgmt_client_factory(ComputeManagementClient)
try:
return compute_client.virtual_machine_extensions.get(
resource_group, vm_name, extension_name
)
except CloudError as err:
logging.error("extension does not exist %s", err)
return None
def create_extension(vm_name: str, extension: Dict) -> Any:
resource_group = get_base_resource_group()
logging.info(
"creating extension: %s:%s:%s", resource_group, vm_name, extension["name"]
)
compute_client = mgmt_client_factory(ComputeManagementClient)
return compute_client.virtual_machine_extensions.create_or_update(
resource_group, vm_name, extension["name"], extension
)
def delete_vm(name: str) -> Any:
resource_group = get_base_resource_group()
logging.info("deleting vm: %s %s", resource_group, name)
compute_client = mgmt_client_factory(ComputeManagementClient)
return compute_client.virtual_machines.delete(resource_group, name)
def has_components(name: str) -> bool:
# check if any of the components associated with a VM still exist.
#
# Azure VM Deletion requires we first delete the VM, then delete all of it's
# resources. This is required to ensure we've cleaned it all up before
# marking it "done"
resource_group = get_base_resource_group()
if get_vm(name):
return True
if get_public_nic(resource_group, name):
return True
if get_ip(resource_group, name):
return True
disks = [x.name for x in list_disks(resource_group) if x.name.startswith(name)]
if disks:
return True
return False
def delete_vm_components(name: str) -> bool:
resource_group = get_base_resource_group()
logging.info("deleting vm components %s:%s", resource_group, name)
if get_vm(name):
logging.info("deleting vm %s:%s", resource_group, name)
delete_vm(name)
return False
if get_public_nic(resource_group, name):
logging.info("deleting nic %s:%s", resource_group, name)
delete_nic(resource_group, name)
return False
if get_ip(resource_group, name):
logging.info("deleting ip %s:%s", resource_group, name)
delete_ip(resource_group, name)
return False
disks = [x.name for x in list_disks(resource_group) if x.name.startswith(name)]
if disks:
for disk in disks:
logging.info("deleting disk %s:%s", resource_group, disk)
delete_disk(resource_group, disk)
return False
return True
class VM(BaseModel):
name: Union[UUID, str]
region: Region
sku: str
image: str
auth: Authentication
def is_deleted(self) -> bool:
return has_components(str(self.name))
def exists(self) -> bool:
return self.get() is not None
def get(self) -> Optional[VirtualMachine]:
return get_vm(str(self.name))
def create(self) -> Union[None, Error]:
if self.get() is not None:
return None
logging.info("vm creating: %s", self.name)
return create_vm(
str(self.name),
self.region,
self.sku,
self.image,
self.auth.password,
self.auth.public_key,
)
def delete(self) -> bool:
return delete_vm_components(str(self.name))
def add_extensions(self, extensions: List[Extension]) -> Union[bool, Error]:
status = []
to_create = []
for config in extensions:
if not isinstance(config["name"], str):
logging.error("vm agent - incompatable name: %s", repr(config))
continue
extension = get_extension(str(self.name), config["name"])
if extension:
logging.info(
"vm extension state: %s - %s - %s",
self.name,
config["name"],
extension.provisioning_state,
)
status.append(extension.provisioning_state)
else:
to_create.append(config)
if to_create:
for config in to_create:
create_extension(str(self.name), config)
else:
if all([x == "Succeeded" for x in status]):
return True
elif "Failed" in status:
return Error(
code=ErrorCode.VM_CREATE_FAILED,
errors=["failed to launch extension"],
)
elif not ("Creating" in status or "Updating" in status):
logging.error("vm agent - unknown state %s: %s", self.name, status)
return False

View File

@ -0,0 +1,334 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
import os
from typing import Any, Dict, List, Optional, Union, cast
from uuid import UUID
from azure.mgmt.compute import ComputeManagementClient
from azure.mgmt.compute.models import ResourceSku, ResourceSkuRestrictionsType
from memoization import cached
from msrestazure.azure_exceptions import CloudError
from onefuzztypes.enums import OS, ErrorCode
from onefuzztypes.models import Error
from onefuzztypes.primitives import Region
from .creds import get_base_resource_group, mgmt_client_factory
from .image import get_os
def list_vmss(name: UUID) -> Optional[List[str]]:
resource_group = get_base_resource_group()
client = mgmt_client_factory(ComputeManagementClient)
try:
instances = [
x.instance_id
for x in client.virtual_machine_scale_set_vms.list(
resource_group, str(name)
)
]
return instances
except CloudError as err:
logging.error("cloud error listing vmss: %s (%s)", name, err)
return None
def delete_vmss(name: UUID) -> Any:
resource_group = get_base_resource_group()
compute_client = mgmt_client_factory(ComputeManagementClient)
try:
compute_client.virtual_machine_scale_sets.delete(resource_group, str(name))
except CloudError as err:
logging.error("cloud error deleting vmss: %s (%s)", name, err)
def get_vmss(name: UUID) -> Optional[Any]:
resource_group = get_base_resource_group()
logging.debug("getting vm: %s", name)
compute_client = mgmt_client_factory(ComputeManagementClient)
try:
return compute_client.virtual_machine_scale_sets.get(resource_group, str(name))
except CloudError as err:
logging.debug("vm does not exist %s", err)
return None
def resize_vmss(name: UUID, capacity: int) -> None:
check_can_update(name)
resource_group = get_base_resource_group()
logging.info("updating VM count - name: %s vm_count: %d", name, capacity)
compute_client = mgmt_client_factory(ComputeManagementClient)
compute_client.virtual_machine_scale_sets.update(
resource_group, str(name), {"sku": {"capacity": capacity}}
)
def get_vmss_size(name: UUID) -> Optional[int]:
vmss = get_vmss(name)
if vmss is None:
return None
return cast(int, vmss.sku.capacity)
def list_instance_ids(name: UUID) -> Dict[UUID, str]:
logging.info("get instance IDs for scaleset: %s", name)
resource_group = get_base_resource_group()
compute_client = mgmt_client_factory(ComputeManagementClient)
results = {}
try:
for instance in compute_client.virtual_machine_scale_set_vms.list(
resource_group, str(name)
):
results[UUID(instance.vm_id)] = cast(str, instance.instance_id)
except CloudError:
logging.debug("scaleset not available: %s", name)
return results
@cached(ttl=60)
def get_instance_id(name: UUID, vm_id: UUID) -> Union[str, Error]:
resource_group = get_base_resource_group()
logging.info("get instance ID for scaleset node: %s:%s", name, vm_id)
compute_client = mgmt_client_factory(ComputeManagementClient)
vm_id_str = str(vm_id)
for instance in compute_client.virtual_machine_scale_set_vms.list(
resource_group, str(name)
):
if instance.vm_id == vm_id_str:
return cast(str, instance.instance_id)
return Error(
code=ErrorCode.UNABLE_TO_FIND,
errors=["unable to find scaleset machine: %s:%s" % (name, vm_id)],
)
class UnableToUpdate(Exception):
pass
def check_can_update(name: UUID) -> Any:
vmss = get_vmss(name)
if vmss is None:
raise UnableToUpdate
if vmss.provisioning_state != "Succeeded":
raise UnableToUpdate
return vmss
def reimage_vmss_nodes(name: UUID, vm_ids: List[UUID]) -> Optional[Error]:
check_can_update(name)
resource_group = get_base_resource_group()
logging.info("reimaging scaleset VM - name: %s vm_ids:%s", name, vm_ids)
compute_client = mgmt_client_factory(ComputeManagementClient)
instance_ids = []
machine_to_id = list_instance_ids(name)
for vm_id in vm_ids:
if vm_id in machine_to_id:
instance_ids.append(machine_to_id[vm_id])
else:
logging.info("unable to find vm_id for %s:%s", name, vm_id)
if instance_ids:
compute_client.virtual_machine_scale_sets.reimage_all(
resource_group, str(name), instance_ids=instance_ids
)
return None
def delete_vmss_nodes(name: UUID, vm_ids: List[UUID]) -> Optional[Error]:
check_can_update(name)
resource_group = get_base_resource_group()
logging.info("deleting scaleset VM - name: %s vm_ids:%s", name, vm_ids)
compute_client = mgmt_client_factory(ComputeManagementClient)
instance_ids = []
machine_to_id = list_instance_ids(name)
for vm_id in vm_ids:
if vm_id in machine_to_id:
instance_ids.append(machine_to_id[vm_id])
else:
logging.info("unable to find vm_id for %s:%s", name, vm_id)
if instance_ids:
compute_client.virtual_machine_scale_sets.delete_instances(
resource_group, str(name), instance_ids=instance_ids
)
return None
def update_extensions(name: UUID, extensions: List[Any]) -> None:
check_can_update(name)
resource_group = get_base_resource_group()
logging.info("updating VM extensions: %s", name)
compute_client = mgmt_client_factory(ComputeManagementClient)
compute_client.virtual_machine_scale_sets.update(
resource_group,
str(name),
{"virtual_machine_profile": {"extension_profile": {"extensions": extensions}}},
)
def create_vmss(
location: Region,
name: UUID,
vm_sku: str,
vm_count: int,
image: str,
network_id: str,
spot_instances: bool,
extensions: List[Any],
password: str,
ssh_public_key: str,
tags: Dict[str, str],
) -> Optional[Error]:
vmss = get_vmss(name)
if vmss is not None:
return None
logging.info(
"creating VM count"
"name: %s vm_sku: %s vm_count: %d "
"image: %s subnet: %s spot_instances: %s",
name,
vm_sku,
vm_count,
image,
network_id,
spot_instances,
)
resource_group = get_base_resource_group()
compute_client = mgmt_client_factory(ComputeManagementClient)
if image.startswith("/"):
image_ref = {"id": image}
else:
image_val = image.split(":", 4)
image_ref = {
"publisher": image_val[0],
"offer": image_val[1],
"sku": image_val[2],
"version": image_val[3],
}
sku = {"name": vm_sku, "tier": "Standard", "capacity": vm_count}
params: Dict[str, Any] = {
"location": location,
"do_not_run_extensions_on_overprovisioned_vms": True,
"upgrade_policy": {"mode": "Manual"},
"sku": sku,
"identity": {"type": "SystemAssigned"},
"virtual_machine_profile": {
"priority": "Regular",
"storage_profile": {"image_reference": image_ref},
"os_profile": {
"computer_name_prefix": "node",
"admin_username": "onefuzz",
"admin_password": password,
},
"network_profile": {
"network_interface_configurations": [
{
"name": "onefuzz-nic",
"primary": True,
"ip_configurations": [
{"name": "onefuzz-ip-config", "subnet": {"id": network_id}}
],
}
]
},
"extension_profile": {"extensions": extensions},
},
}
image_os = get_os(location, image)
if isinstance(image_os, Error):
return image_os
if image_os == OS.linux:
params["virtual_machine_profile"]["os_profile"]["linux_configuration"] = {
"disable_password_authentication": True,
"ssh": {
"public_keys": [
{
"path": "/home/onefuzz/.ssh/authorized_keys",
"key_data": ssh_public_key,
}
]
},
}
if spot_instances:
# Setting max price to -1 means it won't be evicted because of
# price.
#
# https://docs.microsoft.com/en-us/azure/
# virtual-machine-scale-sets/use-spot#resource-manager-templates
params["virtual_machine_profile"].update(
{
"eviction_policy": "Delete",
"priority": "Spot",
"billing_profile": {"max_price": -1},
}
)
params["tags"] = tags.copy()
owner = os.environ.get("ONEFUZZ_OWNER")
if owner:
params["tags"]["OWNER"] = owner
try:
compute_client.virtual_machine_scale_sets.create_or_update(
resource_group, name, params
)
except CloudError as err:
if "The request failed due to conflict with a concurrent request" in repr(err):
logging.debug(
"create VM had conflicts with concurrent request, ignoring %s", err
)
return None
return Error(
code=ErrorCode.VM_CREATE_FAILED,
errors=["creating vmss: %s" % err],
)
return None
@cached(ttl=60)
def list_available_skus(location: str) -> List[str]:
compute_client = mgmt_client_factory(ComputeManagementClient)
skus: List[ResourceSku] = list(
compute_client.resource_skus.list(filter="location eq '%s'" % location)
)
sku_names: List[str] = []
for sku in skus:
available = True
if sku.restrictions is not None:
for restriction in sku.restrictions:
if restriction.type == ResourceSkuRestrictionsType.location and (
location.upper() in [v.upper() for v in restriction.values]
):
available = False
break
if available:
sku_names.append(sku.name)
return sku_names