Revert "NSG Updated After CLI Update to Instance_Config (#1375)" (#1384)

This reverts commit 357bc4fcad.
This commit is contained in:
Noah McGregor Harper
2021-10-21 12:51:20 -07:00
committed by GitHub
parent 357bc4fcad
commit b238bfea03
7 changed files with 3 additions and 477 deletions

View File

@ -649,10 +649,6 @@ Each event will be submitted via HTTP POST to the user provided URL.
"address_space": "10.0.0.0/8", "address_space": "10.0.0.0/8",
"subnet": "10.0.0.0/16" "subnet": "10.0.0.0/16"
}, },
"proxy_nsg_config": {
"allowed_ips": [],
"allowed_service_tags": []
},
"proxy_vm_sku": "Standard_B2s" "proxy_vm_sku": "Standard_B2s"
} }
} }
@ -763,9 +759,6 @@ Each event will be submitted via HTTP POST to the user provided URL.
"network_config": { "network_config": {
"$ref": "#/definitions/NetworkConfig" "$ref": "#/definitions/NetworkConfig"
}, },
"proxy_nsg_config": {
"$ref": "#/definitions/NetworkSecurityGroupConfig"
},
"proxy_vm_sku": { "proxy_vm_sku": {
"default": "Standard_B2s", "default": "Standard_B2s",
"title": "Proxy Vm Sku", "title": "Proxy Vm Sku",
@ -821,26 +814,6 @@ Each event will be submitted via HTTP POST to the user provided URL.
}, },
"title": "NetworkConfig", "title": "NetworkConfig",
"type": "object" "type": "object"
},
"NetworkSecurityGroupConfig": {
"properties": {
"allowed_ips": {
"items": {
"type": "string"
},
"title": "Allowed Ips",
"type": "array"
},
"allowed_service_tags": {
"items": {
"type": "string"
},
"title": "Allowed Service Tags",
"type": "array"
}
},
"title": "NetworkSecurityGroupConfig",
"type": "object"
} }
}, },
"properties": { "properties": {
@ -5857,9 +5830,6 @@ Each event will be submitted via HTTP POST to the user provided URL.
"network_config": { "network_config": {
"$ref": "#/definitions/NetworkConfig" "$ref": "#/definitions/NetworkConfig"
}, },
"proxy_nsg_config": {
"$ref": "#/definitions/NetworkSecurityGroupConfig"
},
"proxy_vm_sku": { "proxy_vm_sku": {
"default": "Standard_B2s", "default": "Standard_B2s",
"title": "Proxy Vm Sku", "title": "Proxy Vm Sku",
@ -5967,26 +5937,6 @@ Each event will be submitted via HTTP POST to the user provided URL.
"title": "NetworkConfig", "title": "NetworkConfig",
"type": "object" "type": "object"
}, },
"NetworkSecurityGroupConfig": {
"properties": {
"allowed_ips": {
"items": {
"type": "string"
},
"title": "Allowed Ips",
"type": "array"
},
"allowed_service_tags": {
"items": {
"type": "string"
},
"title": "Allowed Service Tags",
"type": "array"
}
},
"title": "NetworkSecurityGroupConfig",
"type": "object"
},
"NoReproReport": { "NoReproReport": {
"properties": { "properties": {
"error": { "error": {

View File

@ -8,11 +8,9 @@ from onefuzztypes.enums import ErrorCode
from onefuzztypes.models import Error from onefuzztypes.models import Error
from onefuzztypes.requests import InstanceConfigUpdate from onefuzztypes.requests import InstanceConfigUpdate
from ..onefuzzlib.azure.nsg import set_allowed
from ..onefuzzlib.config import InstanceConfig from ..onefuzzlib.config import InstanceConfig
from ..onefuzzlib.endpoint_authorization import call_if_user, can_modify_config from ..onefuzzlib.endpoint_authorization import call_if_user, can_modify_config
from ..onefuzzlib.request import not_ok, ok, parse_request from ..onefuzzlib.request import not_ok, ok, parse_request
from ..onefuzzlib.workers.scalesets import Scaleset
def get(req: func.HttpRequest) -> func.HttpResponse: def get(req: func.HttpRequest) -> func.HttpResponse:
@ -32,34 +30,8 @@ def post(req: func.HttpRequest) -> func.HttpResponse:
context="instance_config_update", context="instance_config_update",
) )
update_nsg = False
if request.config.proxy_nsg_config and config.proxy_nsg_config:
request_config = request.config.proxy_nsg_config
current_config = config.proxy_nsg_config
if set(request_config.allowed_service_tags) != set(
current_config.allowed_service_tags
) or set(request_config.allowed_ips) != set(current_config.allowed_ips):
update_nsg = True
config.update(request.config) config.update(request.config)
config.save() config.save()
# Update All NSGs
if update_nsg:
scalesets = Scaleset.search()
regions = set(x.region for x in scalesets)
for region in regions:
# nsg = get_nsg(region)
result = set_allowed(region, request.config.proxy_nsg_config)
if isinstance(result, Error):
return not_ok(
Error(
code=ErrorCode.UNABLE_TO_CREATE,
errors=["Unable to update nsg %s due to %s" % (region, result)],
),
context="instance_config_update",
)
return ok(config) return ok(config)

View File

@ -1,337 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
import os
from typing import Dict, List, Optional, Union, cast
from azure.core.exceptions import HttpResponseError, ResourceNotFoundError
from azure.mgmt.network.models import (
NetworkInterface,
NetworkSecurityGroup,
SecurityRule,
SecurityRuleAccess,
SecurityRuleProtocol,
)
from msrestazure.azure_exceptions import CloudError
from onefuzztypes.enums import ErrorCode
from onefuzztypes.models import Error, NetworkSecurityGroupConfig
from onefuzztypes.primitives import Region
from pydantic import BaseModel, validator
from .creds import get_base_resource_group
from .network_mgmt_client import get_network_client
def is_concurrent_request_error(err: str) -> bool:
return "The request failed due to conflict with a concurrent request" in str(err)
def get_nsg(name: str) -> Optional[NetworkSecurityGroup]:
resource_group = get_base_resource_group()
logging.debug("getting nsg: %s", name)
network_client = get_network_client()
try:
nsg = network_client.network_security_groups.get(resource_group, name)
return cast(NetworkSecurityGroup, nsg)
except (ResourceNotFoundError, CloudError) as err:
logging.debug("nsg %s does not exist: %s", name, err)
return None
def create_nsg(name: str, location: Region) -> Union[None, Error]:
resource_group = get_base_resource_group()
logging.info("creating nsg %s:%s:%s", resource_group, location, name)
network_client = get_network_client()
params: Dict = {
"location": location,
}
if "ONEFUZZ_OWNER" in os.environ:
params["tags"] = {"OWNER": os.environ["ONEFUZZ_OWNER"]}
try:
network_client.network_security_groups.begin_create_or_update(
resource_group, name, params
)
except (ResourceNotFoundError, CloudError) as err:
if is_concurrent_request_error(str(err)):
logging.debug(
"create NSG had conflicts with concurrent request, ignoring %s", err
)
return None
return Error(
code=ErrorCode.UNABLE_TO_CREATE,
errors=["Unable to create nsg %s due to %s" % (name, err)],
)
return None
def update_nsg(nsg: NetworkSecurityGroup) -> Union[None, Error]:
resource_group = get_base_resource_group()
logging.info("updating nsg %s:%s:%s", resource_group, nsg.location, nsg.name)
network_client = get_network_client()
try:
network_client.network_security_groups.begin_create_or_update(
resource_group, nsg.name, nsg
)
except (ResourceNotFoundError, CloudError) as err:
if is_concurrent_request_error(str(err)):
logging.debug(
"create NSG had conflicts with concurrent request, ignoring %s", err
)
return None
return Error(
code=ErrorCode.UNABLE_TO_CREATE,
errors=["Unable to update nsg %s due to %s" % (nsg.name, err)],
)
return None
def delete_nsg(name: str) -> bool:
# NSG can be only deleted if no other resource is associated with it
resource_group = get_base_resource_group()
logging.info("deleting nsg: %s %s", resource_group, name)
network_client = get_network_client()
try:
network_client.network_security_groups.begin_delete(resource_group, name)
return True
except HttpResponseError as err:
err_str = str(err)
if (
"cannot be deleted because it is in use by the following resources"
) in err_str:
return False
except ResourceNotFoundError:
return True
return False
def set_allowed(name: str, sources: NetworkSecurityGroupConfig) -> Union[None, Error]:
resource_group = get_base_resource_group()
nsg = get_nsg(name)
if not nsg:
return Error(
code=ErrorCode.UNABLE_TO_FIND,
errors=["cannot update nsg rules. nsg %s not found" % name],
)
logging.info(
"setting allowed incoming connection sources for nsg: %s %s",
resource_group,
name,
)
all_sources = sources.allowed_ips + sources.allowed_service_tags
security_rules = []
# NSG security rule priority range defined here:
# https://docs.microsoft.com/en-us/azure/virtual-network/network-security-groups-overview
min_priority = 100
# NSG rules per NSG limits:
# https://docs.microsoft.com/en-us/azure/azure-resource-manager/management/azure-subscription-service-limits?toc=/azure/virtual-network/toc.json#networking-limits
max_rule_count = 1000
if len(all_sources) > max_rule_count:
return Error(
code=ErrorCode.INVALID_REQUEST,
errors=[
"too many rules provided %d. Max allowed: %d"
% ((len(all_sources)), max_rule_count),
],
)
priority = min_priority
for src in all_sources:
security_rules.append(
SecurityRule(
name="Allow" + str(priority),
protocol=SecurityRuleProtocol.TCP,
source_port_range="*",
destination_port_range="*",
source_address_prefix=src,
destination_address_prefix="*",
access=SecurityRuleAccess.ALLOW,
priority=priority, # between 100 and 4096
direction="Inbound",
)
)
# Will not exceed `max_rule_count` or max NSG priority (4096)
# due to earlier check of `len(all_sources)`.
priority += 1
nsg.security_rules = security_rules
return update_nsg(nsg)
def clear_all_rules(name: str) -> Union[None, Error]:
return set_allowed(name, NetworkSecurityGroupConfig())
def get_all_rules(name: str) -> Union[Error, List[SecurityRule]]:
nsg = get_nsg(name)
if not nsg:
return Error(
code=ErrorCode.UNABLE_TO_FIND,
errors=["cannot get nsg rules. nsg %s not found" % name],
)
return cast(List[SecurityRule], nsg.security_rules)
def associate_nic(name: str, nic: NetworkInterface) -> Union[None, Error]:
resource_group = get_base_resource_group()
nsg = get_nsg(name)
if not nsg:
return Error(
code=ErrorCode.UNABLE_TO_FIND,
errors=["cannot associate nic. nsg %s not found" % name],
)
if nsg.location != nic.location:
return Error(
code=ErrorCode.UNABLE_TO_UPDATE,
errors=[
"network interface and nsg have to be in the same region.",
"nsg %s %s, nic: %s %s"
% (nsg.name, nsg.location, nic.name, nic.location),
],
)
if nic.network_security_group and nic.network_security_group.id == nsg.id:
return None
logging.info("associating nic %s with nsg: %s %s", nic.name, resource_group, name)
nic.network_security_group = nsg
network_client = get_network_client()
try:
network_client.network_interfaces.begin_create_or_update(
resource_group, nic.name, nic
)
except (ResourceNotFoundError, CloudError) as err:
if is_concurrent_request_error(str(err)):
logging.debug(
"associate NSG with NIC had conflicts",
"with concurrent request, ignoring %s",
err,
)
return None
return Error(
code=ErrorCode.UNABLE_TO_UPDATE,
errors=[
"Unable to associate nsg %s with nic %s due to %s"
% (
name,
nic.name,
err,
)
],
)
return None
def dissociate_nic(name: str, nic: NetworkInterface) -> Union[None, Error]:
if nic.network_security_group is None:
return None
resource_group = get_base_resource_group()
nsg = get_nsg(name)
if not nsg:
return Error(
code=ErrorCode.UNABLE_TO_FIND,
errors=["cannot update nsg rules. nsg %s not found" % name],
)
if nsg.id != nic.network_security_group.id:
return Error(
code=ErrorCode.UNABLE_TO_UPDATE,
errors=[
"network interface is not associated with this nsg.",
"nsg %s, nic: %s, nic.nsg: %s"
% (
nsg.id,
nic.name,
nic.network_security_group.id,
),
],
)
logging.info("dissociating nic %s with nsg: %s %s", nic.name, resource_group, name)
nic.network_security_group = None
network_client = get_network_client()
try:
network_client.network_interfaces.begin_create_or_update(
resource_group, nic.name, nic
)
except (ResourceNotFoundError, CloudError) as err:
if is_concurrent_request_error(str(err)):
logging.debug(
"associate NSG with NIC had conflicts with ",
"concurrent request, ignoring %s",
err,
)
return None
return Error(
code=ErrorCode.UNABLE_TO_UPDATE,
errors=[
"Unable to associate nsg %s with nic %s due to %s"
% (
name,
nic.name,
err,
)
],
)
return None
class NSG(BaseModel):
name: str
region: Region
@validator("name", allow_reuse=True)
def check_name(cls, value: str) -> str:
# https://docs.microsoft.com/en-us/azure/azure-resource-manager/management/resource-name-rules
if len(value) > 80:
raise ValueError("NSG name too long")
return value
def create(self) -> Union[None, Error]:
# Optimization: if NSG exists - do not try
# to create it
if self.get() is not None:
return None
return create_nsg(self.name, self.region)
def delete(self) -> bool:
return delete_nsg(self.name)
def get(self) -> Optional[NetworkSecurityGroup]:
return get_nsg(self.name)
def set_allowed_sources(
self, sources: NetworkSecurityGroupConfig
) -> Union[None, Error]:
return set_allowed(self.name, sources)
def clear_all_rules(self) -> Union[None, Error]:
return clear_all_rules(self.name)
def get_all_rules(self) -> Union[Error, List[SecurityRule]]:
return get_all_rules(self.name)
def associate_nic(self, nic: NetworkInterface) -> Union[None, Error]:
return associate_nic(self.name, nic)
def dissociate_nic(self, nic: NetworkInterface) -> Union[None, Error]:
return dissociate_nic(self.name, nic)

View File

@ -21,7 +21,6 @@ from .creds import get_base_resource_group
from .disk import delete_disk, list_disks from .disk import delete_disk, list_disks
from .image import get_os from .image import get_os
from .ip import create_public_nic, delete_ip, delete_nic, get_ip, get_public_nic from .ip import create_public_nic, delete_ip, delete_nic, get_ip, get_public_nic
from .nsg import NSG
def get_vm(name: str) -> Optional[VirtualMachine]: def get_vm(name: str) -> Optional[VirtualMachine]:
@ -48,7 +47,6 @@ def create_vm(
image: str, image: str,
password: str, password: str,
ssh_public_key: str, ssh_public_key: str,
nsg: Optional[NSG],
) -> Union[None, Error]: ) -> Union[None, Error]:
resource_group = get_base_resource_group() resource_group = get_base_resource_group()
logging.info("creating vm %s:%s:%s", resource_group, location, name) logging.info("creating vm %s:%s:%s", resource_group, location, name)
@ -62,10 +60,6 @@ def create_vm(
return result return result
logging.info("waiting on nic creation") logging.info("waiting on nic creation")
return None return None
if nsg:
result = nsg.associate_nic(nic)
if isinstance(result, Error):
return result
if image.startswith("/"): if image.startswith("/"):
image_ref = {"id": image} image_ref = {"id": image}
@ -187,7 +181,7 @@ def has_components(name: str) -> bool:
return False return False
def delete_vm_components(name: str, nsg: Optional[NSG]) -> bool: def delete_vm_components(name: str) -> bool:
resource_group = get_base_resource_group() resource_group = get_base_resource_group()
logging.info("deleting vm components %s:%s", resource_group, name) logging.info("deleting vm components %s:%s", resource_group, name)
if get_vm(name): if get_vm(name):
@ -195,12 +189,8 @@ def delete_vm_components(name: str, nsg: Optional[NSG]) -> bool:
delete_vm(name) delete_vm(name)
return False return False
nic = get_public_nic(resource_group, name) if get_public_nic(resource_group, name):
if nic:
logging.info("deleting nic %s:%s", resource_group, name) logging.info("deleting nic %s:%s", resource_group, name)
if nic.network_security_group and nsg:
nsg.dissociate_nic(nic)
return False
delete_nic(resource_group, name) delete_nic(resource_group, name)
return False return False
@ -225,7 +215,6 @@ class VM(BaseModel):
sku: str sku: str
image: str image: str
auth: Authentication auth: Authentication
nsg: Optional[NSG]
@validator("name", allow_reuse=True) @validator("name", allow_reuse=True)
def check_name(cls, value: Union[UUID, str]) -> Union[UUID, str]: def check_name(cls, value: Union[UUID, str]) -> Union[UUID, str]:
@ -259,11 +248,10 @@ class VM(BaseModel):
self.image, self.image,
self.auth.password, self.auth.password,
self.auth.public_key, self.auth.public_key,
self.nsg,
) )
def delete(self) -> bool: def delete(self) -> bool:
return delete_vm_components(str(self.name), self.nsg) return delete_vm_components(str(self.name))
def add_extensions(self, extensions: List[Extension]) -> Union[bool, Error]: def add_extensions(self, extensions: List[Extension]) -> Union[bool, Error]:
status = [] status = []

View File

@ -33,7 +33,6 @@ from .azure.auth import build_auth
from .azure.containers import get_file_sas_url, save_blob from .azure.containers import get_file_sas_url, save_blob
from .azure.creds import get_instance_id from .azure.creds import get_instance_id
from .azure.ip import get_public_ip from .azure.ip import get_public_ip
from .azure.nsg import NSG
from .azure.queue import get_queue_sas from .azure.queue import get_queue_sas
from .azure.storage import StorageType from .azure.storage import StorageType
from .azure.vm import VM from .azure.vm import VM
@ -91,25 +90,6 @@ class Proxy(ORMMixin):
self.save_proxy_config() self.save_proxy_config()
self.set_state(VmState.extensions_launch) self.set_state(VmState.extensions_launch)
else: else:
nsg = NSG(
name=self.region,
region=self.region,
)
result = nsg.create()
if isinstance(result, Error):
self.set_failed(result)
return
config = InstanceConfig.fetch()
nsg_config = config.proxy_nsg_config
result = nsg.set_allowed_sources(nsg_config)
if isinstance(result, Error):
self.set_failed(result)
return
vm.nsg = nsg
result = vm.create() result = vm.create()
if isinstance(result, Error): if isinstance(result, Error):
self.set_failed(result) self.set_failed(result)

View File

@ -18,10 +18,8 @@ from .azure.auth import build_auth
from .azure.containers import save_blob from .azure.containers import save_blob
from .azure.creds import get_base_region from .azure.creds import get_base_region
from .azure.ip import get_public_ip from .azure.ip import get_public_ip
from .azure.nsg import NSG
from .azure.storage import StorageType from .azure.storage import StorageType
from .azure.vm import VM from .azure.vm import VM
from .config import InstanceConfig
from .extension import repro_extensions from .extension import repro_extensions
from .orm import ORMMixin, QueryFilter from .orm import ORMMixin, QueryFilter
from .reports import get_report from .reports import get_report
@ -87,23 +85,6 @@ class Repro(BASE_REPRO, ORMMixin):
self.state = VmState.extensions_launch self.state = VmState.extensions_launch
else: else:
nsg = NSG(
name=vm.region,
region=vm.region,
)
result = nsg.create()
if isinstance(result, Error):
self.set_failed(result)
return
config = InstanceConfig.fetch()
nsg_config = config.proxy_nsg_config
result = nsg.set_allowed_sources(nsg_config)
if isinstance(result, Error):
self.set_failed(result)
return
vm.nsg = nsg
result = vm.create() result = vm.create()
if isinstance(result, Error): if isinstance(result, Error):
self.set_error(result) self.set_error(result)

View File

@ -799,11 +799,6 @@ class NetworkConfig(BaseModel):
subnet: str = Field(default="10.0.0.0/16") subnet: str = Field(default="10.0.0.0/16")
class NetworkSecurityGroupConfig(BaseModel):
allowed_service_tags: List[str] = Field(default_factory=list)
allowed_ips: List[str] = Field(default_factory=list)
class KeyvaultExtensionConfig(BaseModel): class KeyvaultExtensionConfig(BaseModel):
keyvault_name: str keyvault_name: str
cert_name: str cert_name: str
@ -846,9 +841,6 @@ class InstanceConfig(BaseModel):
allowed_aad_tenants: List[UUID] allowed_aad_tenants: List[UUID]
network_config: NetworkConfig = Field(default_factory=NetworkConfig) network_config: NetworkConfig = Field(default_factory=NetworkConfig)
proxy_nsg_config: NetworkSecurityGroupConfig = Field(
default_factory=NetworkSecurityGroupConfig
)
extensions: Optional[AzureVmExtensionConfig] extensions: Optional[AzureVmExtensionConfig]
proxy_vm_sku: str = Field(default="Standard_B2s") proxy_vm_sku: str = Field(default="Standard_B2s")