mirror of
https://github.com/microsoft/onefuzz.git
synced 2025-06-16 11:58:09 +00:00
NSG deployment on a creation of new debug/repro proxy. (#1340)
Co-authored-by: stas <statis@microsoft.com>
This commit is contained in:
334
src/api-service/__app__/onefuzzlib/azure/nsg.py
Normal file
334
src/api-service/__app__/onefuzzlib/azure/nsg.py
Normal file
@ -0,0 +1,334 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
#
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Dict, List, Optional, Union, cast
|
||||||
|
|
||||||
|
from azure.core.exceptions import HttpResponseError, ResourceNotFoundError
|
||||||
|
from azure.mgmt.network.models import (
|
||||||
|
NetworkInterface,
|
||||||
|
NetworkSecurityGroup,
|
||||||
|
SecurityRule,
|
||||||
|
SecurityRuleAccess,
|
||||||
|
SecurityRuleProtocol,
|
||||||
|
)
|
||||||
|
from msrestazure.azure_exceptions import CloudError
|
||||||
|
from onefuzztypes.enums import ErrorCode
|
||||||
|
from onefuzztypes.models import Error
|
||||||
|
from onefuzztypes.primitives import Region
|
||||||
|
from pydantic import BaseModel, validator
|
||||||
|
|
||||||
|
from .creds import get_base_resource_group
|
||||||
|
from .network_mgmt_client import get_network_client
|
||||||
|
|
||||||
|
|
||||||
|
def is_concurrent_request_error(err: str) -> bool:
|
||||||
|
return "The request failed due to conflict with a concurrent request" in str(err)
|
||||||
|
|
||||||
|
|
||||||
|
def get_nsg(name: str) -> Optional[NetworkSecurityGroup]:
|
||||||
|
resource_group = get_base_resource_group()
|
||||||
|
|
||||||
|
logging.debug("getting nsg: %s", name)
|
||||||
|
network_client = get_network_client()
|
||||||
|
try:
|
||||||
|
nsg = network_client.network_security_groups.get(resource_group, name)
|
||||||
|
return cast(NetworkSecurityGroup, nsg)
|
||||||
|
except (ResourceNotFoundError, CloudError) as err:
|
||||||
|
logging.debug("nsg %s does not exist: %s", name, err)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def create_nsg(name: str, location: Region) -> Union[None, Error]:
|
||||||
|
resource_group = get_base_resource_group()
|
||||||
|
|
||||||
|
logging.info("creating nsg %s:%s:%s", resource_group, location, name)
|
||||||
|
network_client = get_network_client()
|
||||||
|
|
||||||
|
params: Dict = {
|
||||||
|
"location": location,
|
||||||
|
}
|
||||||
|
|
||||||
|
if "ONEFUZZ_OWNER" in os.environ:
|
||||||
|
params["tags"] = {"OWNER": os.environ["ONEFUZZ_OWNER"]}
|
||||||
|
|
||||||
|
try:
|
||||||
|
network_client.network_security_groups.begin_create_or_update(
|
||||||
|
resource_group, name, params
|
||||||
|
)
|
||||||
|
except (ResourceNotFoundError, CloudError) as err:
|
||||||
|
if is_concurrent_request_error(str(err)):
|
||||||
|
logging.debug(
|
||||||
|
"create NSG had conflicts with concurrent request, ignoring %s", err
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
return Error(
|
||||||
|
code=ErrorCode.UNABLE_TO_CREATE,
|
||||||
|
errors=["Unable to create nsg %s due to %s" % (name, err)],
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def update_nsg(nsg: NetworkSecurityGroup) -> Union[None, Error]:
|
||||||
|
resource_group = get_base_resource_group()
|
||||||
|
|
||||||
|
logging.info("updating nsg %s:%s:%s", resource_group, nsg.location, nsg.name)
|
||||||
|
network_client = get_network_client()
|
||||||
|
|
||||||
|
try:
|
||||||
|
network_client.network_security_groups.begin_create_or_update(
|
||||||
|
resource_group, nsg.name, nsg
|
||||||
|
)
|
||||||
|
except (ResourceNotFoundError, CloudError) as err:
|
||||||
|
if is_concurrent_request_error(str(err)):
|
||||||
|
logging.debug(
|
||||||
|
"create NSG had conflicts with concurrent request, ignoring %s", err
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
return Error(
|
||||||
|
code=ErrorCode.UNABLE_TO_CREATE,
|
||||||
|
errors=["Unable to update nsg %s due to %s" % (nsg.name, err)],
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def delete_nsg(name: str) -> bool:
|
||||||
|
# NSG can be only deleted if no other resource is associated with it
|
||||||
|
resource_group = get_base_resource_group()
|
||||||
|
|
||||||
|
logging.info("deleting nsg: %s %s", resource_group, name)
|
||||||
|
network_client = get_network_client()
|
||||||
|
|
||||||
|
try:
|
||||||
|
network_client.network_security_groups.begin_delete(resource_group, name)
|
||||||
|
return True
|
||||||
|
except HttpResponseError as err:
|
||||||
|
err_str = str(err)
|
||||||
|
if (
|
||||||
|
"cannot be deleted because it is in use by the following resources"
|
||||||
|
) in err_str:
|
||||||
|
return False
|
||||||
|
except ResourceNotFoundError:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def set_allowed(name: str, sources: List[str]) -> Union[None, Error]:
|
||||||
|
resource_group = get_base_resource_group()
|
||||||
|
nsg = get_nsg(name)
|
||||||
|
if not nsg:
|
||||||
|
return Error(
|
||||||
|
code=ErrorCode.UNABLE_TO_FIND,
|
||||||
|
errors=["cannot update nsg rules. nsg %s not found" % name],
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
"setting allowed incoming connection sources for nsg: %s %s",
|
||||||
|
resource_group,
|
||||||
|
name,
|
||||||
|
)
|
||||||
|
security_rules = []
|
||||||
|
# NSG security rule priority range defined here:
|
||||||
|
# https://docs.microsoft.com/en-us/azure/virtual-network/network-security-groups-overview
|
||||||
|
min_priority = 100
|
||||||
|
# NSG rules per NSG limits:
|
||||||
|
# https://docs.microsoft.com/en-us/azure/azure-resource-manager/management/azure-subscription-service-limits?toc=/azure/virtual-network/toc.json#networking-limits
|
||||||
|
max_rule_count = 1000
|
||||||
|
if len(sources) > max_rule_count:
|
||||||
|
return Error(
|
||||||
|
code=ErrorCode.INVALID_REQUEST,
|
||||||
|
errors=[
|
||||||
|
"too many rules provided %d. Max allowed: %d"
|
||||||
|
% ((len(sources)), max_rule_count),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
priority = min_priority
|
||||||
|
for src in sources:
|
||||||
|
security_rules.append(
|
||||||
|
SecurityRule(
|
||||||
|
name="Allow" + str(priority),
|
||||||
|
protocol=SecurityRuleProtocol.TCP,
|
||||||
|
source_port_range="*",
|
||||||
|
destination_port_range="*",
|
||||||
|
source_address_prefix=src,
|
||||||
|
destination_address_prefix="*",
|
||||||
|
access=SecurityRuleAccess.ALLOW,
|
||||||
|
priority=priority, # between 100 and 4096
|
||||||
|
direction="Inbound",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Will not exceed `max_rule_count` or max NSG priority (4096)
|
||||||
|
# due to earlier check of `len(sources)`.
|
||||||
|
priority += 1
|
||||||
|
|
||||||
|
nsg.security_rules = security_rules
|
||||||
|
return update_nsg(nsg)
|
||||||
|
|
||||||
|
|
||||||
|
def clear_all_rules(name: str) -> Union[None, Error]:
|
||||||
|
return set_allowed(name, [])
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_rules(name: str) -> Union[Error, List[SecurityRule]]:
|
||||||
|
nsg = get_nsg(name)
|
||||||
|
if not nsg:
|
||||||
|
return Error(
|
||||||
|
code=ErrorCode.UNABLE_TO_FIND,
|
||||||
|
errors=["cannot get nsg rules. nsg %s not found" % name],
|
||||||
|
)
|
||||||
|
|
||||||
|
return cast(List[SecurityRule], nsg.security_rules)
|
||||||
|
|
||||||
|
|
||||||
|
def associate_nic(name: str, nic: NetworkInterface) -> Union[None, Error]:
|
||||||
|
resource_group = get_base_resource_group()
|
||||||
|
nsg = get_nsg(name)
|
||||||
|
if not nsg:
|
||||||
|
return Error(
|
||||||
|
code=ErrorCode.UNABLE_TO_FIND,
|
||||||
|
errors=["cannot associate nic. nsg %s not found" % name],
|
||||||
|
)
|
||||||
|
|
||||||
|
if nsg.location != nic.location:
|
||||||
|
return Error(
|
||||||
|
code=ErrorCode.UNABLE_TO_UPDATE,
|
||||||
|
errors=[
|
||||||
|
"network interface and nsg have to be in the same region.",
|
||||||
|
"nsg %s %s, nic: %s %s"
|
||||||
|
% (nsg.name, nsg.location, nic.name, nic.location),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
if nic.network_security_group and nic.network_security_group.id == nsg.id:
|
||||||
|
return None
|
||||||
|
|
||||||
|
logging.info("associating nic %s with nsg: %s %s", nic.name, resource_group, name)
|
||||||
|
|
||||||
|
nic.network_security_group = nsg
|
||||||
|
network_client = get_network_client()
|
||||||
|
try:
|
||||||
|
network_client.network_interfaces.begin_create_or_update(
|
||||||
|
resource_group, nic.name, nic
|
||||||
|
)
|
||||||
|
except (ResourceNotFoundError, CloudError) as err:
|
||||||
|
if is_concurrent_request_error(str(err)):
|
||||||
|
logging.debug(
|
||||||
|
"associate NSG with NIC had conflicts",
|
||||||
|
"with concurrent request, ignoring %s",
|
||||||
|
err,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
return Error(
|
||||||
|
code=ErrorCode.UNABLE_TO_UPDATE,
|
||||||
|
errors=[
|
||||||
|
"Unable to associate nsg %s with nic %s due to %s"
|
||||||
|
% (
|
||||||
|
name,
|
||||||
|
nic.name,
|
||||||
|
err,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def dissociate_nic(name: str, nic: NetworkInterface) -> Union[None, Error]:
|
||||||
|
if nic.network_security_group is None:
|
||||||
|
return None
|
||||||
|
resource_group = get_base_resource_group()
|
||||||
|
nsg = get_nsg(name)
|
||||||
|
if not nsg:
|
||||||
|
return Error(
|
||||||
|
code=ErrorCode.UNABLE_TO_FIND,
|
||||||
|
errors=["cannot update nsg rules. nsg %s not found" % name],
|
||||||
|
)
|
||||||
|
if nsg.id != nic.network_security_group.id:
|
||||||
|
return Error(
|
||||||
|
code=ErrorCode.UNABLE_TO_UPDATE,
|
||||||
|
errors=[
|
||||||
|
"network interface is not associated with this nsg.",
|
||||||
|
"nsg %s, nic: %s, nic.nsg: %s"
|
||||||
|
% (
|
||||||
|
nsg.id,
|
||||||
|
nic.name,
|
||||||
|
nic.network_security_group.id,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("dissociating nic %s with nsg: %s %s", nic.name, resource_group, name)
|
||||||
|
|
||||||
|
nic.network_security_group = None
|
||||||
|
network_client = get_network_client()
|
||||||
|
try:
|
||||||
|
network_client.network_interfaces.begin_create_or_update(
|
||||||
|
resource_group, nic.name, nic
|
||||||
|
)
|
||||||
|
except (ResourceNotFoundError, CloudError) as err:
|
||||||
|
if is_concurrent_request_error(str(err)):
|
||||||
|
logging.debug(
|
||||||
|
"associate NSG with NIC had conflicts with ",
|
||||||
|
"concurrent request, ignoring %s",
|
||||||
|
err,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
return Error(
|
||||||
|
code=ErrorCode.UNABLE_TO_UPDATE,
|
||||||
|
errors=[
|
||||||
|
"Unable to associate nsg %s with nic %s due to %s"
|
||||||
|
% (
|
||||||
|
name,
|
||||||
|
nic.name,
|
||||||
|
err,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class NSG(BaseModel):
|
||||||
|
name: str
|
||||||
|
region: Region
|
||||||
|
|
||||||
|
@validator("name", allow_reuse=True)
|
||||||
|
def check_name(cls, value: str) -> str:
|
||||||
|
# https://docs.microsoft.com/en-us/azure/azure-resource-manager/management/resource-name-rules
|
||||||
|
if len(value) > 80:
|
||||||
|
raise ValueError("NSG name too long")
|
||||||
|
return value
|
||||||
|
|
||||||
|
def create(self) -> Union[None, Error]:
|
||||||
|
# Optimization: if NSG exists - do not try
|
||||||
|
# to create it
|
||||||
|
if self.get() is not None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return create_nsg(self.name, self.region)
|
||||||
|
|
||||||
|
def delete(self) -> bool:
|
||||||
|
return delete_nsg(self.name)
|
||||||
|
|
||||||
|
def get(self) -> Optional[NetworkSecurityGroup]:
|
||||||
|
return get_nsg(self.name)
|
||||||
|
|
||||||
|
def set_allowed_sources(self, sources: List[str]) -> Union[None, Error]:
|
||||||
|
return set_allowed(self.name, sources)
|
||||||
|
|
||||||
|
def clear_all_rules(self) -> Union[None, Error]:
|
||||||
|
return clear_all_rules(self.name)
|
||||||
|
|
||||||
|
def get_all_rules(self) -> Union[Error, List[SecurityRule]]:
|
||||||
|
return get_all_rules(self.name)
|
||||||
|
|
||||||
|
def associate_nic(self, nic: NetworkInterface) -> Union[None, Error]:
|
||||||
|
return associate_nic(self.name, nic)
|
||||||
|
|
||||||
|
def dissociate_nic(self, nic: NetworkInterface) -> Union[None, Error]:
|
||||||
|
return dissociate_nic(self.name, nic)
|
@ -21,6 +21,7 @@ from .creds import get_base_resource_group
|
|||||||
from .disk import delete_disk, list_disks
|
from .disk import delete_disk, list_disks
|
||||||
from .image import get_os
|
from .image import get_os
|
||||||
from .ip import create_public_nic, delete_ip, delete_nic, get_ip, get_public_nic
|
from .ip import create_public_nic, delete_ip, delete_nic, get_ip, get_public_nic
|
||||||
|
from .nsg import NSG
|
||||||
|
|
||||||
|
|
||||||
def get_vm(name: str) -> Optional[VirtualMachine]:
|
def get_vm(name: str) -> Optional[VirtualMachine]:
|
||||||
@ -47,6 +48,7 @@ def create_vm(
|
|||||||
image: str,
|
image: str,
|
||||||
password: str,
|
password: str,
|
||||||
ssh_public_key: str,
|
ssh_public_key: str,
|
||||||
|
nsg: Optional[NSG],
|
||||||
) -> Union[None, Error]:
|
) -> Union[None, Error]:
|
||||||
resource_group = get_base_resource_group()
|
resource_group = get_base_resource_group()
|
||||||
logging.info("creating vm %s:%s:%s", resource_group, location, name)
|
logging.info("creating vm %s:%s:%s", resource_group, location, name)
|
||||||
@ -60,6 +62,10 @@ def create_vm(
|
|||||||
return result
|
return result
|
||||||
logging.info("waiting on nic creation")
|
logging.info("waiting on nic creation")
|
||||||
return None
|
return None
|
||||||
|
if nsg:
|
||||||
|
result = nsg.associate_nic(nic)
|
||||||
|
if isinstance(result, Error):
|
||||||
|
return result
|
||||||
|
|
||||||
if image.startswith("/"):
|
if image.startswith("/"):
|
||||||
image_ref = {"id": image}
|
image_ref = {"id": image}
|
||||||
@ -181,7 +187,7 @@ def has_components(name: str) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def delete_vm_components(name: str) -> bool:
|
def delete_vm_components(name: str, nsg: Optional[NSG]) -> bool:
|
||||||
resource_group = get_base_resource_group()
|
resource_group = get_base_resource_group()
|
||||||
logging.info("deleting vm components %s:%s", resource_group, name)
|
logging.info("deleting vm components %s:%s", resource_group, name)
|
||||||
if get_vm(name):
|
if get_vm(name):
|
||||||
@ -189,8 +195,12 @@ def delete_vm_components(name: str) -> bool:
|
|||||||
delete_vm(name)
|
delete_vm(name)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if get_public_nic(resource_group, name):
|
nic = get_public_nic(resource_group, name)
|
||||||
|
if nic:
|
||||||
logging.info("deleting nic %s:%s", resource_group, name)
|
logging.info("deleting nic %s:%s", resource_group, name)
|
||||||
|
if nic.network_security_group and nsg:
|
||||||
|
nsg.dissociate_nic(nic)
|
||||||
|
return False
|
||||||
delete_nic(resource_group, name)
|
delete_nic(resource_group, name)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -215,6 +225,7 @@ class VM(BaseModel):
|
|||||||
sku: str
|
sku: str
|
||||||
image: str
|
image: str
|
||||||
auth: Authentication
|
auth: Authentication
|
||||||
|
nsg: Optional[NSG]
|
||||||
|
|
||||||
@validator("name", allow_reuse=True)
|
@validator("name", allow_reuse=True)
|
||||||
def check_name(cls, value: Union[UUID, str]) -> Union[UUID, str]:
|
def check_name(cls, value: Union[UUID, str]) -> Union[UUID, str]:
|
||||||
@ -248,10 +259,11 @@ class VM(BaseModel):
|
|||||||
self.image,
|
self.image,
|
||||||
self.auth.password,
|
self.auth.password,
|
||||||
self.auth.public_key,
|
self.auth.public_key,
|
||||||
|
self.nsg,
|
||||||
)
|
)
|
||||||
|
|
||||||
def delete(self) -> bool:
|
def delete(self) -> bool:
|
||||||
return delete_vm_components(str(self.name))
|
return delete_vm_components(str(self.name), self.nsg)
|
||||||
|
|
||||||
def add_extensions(self, extensions: List[Extension]) -> Union[bool, Error]:
|
def add_extensions(self, extensions: List[Extension]) -> Union[bool, Error]:
|
||||||
status = []
|
status = []
|
||||||
|
@ -33,6 +33,7 @@ from .azure.auth import build_auth
|
|||||||
from .azure.containers import get_file_sas_url, save_blob
|
from .azure.containers import get_file_sas_url, save_blob
|
||||||
from .azure.creds import get_instance_id
|
from .azure.creds import get_instance_id
|
||||||
from .azure.ip import get_public_ip
|
from .azure.ip import get_public_ip
|
||||||
|
from .azure.nsg import NSG
|
||||||
from .azure.queue import get_queue_sas
|
from .azure.queue import get_queue_sas
|
||||||
from .azure.storage import StorageType
|
from .azure.storage import StorageType
|
||||||
from .azure.vm import VM
|
from .azure.vm import VM
|
||||||
@ -90,6 +91,23 @@ class Proxy(ORMMixin):
|
|||||||
self.save_proxy_config()
|
self.save_proxy_config()
|
||||||
self.set_state(VmState.extensions_launch)
|
self.set_state(VmState.extensions_launch)
|
||||||
else:
|
else:
|
||||||
|
nsg = NSG(
|
||||||
|
name=self.region,
|
||||||
|
region=self.region,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = nsg.create()
|
||||||
|
if isinstance(result, Error):
|
||||||
|
self.set_failed(result)
|
||||||
|
return
|
||||||
|
|
||||||
|
result = nsg.set_allowed_sources(["*"])
|
||||||
|
if isinstance(result, Error):
|
||||||
|
self.set_failed(result)
|
||||||
|
return
|
||||||
|
|
||||||
|
vm.nsg = nsg
|
||||||
|
|
||||||
result = vm.create()
|
result = vm.create()
|
||||||
if isinstance(result, Error):
|
if isinstance(result, Error):
|
||||||
self.set_failed(result)
|
self.set_failed(result)
|
||||||
|
@ -18,6 +18,7 @@ from .azure.auth import build_auth
|
|||||||
from .azure.containers import save_blob
|
from .azure.containers import save_blob
|
||||||
from .azure.creds import get_base_region
|
from .azure.creds import get_base_region
|
||||||
from .azure.ip import get_public_ip
|
from .azure.ip import get_public_ip
|
||||||
|
from .azure.nsg import NSG
|
||||||
from .azure.storage import StorageType
|
from .azure.storage import StorageType
|
||||||
from .azure.vm import VM
|
from .azure.vm import VM
|
||||||
from .extension import repro_extensions
|
from .extension import repro_extensions
|
||||||
@ -85,6 +86,21 @@ class Repro(BASE_REPRO, ORMMixin):
|
|||||||
|
|
||||||
self.state = VmState.extensions_launch
|
self.state = VmState.extensions_launch
|
||||||
else:
|
else:
|
||||||
|
nsg = NSG(
|
||||||
|
name=vm.region,
|
||||||
|
region=vm.region,
|
||||||
|
)
|
||||||
|
result = nsg.create()
|
||||||
|
if isinstance(result, Error):
|
||||||
|
self.set_failed(result)
|
||||||
|
return
|
||||||
|
|
||||||
|
result = nsg.set_allowed_sources(["*"])
|
||||||
|
if isinstance(result, Error):
|
||||||
|
self.set_failed(result)
|
||||||
|
return
|
||||||
|
|
||||||
|
vm.nsg = nsg
|
||||||
result = vm.create()
|
result = vm.create()
|
||||||
if isinstance(result, Error):
|
if isinstance(result, Error):
|
||||||
self.set_error(result)
|
self.set_error(result)
|
||||||
|
Reference in New Issue
Block a user