mirror of
https://github.com/microsoft/onefuzz.git
synced 2025-06-17 04:18:07 +00:00
initial public release
This commit is contained in:
37
src/api-service/__app__/onefuzzlib/azure/auth.py
Normal file
37
src/api-service/__app__/onefuzzlib/azure/auth.py
Normal file
@ -0,0 +1,37 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import subprocess # nosec - used for ssh key generation
|
||||
import tempfile
|
||||
from typing import Tuple
|
||||
from uuid import uuid4
|
||||
|
||||
from onefuzztypes.models import Authentication
|
||||
|
||||
|
||||
def generate_keypair() -> Tuple[str, str]:
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
filename = os.path.join(tmpdir, "key")
|
||||
|
||||
cmd = ["ssh-keygen", "-t", "rsa", "-f", filename, "-P", "", "-b", "2048"]
|
||||
subprocess.check_output(cmd) # nosec - all arguments are under our control
|
||||
|
||||
with open(filename, "r") as handle:
|
||||
private = handle.read()
|
||||
|
||||
with open(filename + ".pub", "r") as handle:
|
||||
public = handle.read().strip()
|
||||
|
||||
return (public, private)
|
||||
|
||||
|
||||
def build_auth() -> Authentication:
|
||||
public_key, private_key = generate_keypair()
|
||||
auth = Authentication(
|
||||
password=str(uuid4()), public_key=public_key, private_key=private_key
|
||||
)
|
||||
return auth
|
155
src/api-service/__app__/onefuzzlib/azure/containers.py
Normal file
155
src/api-service/__app__/onefuzzlib/azure/containers.py
Normal file
@ -0,0 +1,155 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import datetime
|
||||
import os
|
||||
import urllib.parse
|
||||
from typing import Any, Dict, Optional, Union, cast
|
||||
|
||||
from azure.common import AzureHttpError, AzureMissingResourceHttpError
|
||||
from azure.storage.blob import BlobPermissions, ContainerPermissions
|
||||
|
||||
from .creds import get_blob_service
|
||||
|
||||
|
||||
def get_containers(account_id: Optional[str] = None) -> Dict[str, Dict[str, str]]:
|
||||
return {
|
||||
x.name: x.metadata
|
||||
for x in get_blob_service(account_id).list_containers(include_metadata=True)
|
||||
if not x.name.startswith("$")
|
||||
}
|
||||
|
||||
|
||||
def get_container_metadata(
|
||||
name: str, account_id: Optional[str] = None
|
||||
) -> Optional[Dict[str, str]]:
|
||||
try:
|
||||
result = get_blob_service(account_id).get_container_metadata(name)
|
||||
if result:
|
||||
return cast(Dict[str, str], result)
|
||||
except AzureHttpError:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def create_container(
|
||||
name: str, metadata: Optional[Dict[str, str]], account_id: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
try:
|
||||
get_blob_service(account_id).create_container(name, metadata=metadata)
|
||||
except AzureHttpError:
|
||||
# azure storage already logs errors
|
||||
return None
|
||||
|
||||
return get_container_sas_url(
|
||||
name, read=True, add=True, create=True, write=True, delete=True, list=True
|
||||
)
|
||||
|
||||
|
||||
def delete_container(name: str, account_id: Optional[str] = None) -> bool:
|
||||
try:
|
||||
return bool(get_blob_service(account_id).delete_container(name))
|
||||
except AzureHttpError:
|
||||
# azure storage already logs errors
|
||||
return False
|
||||
|
||||
|
||||
def get_container_sas_url(
|
||||
container: str,
|
||||
account_id: Optional[str] = None,
|
||||
read: bool = False,
|
||||
add: bool = False,
|
||||
create: bool = False,
|
||||
write: bool = False,
|
||||
delete: bool = False,
|
||||
list: bool = False,
|
||||
) -> str:
|
||||
service = get_blob_service(account_id)
|
||||
expiry = datetime.datetime.utcnow() + datetime.timedelta(days=30)
|
||||
permission = ContainerPermissions(read, add, create, write, delete, list)
|
||||
|
||||
sas_token = service.generate_container_shared_access_signature(
|
||||
container, permission=permission, expiry=expiry
|
||||
)
|
||||
|
||||
url = service.make_container_url(container, sas_token=sas_token)
|
||||
url = url.replace("?restype=container&", "?")
|
||||
return str(url)
|
||||
|
||||
|
||||
def get_file_sas_url(
|
||||
container: str,
|
||||
name: str,
|
||||
account_id: Optional[str] = None,
|
||||
read: bool = False,
|
||||
add: bool = False,
|
||||
create: bool = False,
|
||||
write: bool = False,
|
||||
delete: bool = False,
|
||||
list: bool = False,
|
||||
days: int = 30,
|
||||
hours: int = 0,
|
||||
minutes: int = 0,
|
||||
) -> str:
|
||||
service = get_blob_service(account_id)
|
||||
expiry = datetime.datetime.utcnow() + datetime.timedelta(
|
||||
days=days, hours=hours, minutes=minutes
|
||||
)
|
||||
permission = BlobPermissions(read, add, create, write, delete, list)
|
||||
|
||||
sas_token = service.generate_blob_shared_access_signature(
|
||||
container, name, permission=permission, expiry=expiry
|
||||
)
|
||||
|
||||
url = service.make_blob_url(container, name, sas_token=sas_token)
|
||||
return str(url)
|
||||
|
||||
|
||||
def save_blob(
|
||||
container: str, name: str, data: Union[str, bytes], account_id: Optional[str] = None
|
||||
) -> None:
|
||||
service = get_blob_service(account_id)
|
||||
service.create_container(container)
|
||||
if isinstance(data, str):
|
||||
service.create_blob_from_text(container, name, data)
|
||||
elif isinstance(data, bytes):
|
||||
service.create_blob_from_bytes(container, name, data)
|
||||
|
||||
|
||||
def get_blob(
|
||||
container: str, name: str, account_id: Optional[str] = None
|
||||
) -> Optional[Any]: # should be bytes
|
||||
service = get_blob_service(account_id)
|
||||
try:
|
||||
blob = service.get_blob_to_bytes(container, name).content
|
||||
return blob
|
||||
except AzureMissingResourceHttpError:
|
||||
return None
|
||||
|
||||
|
||||
def blob_exists(container: str, name: str, account_id: Optional[str] = None) -> bool:
|
||||
service = get_blob_service(account_id)
|
||||
try:
|
||||
service.get_blob_properties(container, name)
|
||||
return True
|
||||
except AzureMissingResourceHttpError:
|
||||
return False
|
||||
|
||||
|
||||
def delete_blob(container: str, name: str, account_id: Optional[str] = None) -> bool:
|
||||
service = get_blob_service(account_id)
|
||||
try:
|
||||
service.delete_blob(container, name)
|
||||
return True
|
||||
except AzureMissingResourceHttpError:
|
||||
return False
|
||||
|
||||
|
||||
def auth_download_url(container: str, filename: str) -> str:
|
||||
instance = os.environ["ONEFUZZ_INSTANCE"]
|
||||
return "%s/api/download?%s" % (
|
||||
instance,
|
||||
urllib.parse.urlencode({"container": container, "filename": filename}),
|
||||
)
|
116
src/api-service/__app__/onefuzzlib/azure/creds.py
Normal file
116
src/api-service/__app__/onefuzzlib/azure/creds.py
Normal file
@ -0,0 +1,116 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
from azure.cli.core import CLIError
|
||||
from azure.common.client_factory import get_client_from_cli_profile
|
||||
from azure.graphrbac import GraphRbacManagementClient
|
||||
from azure.graphrbac.models import CheckGroupMembershipParameters
|
||||
from azure.mgmt.resource import ResourceManagementClient
|
||||
from azure.mgmt.storage import StorageManagementClient
|
||||
from azure.mgmt.subscription import SubscriptionClient
|
||||
from azure.storage.blob import BlockBlobService
|
||||
from memoization import cached
|
||||
from msrestazure.azure_active_directory import MSIAuthentication
|
||||
from msrestazure.tools import parse_resource_id
|
||||
|
||||
from .monkeypatch import allow_more_workers
|
||||
|
||||
|
||||
@cached(ttl=60)
|
||||
def get_msi() -> MSIAuthentication:
|
||||
return MSIAuthentication()
|
||||
|
||||
|
||||
@cached(ttl=60)
|
||||
def mgmt_client_factory(client_class: Any) -> Any:
|
||||
allow_more_workers()
|
||||
try:
|
||||
return get_client_from_cli_profile(client_class)
|
||||
except CLIError:
|
||||
if issubclass(client_class, SubscriptionClient):
|
||||
return client_class(get_msi())
|
||||
else:
|
||||
return client_class(get_msi(), get_subscription())
|
||||
|
||||
|
||||
@cached(ttl=60)
|
||||
def get_storage_account_name_key(account_id: Optional[str] = None) -> Tuple[str, str]:
|
||||
db_client = mgmt_client_factory(StorageManagementClient)
|
||||
if account_id is None:
|
||||
account_id = os.environ["ONEFUZZ_DATA_STORAGE"]
|
||||
resource = parse_resource_id(account_id)
|
||||
key = (
|
||||
db_client.storage_accounts.list_keys(
|
||||
resource["resource_group"], resource["name"]
|
||||
)
|
||||
.keys[0]
|
||||
.value
|
||||
)
|
||||
return resource["name"], key
|
||||
|
||||
|
||||
@cached(ttl=60)
|
||||
def get_blob_service(account_id: Optional[str] = None) -> BlockBlobService:
|
||||
logging.info("getting blob container (account_id: %s)", account_id)
|
||||
name, key = get_storage_account_name_key(account_id)
|
||||
service = BlockBlobService(account_name=name, account_key=key)
|
||||
return service
|
||||
|
||||
|
||||
@cached
|
||||
def get_base_resource_group() -> Any: # should be str
|
||||
return parse_resource_id(os.environ["ONEFUZZ_RESOURCE_GROUP"])["resource_group"]
|
||||
|
||||
|
||||
@cached
|
||||
def get_base_region() -> Any: # should be str
|
||||
client = mgmt_client_factory(ResourceManagementClient)
|
||||
group = client.resource_groups.get(get_base_resource_group())
|
||||
return group.location
|
||||
|
||||
|
||||
@cached
|
||||
def get_subscription() -> Any: # should be str
|
||||
return parse_resource_id(os.environ["ONEFUZZ_DATA_STORAGE"])["subscription"]
|
||||
|
||||
|
||||
@cached
|
||||
def get_fuzz_storage() -> str:
|
||||
return os.environ["ONEFUZZ_DATA_STORAGE"]
|
||||
|
||||
|
||||
@cached
|
||||
def get_func_storage() -> str:
|
||||
return os.environ["ONEFUZZ_FUNC_STORAGE"]
|
||||
|
||||
|
||||
@cached
|
||||
def get_instance_name() -> str:
|
||||
return os.environ["ONEFUZZ_INSTANCE_NAME"]
|
||||
|
||||
|
||||
@cached(ttl=60)
|
||||
def get_regions() -> List[str]:
|
||||
client = mgmt_client_factory(SubscriptionClient)
|
||||
subscription = get_subscription()
|
||||
locations = client.subscriptions.list_locations(subscription)
|
||||
return sorted([x.name for x in locations])
|
||||
|
||||
|
||||
def get_graph_client() -> Any:
|
||||
return mgmt_client_factory(GraphRbacManagementClient)
|
||||
|
||||
|
||||
def is_member_of(group_id: str, member_id: str) -> bool:
|
||||
client = get_graph_client()
|
||||
return bool(
|
||||
client.groups.is_member_of(
|
||||
CheckGroupMembershipParameters(group_id=group_id, member_id=member_id)
|
||||
).value
|
||||
)
|
29
src/api-service/__app__/onefuzzlib/azure/disk.py
Normal file
29
src/api-service/__app__/onefuzzlib/azure/disk.py
Normal file
@ -0,0 +1,29 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from azure.mgmt.compute import ComputeManagementClient
|
||||
from msrestazure.azure_exceptions import CloudError
|
||||
|
||||
from .creds import mgmt_client_factory
|
||||
|
||||
|
||||
def list_disks(resource_group: str) -> Any:
|
||||
logging.info("listing disks %s", resource_group)
|
||||
compute_client = mgmt_client_factory(ComputeManagementClient)
|
||||
return compute_client.disks.list_by_resource_group(resource_group)
|
||||
|
||||
|
||||
def delete_disk(resource_group: str, name: str) -> bool:
|
||||
logging.info("deleting disks %s : %s", resource_group, name)
|
||||
compute_client = mgmt_client_factory(ComputeManagementClient)
|
||||
try:
|
||||
compute_client.disks.delete(resource_group, name)
|
||||
return True
|
||||
except CloudError as err:
|
||||
logging.error("unable to delete disk: %s", err)
|
||||
return False
|
42
src/api-service/__app__/onefuzzlib/azure/image.py
Normal file
42
src/api-service/__app__/onefuzzlib/azure/image.py
Normal file
@ -0,0 +1,42 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from typing import Union
|
||||
|
||||
from azure.mgmt.compute import ComputeManagementClient
|
||||
from memoization import cached
|
||||
from msrestazure.azure_exceptions import CloudError
|
||||
from msrestazure.tools import parse_resource_id
|
||||
from onefuzztypes.enums import OS, ErrorCode
|
||||
from onefuzztypes.models import Error
|
||||
from onefuzztypes.primitives import Region
|
||||
|
||||
from .creds import mgmt_client_factory
|
||||
|
||||
|
||||
@cached(ttl=60)
|
||||
def get_os(region: Region, image: str) -> Union[Error, OS]:
|
||||
client = mgmt_client_factory(ComputeManagementClient)
|
||||
parsed = parse_resource_id(image)
|
||||
if "resource_group" in parsed:
|
||||
try:
|
||||
name = client.images.get(
|
||||
parsed["resource_group"], parsed["name"]
|
||||
).storage_profile.os_disk.os_type.name
|
||||
except CloudError as err:
|
||||
return Error(code=ErrorCode.INVALID_IMAGE, errors=[str(err)])
|
||||
else:
|
||||
publisher, offer, sku, version = image.split(":")
|
||||
try:
|
||||
if version == "latest":
|
||||
version = client.virtual_machine_images.list(
|
||||
region, publisher, offer, sku, top=1
|
||||
)[0].name
|
||||
name = client.virtual_machine_images.get(
|
||||
region, publisher, offer, sku, version
|
||||
).os_disk_image.operating_system.name
|
||||
except CloudError as err:
|
||||
return Error(code=ErrorCode.INVALID_IMAGE, errors=[str(err)])
|
||||
return OS[name]
|
150
src/api-service/__app__/onefuzzlib/azure/ip.py
Normal file
150
src/api-service/__app__/onefuzzlib/azure/ip.py
Normal file
@ -0,0 +1,150 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from azure.mgmt.network import NetworkManagementClient
|
||||
from msrestazure.azure_exceptions import CloudError
|
||||
from msrestazure.tools import parse_resource_id
|
||||
from onefuzztypes.enums import ErrorCode
|
||||
from onefuzztypes.models import Error
|
||||
|
||||
from .creds import get_base_resource_group, mgmt_client_factory
|
||||
from .subnet import create_virtual_network, get_subnet_id
|
||||
from .vmss import get_instance_id
|
||||
|
||||
|
||||
def get_scaleset_instance_ip(scaleset: UUID, machine_id: UUID) -> Optional[str]:
|
||||
instance = get_instance_id(scaleset, machine_id)
|
||||
if not isinstance(instance, str):
|
||||
return None
|
||||
|
||||
resource_group = get_base_resource_group()
|
||||
|
||||
client = mgmt_client_factory(NetworkManagementClient)
|
||||
intf = client.network_interfaces.list_virtual_machine_scale_set_network_interfaces(
|
||||
resource_group, str(scaleset)
|
||||
)
|
||||
try:
|
||||
for interface in intf:
|
||||
resource = parse_resource_id(interface.virtual_machine.id)
|
||||
if resource.get("resource_name") != instance:
|
||||
continue
|
||||
|
||||
for config in interface.ip_configurations:
|
||||
if config.private_ip_address is None:
|
||||
continue
|
||||
return str(config.private_ip_address)
|
||||
except CloudError:
|
||||
# this can fail if an interface is removed during the iteration
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_ip(resource_group: str, name: str) -> Optional[Any]:
|
||||
logging.info("getting ip %s:%s", resource_group, name)
|
||||
network_client = mgmt_client_factory(NetworkManagementClient)
|
||||
try:
|
||||
return network_client.public_ip_addresses.get(resource_group, name)
|
||||
except CloudError:
|
||||
return None
|
||||
|
||||
|
||||
def delete_ip(resource_group: str, name: str) -> Any:
|
||||
logging.info("deleting ip %s:%s", resource_group, name)
|
||||
network_client = mgmt_client_factory(NetworkManagementClient)
|
||||
return network_client.public_ip_addresses.delete(resource_group, name)
|
||||
|
||||
|
||||
def create_ip(resource_group: str, name: str, location: str) -> Any:
|
||||
logging.info("creating ip for %s:%s in %s", resource_group, name, location)
|
||||
|
||||
network_client = mgmt_client_factory(NetworkManagementClient)
|
||||
params: Dict[str, Union[str, Dict[str, str]]] = {
|
||||
"location": location,
|
||||
"public_ip_allocation_method": "Dynamic",
|
||||
}
|
||||
if "ONEFUZZ_OWNER" in os.environ:
|
||||
params["tags"] = {"OWNER": os.environ["ONEFUZZ_OWNER"]}
|
||||
return network_client.public_ip_addresses.create_or_update(
|
||||
resource_group, name, params
|
||||
)
|
||||
|
||||
|
||||
def get_public_nic(resource_group: str, name: str) -> Optional[Any]:
|
||||
logging.info("getting nic: %s %s", resource_group, name)
|
||||
network_client = mgmt_client_factory(NetworkManagementClient)
|
||||
try:
|
||||
return network_client.network_interfaces.get(resource_group, name)
|
||||
except CloudError:
|
||||
return None
|
||||
|
||||
|
||||
def delete_nic(resource_group: str, name: str) -> Optional[Any]:
|
||||
logging.info("deleting nic %s:%s", resource_group, name)
|
||||
network_client = mgmt_client_factory(NetworkManagementClient)
|
||||
return network_client.network_interfaces.delete(resource_group, name)
|
||||
|
||||
|
||||
def create_public_nic(resource_group: str, name: str, location: str) -> Optional[Error]:
|
||||
logging.info("creating nic for %s:%s in %s", resource_group, name, location)
|
||||
|
||||
network_client = mgmt_client_factory(NetworkManagementClient)
|
||||
subnet_id = get_subnet_id(resource_group, location)
|
||||
if not subnet_id:
|
||||
return create_virtual_network(resource_group, location, location)
|
||||
|
||||
ip = get_ip(resource_group, name)
|
||||
if not ip:
|
||||
create_ip(resource_group, name, location)
|
||||
return None
|
||||
|
||||
params = {
|
||||
"location": location,
|
||||
"ip_configurations": [
|
||||
{
|
||||
"name": "myIPConfig",
|
||||
"public_ip_address": ip,
|
||||
"subnet": {"id": subnet_id},
|
||||
}
|
||||
],
|
||||
}
|
||||
if "ONEFUZZ_OWNER" in os.environ:
|
||||
params["tags"] = {"OWNER": os.environ["ONEFUZZ_OWNER"]}
|
||||
|
||||
try:
|
||||
network_client.network_interfaces.create_or_update(resource_group, name, params)
|
||||
except CloudError as err:
|
||||
if "RetryableError" not in repr(err):
|
||||
return Error(
|
||||
code=ErrorCode.VM_CREATE_FAILED,
|
||||
errors=["unable to create nic: %s" % err],
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def get_public_ip(resource_id: str) -> Optional[str]:
|
||||
logging.info("getting ip for %s", resource_id)
|
||||
network_client = mgmt_client_factory(NetworkManagementClient)
|
||||
resource = parse_resource_id(resource_id)
|
||||
ip = (
|
||||
network_client.network_interfaces.get(
|
||||
resource["resource_group"], resource["name"]
|
||||
)
|
||||
.ip_configurations[0]
|
||||
.public_ip_address
|
||||
)
|
||||
resource = parse_resource_id(ip.id)
|
||||
ip = network_client.public_ip_addresses.get(
|
||||
resource["resource_group"], resource["name"]
|
||||
).ip_address
|
||||
if ip is None:
|
||||
return None
|
||||
else:
|
||||
return str(ip)
|
29
src/api-service/__app__/onefuzzlib/azure/monitor.py
Normal file
29
src/api-service/__app__/onefuzzlib/azure/monitor.py
Normal file
@ -0,0 +1,29 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
|
||||
from azure.mgmt.loganalytics import LogAnalyticsManagementClient
|
||||
from memoization import cached
|
||||
|
||||
from .creds import get_base_resource_group, mgmt_client_factory
|
||||
|
||||
|
||||
@cached(ttl=60)
|
||||
def get_montior_client() -> Any:
|
||||
return mgmt_client_factory(LogAnalyticsManagementClient)
|
||||
|
||||
|
||||
@cached(ttl=60)
|
||||
def get_monitor_settings() -> Dict[str, str]:
|
||||
resource_group = get_base_resource_group()
|
||||
workspace_name = os.environ["ONEFUZZ_MONITOR"]
|
||||
client = get_montior_client()
|
||||
customer_id = client.workspaces.get(resource_group, workspace_name).customer_id
|
||||
shared_key = client.shared_keys.get_shared_keys(
|
||||
resource_group, workspace_name
|
||||
).primary_shared_key
|
||||
return {"id": customer_id, "key": shared_key}
|
24
src/api-service/__app__/onefuzzlib/azure/monkeypatch.py
Normal file
24
src/api-service/__app__/onefuzzlib/azure/monkeypatch.py
Normal file
@ -0,0 +1,24 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
|
||||
WORKERS_DONE = False
|
||||
|
||||
|
||||
def allow_more_workers() -> None:
|
||||
global WORKERS_DONE
|
||||
if WORKERS_DONE:
|
||||
return
|
||||
|
||||
stack = inspect.stack()
|
||||
for entry in stack:
|
||||
if entry.filename.endswith("azure_functions_worker/dispatcher.py"):
|
||||
if entry.frame.f_locals["self"]._sync_call_tp._max_workers == 1:
|
||||
logging.info("bumped thread worker count to 50")
|
||||
entry.frame.f_locals["self"]._sync_call_tp._max_workers = 50
|
||||
|
||||
WORKERS_DONE = True
|
46
src/api-service/__app__/onefuzzlib/azure/network.py
Normal file
46
src/api-service/__app__/onefuzzlib/azure/network.py
Normal file
@ -0,0 +1,46 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
from typing import Optional, Union
|
||||
|
||||
from msrestazure.azure_exceptions import CloudError
|
||||
from onefuzztypes.enums import ErrorCode
|
||||
from onefuzztypes.models import Error
|
||||
from onefuzztypes.primitives import Region
|
||||
|
||||
from .creds import get_base_resource_group
|
||||
from .subnet import create_virtual_network, delete_subnet, get_subnet_id
|
||||
|
||||
|
||||
class Network:
|
||||
def __init__(self, region: Region):
|
||||
self.group = get_base_resource_group()
|
||||
self.region = region
|
||||
|
||||
def exists(self) -> bool:
|
||||
return self.get_id() is not None
|
||||
|
||||
def get_id(self) -> Optional[str]:
|
||||
return get_subnet_id(self.group, self.region)
|
||||
|
||||
def create(self) -> Union[None, Error]:
|
||||
if not self.exists():
|
||||
result = create_virtual_network(self.group, self.region, self.region)
|
||||
if isinstance(result, CloudError):
|
||||
error = Error(
|
||||
code=ErrorCode.UNABLE_TO_CREATE_NETWORK, errors=[result.message]
|
||||
)
|
||||
logging.error(
|
||||
"network creation failed: %s- %s",
|
||||
self.region,
|
||||
error,
|
||||
)
|
||||
return error
|
||||
|
||||
return None
|
||||
|
||||
def delete(self) -> None:
|
||||
delete_subnet(self.group, self.region)
|
153
src/api-service/__app__/onefuzzlib/azure/queue.py
Normal file
153
src/api-service/__app__/onefuzzlib/azure/queue.py
Normal file
@ -0,0 +1,153 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import base64
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
from typing import List, Optional, Type, TypeVar, Union
|
||||
from uuid import UUID
|
||||
|
||||
from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError
|
||||
from azure.storage.queue import (
|
||||
QueueSasPermissions,
|
||||
QueueServiceClient,
|
||||
generate_queue_sas,
|
||||
)
|
||||
from memoization import cached
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .creds import get_storage_account_name_key
|
||||
|
||||
QueueNameType = Union[str, UUID]
|
||||
|
||||
|
||||
@cached(ttl=60)
|
||||
def get_queue_client(account_id: str) -> QueueServiceClient:
|
||||
logging.info("getting blob container (account_id: %s)", account_id)
|
||||
name, key = get_storage_account_name_key(account_id)
|
||||
account_url = "https://%s.queue.core.windows.net" % name
|
||||
client = QueueServiceClient(
|
||||
account_url=account_url,
|
||||
credential={"account_name": name, "account_key": key},
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
def get_queue_sas(
|
||||
queue: QueueNameType,
|
||||
*,
|
||||
account_id: str,
|
||||
read: bool = False,
|
||||
add: bool = False,
|
||||
update: bool = False,
|
||||
process: bool = False,
|
||||
) -> str:
|
||||
logging.info("getting queue sas %s (account_id: %s)", queue, account_id)
|
||||
name, key = get_storage_account_name_key(account_id)
|
||||
expiry = datetime.datetime.utcnow() + datetime.timedelta(days=30)
|
||||
|
||||
token = generate_queue_sas(
|
||||
name,
|
||||
str(queue),
|
||||
key,
|
||||
permission=QueueSasPermissions(
|
||||
read=read, add=add, update=update, process=process
|
||||
),
|
||||
expiry=expiry,
|
||||
)
|
||||
|
||||
url = "https://%s.queue.core.windows.net/%s?%s" % (name, queue, token)
|
||||
return url
|
||||
|
||||
|
||||
@cached(ttl=60)
|
||||
def create_queue(name: QueueNameType, *, account_id: str) -> None:
|
||||
client = get_queue_client(account_id)
|
||||
try:
|
||||
client.create_queue(str(name))
|
||||
except ResourceExistsError:
|
||||
pass
|
||||
|
||||
|
||||
def delete_queue(name: QueueNameType, *, account_id: str) -> None:
|
||||
client = get_queue_client(account_id)
|
||||
queues = client.list_queues()
|
||||
if str(name) in [x["name"] for x in queues]:
|
||||
client.delete_queue(name)
|
||||
|
||||
|
||||
def get_queue(name: QueueNameType, *, account_id: str) -> Optional[QueueServiceClient]:
|
||||
client = get_queue_client(account_id)
|
||||
try:
|
||||
return client.get_queue_client(str(name))
|
||||
except ResourceNotFoundError:
|
||||
return None
|
||||
|
||||
|
||||
def send_message(
|
||||
name: QueueNameType,
|
||||
message: bytes,
|
||||
*,
|
||||
account_id: str,
|
||||
) -> None:
|
||||
queue = get_queue(name, account_id=account_id)
|
||||
if queue:
|
||||
try:
|
||||
queue.send_message(base64.b64encode(message).decode())
|
||||
except ResourceNotFoundError:
|
||||
pass
|
||||
|
||||
|
||||
A = TypeVar("A", bound=BaseModel)
|
||||
|
||||
|
||||
MIN_PEEK_SIZE = 1
|
||||
MAX_PEEK_SIZE = 32
|
||||
|
||||
|
||||
# Peek at a max of 32 messages
|
||||
# https://docs.microsoft.com/en-us/python/api/azure-storage-queue/azure.storage.queue.queueclient
|
||||
def peek_queue(
|
||||
name: QueueNameType,
|
||||
*,
|
||||
account_id: str,
|
||||
object_type: Type[A],
|
||||
max_messages: int = MAX_PEEK_SIZE,
|
||||
) -> List[A]:
|
||||
result: List[A] = []
|
||||
|
||||
# message count
|
||||
if max_messages < MIN_PEEK_SIZE or max_messages > MAX_PEEK_SIZE:
|
||||
raise ValueError("invalid max messages: %s" % max_messages)
|
||||
|
||||
queue = get_queue(name, account_id=account_id)
|
||||
if not queue:
|
||||
return result
|
||||
|
||||
for message in queue.peek_messages(max_messages=max_messages):
|
||||
decoded = base64.b64decode(message.content)
|
||||
raw = json.loads(decoded)
|
||||
result.append(object_type.parse_obj(raw))
|
||||
return result
|
||||
|
||||
|
||||
def queue_object(
|
||||
name: QueueNameType,
|
||||
message: BaseModel,
|
||||
*,
|
||||
account_id: str,
|
||||
visibility_timeout: Optional[int] = None,
|
||||
) -> bool:
|
||||
queue = get_queue(name, account_id=account_id)
|
||||
if not queue:
|
||||
raise Exception("unable to queue object, no such queue: %s" % queue)
|
||||
|
||||
encoded = base64.b64encode(message.json(exclude_none=True).encode()).decode()
|
||||
try:
|
||||
queue.send_message(encoded, visibility_timeout=visibility_timeout)
|
||||
return True
|
||||
except ResourceNotFoundError:
|
||||
return False
|
70
src/api-service/__app__/onefuzzlib/azure/subnet.py
Normal file
70
src/api-service/__app__/onefuzzlib/azure/subnet.py
Normal file
@ -0,0 +1,70 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from azure.mgmt.network import NetworkManagementClient
|
||||
from msrestazure.azure_exceptions import CloudError
|
||||
from onefuzztypes.enums import ErrorCode
|
||||
from onefuzztypes.models import Error
|
||||
|
||||
from .creds import mgmt_client_factory
|
||||
|
||||
|
||||
def get_subnet_id(resource_group: str, name: str) -> Optional[str]:
|
||||
network_client = mgmt_client_factory(NetworkManagementClient)
|
||||
try:
|
||||
subnet = network_client.subnets.get(resource_group, name, name)
|
||||
return cast(str, subnet.id)
|
||||
except CloudError:
|
||||
logging.info(
|
||||
"subnet missing: resource group: %s name: %s",
|
||||
resource_group,
|
||||
name,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def delete_subnet(resource_group: str, name: str) -> Union[None, CloudError, Any]:
|
||||
network_client = mgmt_client_factory(NetworkManagementClient)
|
||||
try:
|
||||
return network_client.virtual_networks.delete(resource_group, name)
|
||||
except CloudError as err:
|
||||
if err.error and "InUseSubnetCannotBeDeleted" in str(err.error):
|
||||
logging.error(
|
||||
"subnet delete failed: %s %s : %s", resource_group, name, repr(err)
|
||||
)
|
||||
return None
|
||||
else:
|
||||
raise err
|
||||
|
||||
|
||||
def create_virtual_network(
|
||||
resource_group: str, name: str, location: str
|
||||
) -> Optional[Error]:
|
||||
logging.info(
|
||||
"creating subnet - resource group: %s name: %s location: %s",
|
||||
resource_group,
|
||||
name,
|
||||
location,
|
||||
)
|
||||
|
||||
network_client = mgmt_client_factory(NetworkManagementClient)
|
||||
params = {
|
||||
"location": location,
|
||||
"address_space": {"address_prefixes": ["10.0.0.0/8"]},
|
||||
"subnets": [{"name": name, "address_prefix": "10.0.0.0/16"}],
|
||||
}
|
||||
if "ONEFUZZ_OWNER" in os.environ:
|
||||
params["tags"] = {"OWNER": os.environ["ONEFUZZ_OWNER"]}
|
||||
|
||||
try:
|
||||
network_client.virtual_networks.create_or_update(resource_group, name, params)
|
||||
except CloudError as err:
|
||||
return Error(code=ErrorCode.UNABLE_TO_CREATE_NETWORK, errors=[str(err.message)])
|
||||
|
||||
return None
|
30
src/api-service/__app__/onefuzzlib/azure/table.py
Normal file
30
src/api-service/__app__/onefuzzlib/azure/table.py
Normal file
@ -0,0 +1,30 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from azure.cosmosdb.table import TableService
|
||||
from memoization import cached
|
||||
|
||||
from .creds import get_storage_account_name_key
|
||||
|
||||
|
||||
@cached(ttl=60)
|
||||
def get_client(
|
||||
table: Optional[str] = None, account_id: Optional[str] = None
|
||||
) -> TableService:
|
||||
if account_id is None:
|
||||
account_id = os.environ["ONEFUZZ_FUNC_STORAGE"]
|
||||
|
||||
logging.info("getting table account: (account_id: %s)", account_id)
|
||||
name, key = get_storage_account_name_key(account_id)
|
||||
client = TableService(account_name=name, account_key=key)
|
||||
|
||||
if table and not client.exists(table):
|
||||
logging.info("creating missing table %s", table)
|
||||
client.create_table(table, fail_on_exist=False)
|
||||
return client
|
273
src/api-service/__app__/onefuzzlib/azure/vm.py
Normal file
273
src/api-service/__app__/onefuzzlib/azure/vm.py
Normal file
@ -0,0 +1,273 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Union, cast
|
||||
from uuid import UUID
|
||||
|
||||
from azure.mgmt.compute import ComputeManagementClient
|
||||
from azure.mgmt.compute.models import VirtualMachine
|
||||
from msrestazure.azure_exceptions import CloudError
|
||||
from onefuzztypes.enums import OS, ErrorCode
|
||||
from onefuzztypes.models import Authentication, Error
|
||||
from onefuzztypes.primitives import Extension, Region
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .creds import get_base_resource_group, mgmt_client_factory
|
||||
from .disk import delete_disk, list_disks
|
||||
from .image import get_os
|
||||
from .ip import create_public_nic, delete_ip, delete_nic, get_ip, get_public_nic
|
||||
|
||||
|
||||
def get_vm(name: str) -> Optional[VirtualMachine]:
|
||||
resource_group = get_base_resource_group()
|
||||
|
||||
logging.debug("getting vm: %s %s - %s", resource_group, name)
|
||||
compute_client = mgmt_client_factory(ComputeManagementClient)
|
||||
try:
|
||||
return cast(
|
||||
VirtualMachine,
|
||||
compute_client.virtual_machines.get(
|
||||
resource_group, name, expand="instanceView"
|
||||
),
|
||||
)
|
||||
except CloudError as err:
|
||||
logging.debug("vm does not exist %s", err)
|
||||
return None
|
||||
|
||||
|
||||
def create_vm(
|
||||
name: str,
|
||||
location: str,
|
||||
vm_sku: str,
|
||||
image: str,
|
||||
password: str,
|
||||
ssh_public_key: str,
|
||||
) -> Union[None, Error]:
|
||||
resource_group = get_base_resource_group()
|
||||
logging.info("creating vm %s:%s:%s", resource_group, location, name)
|
||||
|
||||
compute_client = mgmt_client_factory(ComputeManagementClient)
|
||||
|
||||
nic = get_public_nic(resource_group, name)
|
||||
if nic is None:
|
||||
result = create_public_nic(resource_group, name, location)
|
||||
if isinstance(result, Error):
|
||||
return result
|
||||
logging.info("waiting on nic creation")
|
||||
return None
|
||||
|
||||
if image.startswith("/"):
|
||||
image_ref = {"id": image}
|
||||
else:
|
||||
image_val = image.split(":", 4)
|
||||
image_ref = {
|
||||
"publisher": image_val[0],
|
||||
"offer": image_val[1],
|
||||
"sku": image_val[2],
|
||||
"version": image_val[3],
|
||||
}
|
||||
|
||||
params: Dict = {
|
||||
"location": location,
|
||||
"os_profile": {
|
||||
"computer_name": "node",
|
||||
"admin_username": "onefuzz",
|
||||
"admin_password": password,
|
||||
},
|
||||
"hardware_profile": {"vm_size": vm_sku},
|
||||
"storage_profile": {"image_reference": image_ref},
|
||||
"network_profile": {"network_interfaces": [{"id": nic.id}]},
|
||||
}
|
||||
|
||||
image_os = get_os(location, image)
|
||||
if isinstance(image_os, Error):
|
||||
return image_os
|
||||
|
||||
if image_os == OS.linux:
|
||||
params["os_profile"]["linux_configuration"] = {
|
||||
"disable_password_authentication": True,
|
||||
"ssh": {
|
||||
"public_keys": [
|
||||
{
|
||||
"path": "/home/onefuzz/.ssh/authorized_keys",
|
||||
"key_data": ssh_public_key,
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
|
||||
if "ONEFUZZ_OWNER" in os.environ:
|
||||
params["tags"] = {"OWNER": os.environ["ONEFUZZ_OWNER"]}
|
||||
|
||||
try:
|
||||
compute_client.virtual_machines.create_or_update(resource_group, name, params)
|
||||
except CloudError as err:
|
||||
if "The request failed due to conflict with a concurrent request" in str(err):
|
||||
logging.debug(
|
||||
"create VM had conflicts with concurrent request, ignoring %s", err
|
||||
)
|
||||
return None
|
||||
return Error(code=ErrorCode.VM_CREATE_FAILED, errors=[str(err)])
|
||||
return None
|
||||
|
||||
|
||||
def get_extension(vm_name: str, extension_name: str) -> Optional[Any]:
|
||||
resource_group = get_base_resource_group()
|
||||
|
||||
logging.debug(
|
||||
"getting extension: %s:%s:%s - %s",
|
||||
resource_group,
|
||||
vm_name,
|
||||
extension_name,
|
||||
)
|
||||
compute_client = mgmt_client_factory(ComputeManagementClient)
|
||||
try:
|
||||
return compute_client.virtual_machine_extensions.get(
|
||||
resource_group, vm_name, extension_name
|
||||
)
|
||||
except CloudError as err:
|
||||
logging.error("extension does not exist %s", err)
|
||||
return None
|
||||
|
||||
|
||||
def create_extension(vm_name: str, extension: Dict) -> Any:
|
||||
resource_group = get_base_resource_group()
|
||||
|
||||
logging.info(
|
||||
"creating extension: %s:%s:%s", resource_group, vm_name, extension["name"]
|
||||
)
|
||||
compute_client = mgmt_client_factory(ComputeManagementClient)
|
||||
return compute_client.virtual_machine_extensions.create_or_update(
|
||||
resource_group, vm_name, extension["name"], extension
|
||||
)
|
||||
|
||||
|
||||
def delete_vm(name: str) -> Any:
|
||||
resource_group = get_base_resource_group()
|
||||
|
||||
logging.info("deleting vm: %s %s", resource_group, name)
|
||||
compute_client = mgmt_client_factory(ComputeManagementClient)
|
||||
return compute_client.virtual_machines.delete(resource_group, name)
|
||||
|
||||
|
||||
def has_components(name: str) -> bool:
|
||||
# check if any of the components associated with a VM still exist.
|
||||
#
|
||||
# Azure VM Deletion requires we first delete the VM, then delete all of it's
|
||||
# resources. This is required to ensure we've cleaned it all up before
|
||||
# marking it "done"
|
||||
resource_group = get_base_resource_group()
|
||||
if get_vm(name):
|
||||
return True
|
||||
if get_public_nic(resource_group, name):
|
||||
return True
|
||||
if get_ip(resource_group, name):
|
||||
return True
|
||||
|
||||
disks = [x.name for x in list_disks(resource_group) if x.name.startswith(name)]
|
||||
if disks:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def delete_vm_components(name: str) -> bool:
|
||||
resource_group = get_base_resource_group()
|
||||
logging.info("deleting vm components %s:%s", resource_group, name)
|
||||
if get_vm(name):
|
||||
logging.info("deleting vm %s:%s", resource_group, name)
|
||||
delete_vm(name)
|
||||
return False
|
||||
|
||||
if get_public_nic(resource_group, name):
|
||||
logging.info("deleting nic %s:%s", resource_group, name)
|
||||
delete_nic(resource_group, name)
|
||||
return False
|
||||
|
||||
if get_ip(resource_group, name):
|
||||
logging.info("deleting ip %s:%s", resource_group, name)
|
||||
delete_ip(resource_group, name)
|
||||
return False
|
||||
|
||||
disks = [x.name for x in list_disks(resource_group) if x.name.startswith(name)]
|
||||
if disks:
|
||||
for disk in disks:
|
||||
logging.info("deleting disk %s:%s", resource_group, disk)
|
||||
delete_disk(resource_group, disk)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class VM(BaseModel):
|
||||
name: Union[UUID, str]
|
||||
region: Region
|
||||
sku: str
|
||||
image: str
|
||||
auth: Authentication
|
||||
|
||||
def is_deleted(self) -> bool:
|
||||
return has_components(str(self.name))
|
||||
|
||||
def exists(self) -> bool:
|
||||
return self.get() is not None
|
||||
|
||||
def get(self) -> Optional[VirtualMachine]:
|
||||
return get_vm(str(self.name))
|
||||
|
||||
def create(self) -> Union[None, Error]:
|
||||
if self.get() is not None:
|
||||
return None
|
||||
|
||||
logging.info("vm creating: %s", self.name)
|
||||
return create_vm(
|
||||
str(self.name),
|
||||
self.region,
|
||||
self.sku,
|
||||
self.image,
|
||||
self.auth.password,
|
||||
self.auth.public_key,
|
||||
)
|
||||
|
||||
def delete(self) -> bool:
|
||||
return delete_vm_components(str(self.name))
|
||||
|
||||
def add_extensions(self, extensions: List[Extension]) -> Union[bool, Error]:
|
||||
status = []
|
||||
to_create = []
|
||||
for config in extensions:
|
||||
if not isinstance(config["name"], str):
|
||||
logging.error("vm agent - incompatable name: %s", repr(config))
|
||||
continue
|
||||
extension = get_extension(str(self.name), config["name"])
|
||||
|
||||
if extension:
|
||||
logging.info(
|
||||
"vm extension state: %s - %s - %s",
|
||||
self.name,
|
||||
config["name"],
|
||||
extension.provisioning_state,
|
||||
)
|
||||
status.append(extension.provisioning_state)
|
||||
else:
|
||||
to_create.append(config)
|
||||
|
||||
if to_create:
|
||||
for config in to_create:
|
||||
create_extension(str(self.name), config)
|
||||
else:
|
||||
if all([x == "Succeeded" for x in status]):
|
||||
return True
|
||||
elif "Failed" in status:
|
||||
return Error(
|
||||
code=ErrorCode.VM_CREATE_FAILED,
|
||||
errors=["failed to launch extension"],
|
||||
)
|
||||
elif not ("Creating" in status or "Updating" in status):
|
||||
logging.error("vm agent - unknown state %s: %s", self.name, status)
|
||||
|
||||
return False
|
334
src/api-service/__app__/onefuzzlib/azure/vmss.py
Normal file
334
src/api-service/__app__/onefuzzlib/azure/vmss.py
Normal file
@ -0,0 +1,334 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Union, cast
|
||||
from uuid import UUID
|
||||
|
||||
from azure.mgmt.compute import ComputeManagementClient
|
||||
from azure.mgmt.compute.models import ResourceSku, ResourceSkuRestrictionsType
|
||||
from memoization import cached
|
||||
from msrestazure.azure_exceptions import CloudError
|
||||
from onefuzztypes.enums import OS, ErrorCode
|
||||
from onefuzztypes.models import Error
|
||||
from onefuzztypes.primitives import Region
|
||||
|
||||
from .creds import get_base_resource_group, mgmt_client_factory
|
||||
from .image import get_os
|
||||
|
||||
|
||||
def list_vmss(name: UUID) -> Optional[List[str]]:
|
||||
resource_group = get_base_resource_group()
|
||||
client = mgmt_client_factory(ComputeManagementClient)
|
||||
try:
|
||||
instances = [
|
||||
x.instance_id
|
||||
for x in client.virtual_machine_scale_set_vms.list(
|
||||
resource_group, str(name)
|
||||
)
|
||||
]
|
||||
return instances
|
||||
except CloudError as err:
|
||||
logging.error("cloud error listing vmss: %s (%s)", name, err)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def delete_vmss(name: UUID) -> Any:
|
||||
resource_group = get_base_resource_group()
|
||||
compute_client = mgmt_client_factory(ComputeManagementClient)
|
||||
try:
|
||||
compute_client.virtual_machine_scale_sets.delete(resource_group, str(name))
|
||||
except CloudError as err:
|
||||
logging.error("cloud error deleting vmss: %s (%s)", name, err)
|
||||
|
||||
|
||||
def get_vmss(name: UUID) -> Optional[Any]:
|
||||
resource_group = get_base_resource_group()
|
||||
logging.debug("getting vm: %s", name)
|
||||
compute_client = mgmt_client_factory(ComputeManagementClient)
|
||||
try:
|
||||
return compute_client.virtual_machine_scale_sets.get(resource_group, str(name))
|
||||
except CloudError as err:
|
||||
logging.debug("vm does not exist %s", err)
|
||||
return None
|
||||
|
||||
|
||||
def resize_vmss(name: UUID, capacity: int) -> None:
|
||||
check_can_update(name)
|
||||
|
||||
resource_group = get_base_resource_group()
|
||||
logging.info("updating VM count - name: %s vm_count: %d", name, capacity)
|
||||
compute_client = mgmt_client_factory(ComputeManagementClient)
|
||||
compute_client.virtual_machine_scale_sets.update(
|
||||
resource_group, str(name), {"sku": {"capacity": capacity}}
|
||||
)
|
||||
|
||||
|
||||
def get_vmss_size(name: UUID) -> Optional[int]:
|
||||
vmss = get_vmss(name)
|
||||
if vmss is None:
|
||||
return None
|
||||
return cast(int, vmss.sku.capacity)
|
||||
|
||||
|
||||
def list_instance_ids(name: UUID) -> Dict[UUID, str]:
|
||||
logging.info("get instance IDs for scaleset: %s", name)
|
||||
resource_group = get_base_resource_group()
|
||||
compute_client = mgmt_client_factory(ComputeManagementClient)
|
||||
|
||||
results = {}
|
||||
try:
|
||||
for instance in compute_client.virtual_machine_scale_set_vms.list(
|
||||
resource_group, str(name)
|
||||
):
|
||||
results[UUID(instance.vm_id)] = cast(str, instance.instance_id)
|
||||
except CloudError:
|
||||
logging.debug("scaleset not available: %s", name)
|
||||
return results
|
||||
|
||||
|
||||
@cached(ttl=60)
|
||||
def get_instance_id(name: UUID, vm_id: UUID) -> Union[str, Error]:
|
||||
resource_group = get_base_resource_group()
|
||||
logging.info("get instance ID for scaleset node: %s:%s", name, vm_id)
|
||||
compute_client = mgmt_client_factory(ComputeManagementClient)
|
||||
|
||||
vm_id_str = str(vm_id)
|
||||
for instance in compute_client.virtual_machine_scale_set_vms.list(
|
||||
resource_group, str(name)
|
||||
):
|
||||
if instance.vm_id == vm_id_str:
|
||||
return cast(str, instance.instance_id)
|
||||
|
||||
return Error(
|
||||
code=ErrorCode.UNABLE_TO_FIND,
|
||||
errors=["unable to find scaleset machine: %s:%s" % (name, vm_id)],
|
||||
)
|
||||
|
||||
|
||||
class UnableToUpdate(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def check_can_update(name: UUID) -> Any:
|
||||
vmss = get_vmss(name)
|
||||
if vmss is None:
|
||||
raise UnableToUpdate
|
||||
|
||||
if vmss.provisioning_state != "Succeeded":
|
||||
raise UnableToUpdate
|
||||
|
||||
return vmss
|
||||
|
||||
|
||||
def reimage_vmss_nodes(name: UUID, vm_ids: List[UUID]) -> Optional[Error]:
|
||||
check_can_update(name)
|
||||
|
||||
resource_group = get_base_resource_group()
|
||||
logging.info("reimaging scaleset VM - name: %s vm_ids:%s", name, vm_ids)
|
||||
compute_client = mgmt_client_factory(ComputeManagementClient)
|
||||
|
||||
instance_ids = []
|
||||
machine_to_id = list_instance_ids(name)
|
||||
for vm_id in vm_ids:
|
||||
if vm_id in machine_to_id:
|
||||
instance_ids.append(machine_to_id[vm_id])
|
||||
else:
|
||||
logging.info("unable to find vm_id for %s:%s", name, vm_id)
|
||||
|
||||
if instance_ids:
|
||||
compute_client.virtual_machine_scale_sets.reimage_all(
|
||||
resource_group, str(name), instance_ids=instance_ids
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def delete_vmss_nodes(name: UUID, vm_ids: List[UUID]) -> Optional[Error]:
|
||||
check_can_update(name)
|
||||
|
||||
resource_group = get_base_resource_group()
|
||||
logging.info("deleting scaleset VM - name: %s vm_ids:%s", name, vm_ids)
|
||||
compute_client = mgmt_client_factory(ComputeManagementClient)
|
||||
|
||||
instance_ids = []
|
||||
machine_to_id = list_instance_ids(name)
|
||||
for vm_id in vm_ids:
|
||||
if vm_id in machine_to_id:
|
||||
instance_ids.append(machine_to_id[vm_id])
|
||||
else:
|
||||
logging.info("unable to find vm_id for %s:%s", name, vm_id)
|
||||
|
||||
if instance_ids:
|
||||
compute_client.virtual_machine_scale_sets.delete_instances(
|
||||
resource_group, str(name), instance_ids=instance_ids
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def update_extensions(name: UUID, extensions: List[Any]) -> None:
|
||||
check_can_update(name)
|
||||
|
||||
resource_group = get_base_resource_group()
|
||||
logging.info("updating VM extensions: %s", name)
|
||||
compute_client = mgmt_client_factory(ComputeManagementClient)
|
||||
compute_client.virtual_machine_scale_sets.update(
|
||||
resource_group,
|
||||
str(name),
|
||||
{"virtual_machine_profile": {"extension_profile": {"extensions": extensions}}},
|
||||
)
|
||||
|
||||
|
||||
def create_vmss(
|
||||
location: Region,
|
||||
name: UUID,
|
||||
vm_sku: str,
|
||||
vm_count: int,
|
||||
image: str,
|
||||
network_id: str,
|
||||
spot_instances: bool,
|
||||
extensions: List[Any],
|
||||
password: str,
|
||||
ssh_public_key: str,
|
||||
tags: Dict[str, str],
|
||||
) -> Optional[Error]:
|
||||
|
||||
vmss = get_vmss(name)
|
||||
if vmss is not None:
|
||||
return None
|
||||
|
||||
logging.info(
|
||||
"creating VM count"
|
||||
"name: %s vm_sku: %s vm_count: %d "
|
||||
"image: %s subnet: %s spot_instances: %s",
|
||||
name,
|
||||
vm_sku,
|
||||
vm_count,
|
||||
image,
|
||||
network_id,
|
||||
spot_instances,
|
||||
)
|
||||
|
||||
resource_group = get_base_resource_group()
|
||||
|
||||
compute_client = mgmt_client_factory(ComputeManagementClient)
|
||||
|
||||
if image.startswith("/"):
|
||||
image_ref = {"id": image}
|
||||
else:
|
||||
image_val = image.split(":", 4)
|
||||
image_ref = {
|
||||
"publisher": image_val[0],
|
||||
"offer": image_val[1],
|
||||
"sku": image_val[2],
|
||||
"version": image_val[3],
|
||||
}
|
||||
|
||||
sku = {"name": vm_sku, "tier": "Standard", "capacity": vm_count}
|
||||
|
||||
params: Dict[str, Any] = {
|
||||
"location": location,
|
||||
"do_not_run_extensions_on_overprovisioned_vms": True,
|
||||
"upgrade_policy": {"mode": "Manual"},
|
||||
"sku": sku,
|
||||
"identity": {"type": "SystemAssigned"},
|
||||
"virtual_machine_profile": {
|
||||
"priority": "Regular",
|
||||
"storage_profile": {"image_reference": image_ref},
|
||||
"os_profile": {
|
||||
"computer_name_prefix": "node",
|
||||
"admin_username": "onefuzz",
|
||||
"admin_password": password,
|
||||
},
|
||||
"network_profile": {
|
||||
"network_interface_configurations": [
|
||||
{
|
||||
"name": "onefuzz-nic",
|
||||
"primary": True,
|
||||
"ip_configurations": [
|
||||
{"name": "onefuzz-ip-config", "subnet": {"id": network_id}}
|
||||
],
|
||||
}
|
||||
]
|
||||
},
|
||||
"extension_profile": {"extensions": extensions},
|
||||
},
|
||||
}
|
||||
|
||||
image_os = get_os(location, image)
|
||||
if isinstance(image_os, Error):
|
||||
return image_os
|
||||
|
||||
if image_os == OS.linux:
|
||||
params["virtual_machine_profile"]["os_profile"]["linux_configuration"] = {
|
||||
"disable_password_authentication": True,
|
||||
"ssh": {
|
||||
"public_keys": [
|
||||
{
|
||||
"path": "/home/onefuzz/.ssh/authorized_keys",
|
||||
"key_data": ssh_public_key,
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
|
||||
if spot_instances:
|
||||
# Setting max price to -1 means it won't be evicted because of
|
||||
# price.
|
||||
#
|
||||
# https://docs.microsoft.com/en-us/azure/
|
||||
# virtual-machine-scale-sets/use-spot#resource-manager-templates
|
||||
params["virtual_machine_profile"].update(
|
||||
{
|
||||
"eviction_policy": "Delete",
|
||||
"priority": "Spot",
|
||||
"billing_profile": {"max_price": -1},
|
||||
}
|
||||
)
|
||||
|
||||
params["tags"] = tags.copy()
|
||||
owner = os.environ.get("ONEFUZZ_OWNER")
|
||||
if owner:
|
||||
params["tags"]["OWNER"] = owner
|
||||
|
||||
try:
|
||||
compute_client.virtual_machine_scale_sets.create_or_update(
|
||||
resource_group, name, params
|
||||
)
|
||||
except CloudError as err:
|
||||
if "The request failed due to conflict with a concurrent request" in repr(err):
|
||||
logging.debug(
|
||||
"create VM had conflicts with concurrent request, ignoring %s", err
|
||||
)
|
||||
return None
|
||||
return Error(
|
||||
code=ErrorCode.VM_CREATE_FAILED,
|
||||
errors=["creating vmss: %s" % err],
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@cached(ttl=60)
|
||||
def list_available_skus(location: str) -> List[str]:
|
||||
compute_client = mgmt_client_factory(ComputeManagementClient)
|
||||
skus: List[ResourceSku] = list(
|
||||
compute_client.resource_skus.list(filter="location eq '%s'" % location)
|
||||
)
|
||||
sku_names: List[str] = []
|
||||
for sku in skus:
|
||||
available = True
|
||||
if sku.restrictions is not None:
|
||||
for restriction in sku.restrictions:
|
||||
if restriction.type == ResourceSkuRestrictionsType.location and (
|
||||
location.upper() in [v.upper() for v in restriction.values]
|
||||
):
|
||||
available = False
|
||||
break
|
||||
|
||||
if available:
|
||||
sku_names.append(sku.name)
|
||||
return sku_names
|
Reference in New Issue
Block a user