diff --git a/docs/webhook_events.md b/docs/webhook_events.md index 53964ed3b..80fd06ca0 100644 --- a/docs/webhook_events.md +++ b/docs/webhook_events.md @@ -649,6 +649,10 @@ Each event will be submitted via HTTP POST to the user provided URL. "address_space": "10.0.0.0/8", "subnet": "10.0.0.0/16" }, + "proxy_nsg_config": { + "allowed_ips": [], + "allowed_service_tags": [] + }, "proxy_vm_sku": "Standard_B2s" } } @@ -759,6 +763,9 @@ Each event will be submitted via HTTP POST to the user provided URL. "network_config": { "$ref": "#/definitions/NetworkConfig" }, + "proxy_nsg_config": { + "$ref": "#/definitions/NetworkSecurityGroupConfig" + }, "proxy_vm_sku": { "default": "Standard_B2s", "title": "Proxy Vm Sku", @@ -814,6 +821,26 @@ Each event will be submitted via HTTP POST to the user provided URL. }, "title": "NetworkConfig", "type": "object" + }, + "NetworkSecurityGroupConfig": { + "properties": { + "allowed_ips": { + "items": { + "type": "string" + }, + "title": "Allowed Ips", + "type": "array" + }, + "allowed_service_tags": { + "items": { + "type": "string" + }, + "title": "Allowed Service Tags", + "type": "array" + } + }, + "title": "NetworkSecurityGroupConfig", + "type": "object" } }, "properties": { @@ -5830,6 +5857,9 @@ Each event will be submitted via HTTP POST to the user provided URL. "network_config": { "$ref": "#/definitions/NetworkConfig" }, + "proxy_nsg_config": { + "$ref": "#/definitions/NetworkSecurityGroupConfig" + }, "proxy_vm_sku": { "default": "Standard_B2s", "title": "Proxy Vm Sku", @@ -5937,6 +5967,26 @@ Each event will be submitted via HTTP POST to the user provided URL. "title": "NetworkConfig", "type": "object" }, + "NetworkSecurityGroupConfig": { + "properties": { + "allowed_ips": { + "items": { + "type": "string" + }, + "title": "Allowed Ips", + "type": "array" + }, + "allowed_service_tags": { + "items": { + "type": "string" + }, + "title": "Allowed Service Tags", + "type": "array" + } + }, + "title": "NetworkSecurityGroupConfig", + "type": "object" + }, "NoReproReport": { "properties": { "error": { diff --git a/src/api-service/__app__/instance_config/__init__.py b/src/api-service/__app__/instance_config/__init__.py index 800abe5f6..e8f4e3707 100644 --- a/src/api-service/__app__/instance_config/__init__.py +++ b/src/api-service/__app__/instance_config/__init__.py @@ -8,9 +8,11 @@ from onefuzztypes.enums import ErrorCode from onefuzztypes.models import Error from onefuzztypes.requests import InstanceConfigUpdate +from ..onefuzzlib.azure.nsg import set_allowed from ..onefuzzlib.config import InstanceConfig from ..onefuzzlib.endpoint_authorization import call_if_user, can_modify_config from ..onefuzzlib.request import not_ok, ok, parse_request +from ..onefuzzlib.workers.scalesets import Scaleset def get(req: func.HttpRequest) -> func.HttpResponse: @@ -30,8 +32,34 @@ def post(req: func.HttpRequest) -> func.HttpResponse: context="instance_config_update", ) + update_nsg = False + if request.config.proxy_nsg_config and config.proxy_nsg_config: + request_config = request.config.proxy_nsg_config + current_config = config.proxy_nsg_config + if set(request_config.allowed_service_tags) != set( + current_config.allowed_service_tags + ) or set(request_config.allowed_ips) != set(current_config.allowed_ips): + update_nsg = True + config.update(request.config) config.save() + + # Update All NSGs + if update_nsg: + scalesets = Scaleset.search() + regions = set(x.region for x in scalesets) + for region in regions: + # nsg = get_nsg(region) + result = set_allowed(region, request.config.proxy_nsg_config) + if isinstance(result, Error): + return not_ok( + Error( + code=ErrorCode.UNABLE_TO_CREATE, + errors=["Unable to update nsg %s due to %s" % (region, result)], + ), + context="instance_config_update", + ) + return ok(config) diff --git a/src/api-service/__app__/onefuzzlib/azure/nsg.py b/src/api-service/__app__/onefuzzlib/azure/nsg.py new file mode 100644 index 000000000..d7145169c --- /dev/null +++ b/src/api-service/__app__/onefuzzlib/azure/nsg.py @@ -0,0 +1,337 @@ +#!/usr/bin/env python +# +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import logging +import os +from typing import Dict, List, Optional, Union, cast + +from azure.core.exceptions import HttpResponseError, ResourceNotFoundError +from azure.mgmt.network.models import ( + NetworkInterface, + NetworkSecurityGroup, + SecurityRule, + SecurityRuleAccess, + SecurityRuleProtocol, +) +from msrestazure.azure_exceptions import CloudError +from onefuzztypes.enums import ErrorCode +from onefuzztypes.models import Error, NetworkSecurityGroupConfig +from onefuzztypes.primitives import Region +from pydantic import BaseModel, validator + +from .creds import get_base_resource_group +from .network_mgmt_client import get_network_client + + +def is_concurrent_request_error(err: str) -> bool: + return "The request failed due to conflict with a concurrent request" in str(err) + + +def get_nsg(name: str) -> Optional[NetworkSecurityGroup]: + resource_group = get_base_resource_group() + + logging.debug("getting nsg: %s", name) + network_client = get_network_client() + try: + nsg = network_client.network_security_groups.get(resource_group, name) + return cast(NetworkSecurityGroup, nsg) + except (ResourceNotFoundError, CloudError) as err: + logging.debug("nsg %s does not exist: %s", name, err) + return None + + +def create_nsg(name: str, location: Region) -> Union[None, Error]: + resource_group = get_base_resource_group() + + logging.info("creating nsg %s:%s:%s", resource_group, location, name) + network_client = get_network_client() + + params: Dict = { + "location": location, + } + + if "ONEFUZZ_OWNER" in os.environ: + params["tags"] = {"OWNER": os.environ["ONEFUZZ_OWNER"]} + + try: + network_client.network_security_groups.begin_create_or_update( + resource_group, name, params + ) + except (ResourceNotFoundError, CloudError) as err: + if is_concurrent_request_error(str(err)): + logging.debug( + "create NSG had conflicts with concurrent request, ignoring %s", err + ) + return None + return Error( + code=ErrorCode.UNABLE_TO_CREATE, + errors=["Unable to create nsg %s due to %s" % (name, err)], + ) + return None + + +def update_nsg(nsg: NetworkSecurityGroup) -> Union[None, Error]: + resource_group = get_base_resource_group() + + logging.info("updating nsg %s:%s:%s", resource_group, nsg.location, nsg.name) + network_client = get_network_client() + + try: + network_client.network_security_groups.begin_create_or_update( + resource_group, nsg.name, nsg + ) + except (ResourceNotFoundError, CloudError) as err: + if is_concurrent_request_error(str(err)): + logging.debug( + "create NSG had conflicts with concurrent request, ignoring %s", err + ) + return None + return Error( + code=ErrorCode.UNABLE_TO_CREATE, + errors=["Unable to update nsg %s due to %s" % (nsg.name, err)], + ) + return None + + +def delete_nsg(name: str) -> bool: + # NSG can be only deleted if no other resource is associated with it + resource_group = get_base_resource_group() + + logging.info("deleting nsg: %s %s", resource_group, name) + network_client = get_network_client() + + try: + network_client.network_security_groups.begin_delete(resource_group, name) + return True + except HttpResponseError as err: + err_str = str(err) + if ( + "cannot be deleted because it is in use by the following resources" + ) in err_str: + return False + except ResourceNotFoundError: + return True + + return False + + +def set_allowed(name: str, sources: NetworkSecurityGroupConfig) -> Union[None, Error]: + resource_group = get_base_resource_group() + nsg = get_nsg(name) + if not nsg: + return Error( + code=ErrorCode.UNABLE_TO_FIND, + errors=["cannot update nsg rules. nsg %s not found" % name], + ) + + logging.info( + "setting allowed incoming connection sources for nsg: %s %s", + resource_group, + name, + ) + all_sources = sources.allowed_ips + sources.allowed_service_tags + security_rules = [] + # NSG security rule priority range defined here: + # https://docs.microsoft.com/en-us/azure/virtual-network/network-security-groups-overview + min_priority = 100 + # NSG rules per NSG limits: + # https://docs.microsoft.com/en-us/azure/azure-resource-manager/management/azure-subscription-service-limits?toc=/azure/virtual-network/toc.json#networking-limits + max_rule_count = 1000 + if len(all_sources) > max_rule_count: + return Error( + code=ErrorCode.INVALID_REQUEST, + errors=[ + "too many rules provided %d. Max allowed: %d" + % ((len(all_sources)), max_rule_count), + ], + ) + + priority = min_priority + for src in all_sources: + security_rules.append( + SecurityRule( + name="Allow" + str(priority), + protocol=SecurityRuleProtocol.TCP, + source_port_range="*", + destination_port_range="*", + source_address_prefix=src, + destination_address_prefix="*", + access=SecurityRuleAccess.ALLOW, + priority=priority, # between 100 and 4096 + direction="Inbound", + ) + ) + # Will not exceed `max_rule_count` or max NSG priority (4096) + # due to earlier check of `len(all_sources)`. + priority += 1 + + nsg.security_rules = security_rules + return update_nsg(nsg) + + +def clear_all_rules(name: str) -> Union[None, Error]: + return set_allowed(name, NetworkSecurityGroupConfig()) + + +def get_all_rules(name: str) -> Union[Error, List[SecurityRule]]: + nsg = get_nsg(name) + if not nsg: + return Error( + code=ErrorCode.UNABLE_TO_FIND, + errors=["cannot get nsg rules. nsg %s not found" % name], + ) + + return cast(List[SecurityRule], nsg.security_rules) + + +def associate_nic(name: str, nic: NetworkInterface) -> Union[None, Error]: + resource_group = get_base_resource_group() + nsg = get_nsg(name) + if not nsg: + return Error( + code=ErrorCode.UNABLE_TO_FIND, + errors=["cannot associate nic. nsg %s not found" % name], + ) + + if nsg.location != nic.location: + return Error( + code=ErrorCode.UNABLE_TO_UPDATE, + errors=[ + "network interface and nsg have to be in the same region.", + "nsg %s %s, nic: %s %s" + % (nsg.name, nsg.location, nic.name, nic.location), + ], + ) + + if nic.network_security_group and nic.network_security_group.id == nsg.id: + return None + + logging.info("associating nic %s with nsg: %s %s", nic.name, resource_group, name) + + nic.network_security_group = nsg + network_client = get_network_client() + try: + network_client.network_interfaces.begin_create_or_update( + resource_group, nic.name, nic + ) + except (ResourceNotFoundError, CloudError) as err: + if is_concurrent_request_error(str(err)): + logging.debug( + "associate NSG with NIC had conflicts", + "with concurrent request, ignoring %s", + err, + ) + return None + return Error( + code=ErrorCode.UNABLE_TO_UPDATE, + errors=[ + "Unable to associate nsg %s with nic %s due to %s" + % ( + name, + nic.name, + err, + ) + ], + ) + + return None + + +def dissociate_nic(name: str, nic: NetworkInterface) -> Union[None, Error]: + if nic.network_security_group is None: + return None + resource_group = get_base_resource_group() + nsg = get_nsg(name) + if not nsg: + return Error( + code=ErrorCode.UNABLE_TO_FIND, + errors=["cannot update nsg rules. nsg %s not found" % name], + ) + if nsg.id != nic.network_security_group.id: + return Error( + code=ErrorCode.UNABLE_TO_UPDATE, + errors=[ + "network interface is not associated with this nsg.", + "nsg %s, nic: %s, nic.nsg: %s" + % ( + nsg.id, + nic.name, + nic.network_security_group.id, + ), + ], + ) + + logging.info("dissociating nic %s with nsg: %s %s", nic.name, resource_group, name) + + nic.network_security_group = None + network_client = get_network_client() + try: + network_client.network_interfaces.begin_create_or_update( + resource_group, nic.name, nic + ) + except (ResourceNotFoundError, CloudError) as err: + if is_concurrent_request_error(str(err)): + logging.debug( + "associate NSG with NIC had conflicts with ", + "concurrent request, ignoring %s", + err, + ) + return None + return Error( + code=ErrorCode.UNABLE_TO_UPDATE, + errors=[ + "Unable to associate nsg %s with nic %s due to %s" + % ( + name, + nic.name, + err, + ) + ], + ) + + return None + + +class NSG(BaseModel): + name: str + region: Region + + @validator("name", allow_reuse=True) + def check_name(cls, value: str) -> str: + # https://docs.microsoft.com/en-us/azure/azure-resource-manager/management/resource-name-rules + if len(value) > 80: + raise ValueError("NSG name too long") + return value + + def create(self) -> Union[None, Error]: + # Optimization: if NSG exists - do not try + # to create it + if self.get() is not None: + return None + + return create_nsg(self.name, self.region) + + def delete(self) -> bool: + return delete_nsg(self.name) + + def get(self) -> Optional[NetworkSecurityGroup]: + return get_nsg(self.name) + + def set_allowed_sources( + self, sources: NetworkSecurityGroupConfig + ) -> Union[None, Error]: + return set_allowed(self.name, sources) + + def clear_all_rules(self) -> Union[None, Error]: + return clear_all_rules(self.name) + + def get_all_rules(self) -> Union[Error, List[SecurityRule]]: + return get_all_rules(self.name) + + def associate_nic(self, nic: NetworkInterface) -> Union[None, Error]: + return associate_nic(self.name, nic) + + def dissociate_nic(self, nic: NetworkInterface) -> Union[None, Error]: + return dissociate_nic(self.name, nic) diff --git a/src/api-service/__app__/onefuzzlib/azure/vm.py b/src/api-service/__app__/onefuzzlib/azure/vm.py index f52117dcd..f2e8deb12 100644 --- a/src/api-service/__app__/onefuzzlib/azure/vm.py +++ b/src/api-service/__app__/onefuzzlib/azure/vm.py @@ -21,6 +21,7 @@ from .creds import get_base_resource_group from .disk import delete_disk, list_disks from .image import get_os from .ip import create_public_nic, delete_ip, delete_nic, get_ip, get_public_nic +from .nsg import NSG def get_vm(name: str) -> Optional[VirtualMachine]: @@ -47,6 +48,7 @@ def create_vm( image: str, password: str, ssh_public_key: str, + nsg: Optional[NSG], ) -> Union[None, Error]: resource_group = get_base_resource_group() logging.info("creating vm %s:%s:%s", resource_group, location, name) @@ -60,6 +62,10 @@ def create_vm( return result logging.info("waiting on nic creation") return None + if nsg: + result = nsg.associate_nic(nic) + if isinstance(result, Error): + return result if image.startswith("/"): image_ref = {"id": image} @@ -181,7 +187,7 @@ def has_components(name: str) -> bool: return False -def delete_vm_components(name: str) -> bool: +def delete_vm_components(name: str, nsg: Optional[NSG]) -> bool: resource_group = get_base_resource_group() logging.info("deleting vm components %s:%s", resource_group, name) if get_vm(name): @@ -189,8 +195,12 @@ def delete_vm_components(name: str) -> bool: delete_vm(name) return False - if get_public_nic(resource_group, name): + nic = get_public_nic(resource_group, name) + if nic: logging.info("deleting nic %s:%s", resource_group, name) + if nic.network_security_group and nsg: + nsg.dissociate_nic(nic) + return False delete_nic(resource_group, name) return False @@ -215,6 +225,7 @@ class VM(BaseModel): sku: str image: str auth: Authentication + nsg: Optional[NSG] @validator("name", allow_reuse=True) def check_name(cls, value: Union[UUID, str]) -> Union[UUID, str]: @@ -248,10 +259,11 @@ class VM(BaseModel): self.image, self.auth.password, self.auth.public_key, + self.nsg, ) def delete(self) -> bool: - return delete_vm_components(str(self.name)) + return delete_vm_components(str(self.name), self.nsg) def add_extensions(self, extensions: List[Extension]) -> Union[bool, Error]: status = [] diff --git a/src/api-service/__app__/onefuzzlib/proxy.py b/src/api-service/__app__/onefuzzlib/proxy.py index c7c75bbb2..78d82220c 100644 --- a/src/api-service/__app__/onefuzzlib/proxy.py +++ b/src/api-service/__app__/onefuzzlib/proxy.py @@ -33,6 +33,7 @@ from .azure.auth import build_auth from .azure.containers import get_file_sas_url, save_blob from .azure.creds import get_instance_id from .azure.ip import get_public_ip +from .azure.nsg import NSG from .azure.queue import get_queue_sas from .azure.storage import StorageType from .azure.vm import VM @@ -90,6 +91,25 @@ class Proxy(ORMMixin): self.save_proxy_config() self.set_state(VmState.extensions_launch) else: + nsg = NSG( + name=self.region, + region=self.region, + ) + + result = nsg.create() + if isinstance(result, Error): + self.set_failed(result) + return + + config = InstanceConfig.fetch() + nsg_config = config.proxy_nsg_config + result = nsg.set_allowed_sources(nsg_config) + if isinstance(result, Error): + self.set_failed(result) + return + + vm.nsg = nsg + result = vm.create() if isinstance(result, Error): self.set_failed(result) diff --git a/src/api-service/__app__/onefuzzlib/repro.py b/src/api-service/__app__/onefuzzlib/repro.py index 38fcbebc1..937f252bd 100644 --- a/src/api-service/__app__/onefuzzlib/repro.py +++ b/src/api-service/__app__/onefuzzlib/repro.py @@ -18,8 +18,10 @@ from .azure.auth import build_auth from .azure.containers import save_blob from .azure.creds import get_base_region from .azure.ip import get_public_ip +from .azure.nsg import NSG from .azure.storage import StorageType from .azure.vm import VM +from .config import InstanceConfig from .extension import repro_extensions from .orm import ORMMixin, QueryFilter from .reports import get_report @@ -85,6 +87,23 @@ class Repro(BASE_REPRO, ORMMixin): self.state = VmState.extensions_launch else: + nsg = NSG( + name=vm.region, + region=vm.region, + ) + result = nsg.create() + if isinstance(result, Error): + self.set_failed(result) + return + + config = InstanceConfig.fetch() + nsg_config = config.proxy_nsg_config + result = nsg.set_allowed_sources(nsg_config) + if isinstance(result, Error): + self.set_failed(result) + return + + vm.nsg = nsg result = vm.create() if isinstance(result, Error): self.set_error(result) diff --git a/src/pytypes/onefuzztypes/models.py b/src/pytypes/onefuzztypes/models.py index 0e72090b8..a17a963a1 100644 --- a/src/pytypes/onefuzztypes/models.py +++ b/src/pytypes/onefuzztypes/models.py @@ -799,6 +799,11 @@ class NetworkConfig(BaseModel): subnet: str = Field(default="10.0.0.0/16") +class NetworkSecurityGroupConfig(BaseModel): + allowed_service_tags: List[str] = Field(default_factory=list) + allowed_ips: List[str] = Field(default_factory=list) + + class KeyvaultExtensionConfig(BaseModel): keyvault_name: str cert_name: str @@ -841,6 +846,9 @@ class InstanceConfig(BaseModel): allowed_aad_tenants: List[UUID] network_config: NetworkConfig = Field(default_factory=NetworkConfig) + proxy_nsg_config: NetworkSecurityGroupConfig = Field( + default_factory=NetworkSecurityGroupConfig + ) extensions: Optional[AzureVmExtensionConfig] proxy_vm_sku: str = Field(default="Standard_B2s")