Files
onefuzz/src/api-service/__app__/onefuzzlib/proxy.py
Noah McGregor Harper 83e48e7e7b Adding new InstanceConfig value for VMSS & VM tags (#1560)
* Adding new instanceconfig value for tags.

* Removing bad import.

* Updating where tags are generated.

* Updating tag generation for scalesets.

* Updating tag generation in vm.

* Updating vm tag generation.

* Updating vm tag generation.

* Fixing extension.

* Fixing import.

* Fixing typing.

* Fixing get_vm calls.

* Fixing calls to get_vm.

* Fixing optional tag.
2022-01-05 13:16:03 -08:00

347 lines
11 KiB
Python

#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import datetime
import logging
import os
from typing import List, Optional, Tuple
from uuid import UUID, uuid4
import base58
from azure.mgmt.compute.models import VirtualMachine
from onefuzztypes.enums import ErrorCode, VmState
from onefuzztypes.events import (
EventProxyCreated,
EventProxyDeleted,
EventProxyFailed,
EventProxyStateUpdated,
)
from onefuzztypes.models import (
Authentication,
Error,
Forward,
ProxyConfig,
ProxyHeartbeat,
)
from onefuzztypes.primitives import Container, Region
from pydantic import Field
from .__version__ import __version__
from .azure.auth import build_auth
from .azure.containers import get_file_sas_url, save_blob
from .azure.creds import get_instance_id
from .azure.ip import get_public_ip
from .azure.nsg import NSG
from .azure.queue import get_queue_sas
from .azure.storage import StorageType
from .azure.vm import VM
from .config import InstanceConfig
from .events import send_event
from .extension import proxy_manager_extensions
from .orm import ORMMixin, QueryFilter
from .proxy_forward import ProxyForward
PROXY_IMAGE = "Canonical:UbuntuServer:18.04-LTS:latest"
PROXY_LOG_PREFIX = "scaleset-proxy: "
PROXY_LIFESPAN = datetime.timedelta(days=7)
# This isn't intended to ever be shared to the client, hence not being in
# onefuzztypes
class Proxy(ORMMixin):
timestamp: Optional[datetime.datetime] = Field(alias="Timestamp")
created_timestamp: datetime.datetime = Field(
default_factory=datetime.datetime.utcnow
)
proxy_id: UUID = Field(default_factory=uuid4)
region: Region
state: VmState = Field(default=VmState.init)
auth: Authentication = Field(default_factory=build_auth)
ip: Optional[str]
error: Optional[Error]
version: str = Field(default=__version__)
heartbeat: Optional[ProxyHeartbeat]
outdated: bool = Field(default=False)
@classmethod
def key_fields(cls) -> Tuple[str, Optional[str]]:
return ("region", "proxy_id")
def get_vm(self, config: InstanceConfig) -> VM:
sku = config.proxy_vm_sku
tags = None
if config.vm_tags:
tags = config.vm_tags
vm = VM(
name="proxy-%s" % base58.b58encode(self.proxy_id.bytes).decode(),
region=self.region,
sku=sku,
image=PROXY_IMAGE,
auth=self.auth,
tags=tags,
)
return vm
def init(self) -> None:
config = InstanceConfig.fetch()
vm = self.get_vm(config)
vm_data = vm.get()
if vm_data:
if vm_data.provisioning_state == "Failed":
self.set_provision_failed(vm_data)
return
else:
self.save_proxy_config()
self.set_state(VmState.extensions_launch)
else:
nsg = NSG(
name=self.region,
region=self.region,
)
result = nsg.create()
if isinstance(result, Error):
self.set_failed(result)
return
nsg_config = config.proxy_nsg_config
result = nsg.set_allowed_sources(nsg_config)
if isinstance(result, Error):
self.set_failed(result)
return
vm.nsg = nsg
result = vm.create()
if isinstance(result, Error):
self.set_failed(result)
return
self.save()
def set_provision_failed(self, vm_data: VirtualMachine) -> None:
errors = ["provisioning failed"]
for status in vm_data.instance_view.statuses:
if status.level.name.lower() == "error":
errors.append(
f"code:{status.code} status:{status.display_status} "
f"message:{status.message}"
)
self.set_failed(
Error(
code=ErrorCode.PROXY_FAILED,
errors=errors,
)
)
return
def set_failed(self, error: Error) -> None:
if self.error is not None:
return
logging.error(PROXY_LOG_PREFIX + "vm failed: %s - %s", self.region, error)
send_event(
EventProxyFailed(region=self.region, proxy_id=self.proxy_id, error=error)
)
self.error = error
self.set_state(VmState.stopping)
def extensions_launch(self) -> None:
config = InstanceConfig.fetch()
vm = self.get_vm(config)
vm_data = vm.get()
if not vm_data:
self.set_failed(
Error(
code=ErrorCode.PROXY_FAILED,
errors=["azure not able to find vm"],
)
)
return
if vm_data.provisioning_state == "Failed":
self.set_provision_failed(vm_data)
return
ip = get_public_ip(vm_data.network_profile.network_interfaces[0].id)
if ip is None:
self.save()
return
self.ip = ip
extensions = proxy_manager_extensions(self.region, self.proxy_id)
result = vm.add_extensions(extensions)
if isinstance(result, Error):
self.set_failed(result)
return
elif result:
self.set_state(VmState.running)
self.save()
def stopping(self) -> None:
config = InstanceConfig.fetch()
vm = self.get_vm(config)
if not vm.is_deleted():
logging.info(PROXY_LOG_PREFIX + "stopping proxy: %s", self.region)
vm.delete()
self.save()
else:
self.stopped()
def stopped(self) -> None:
self.set_state(VmState.stopped)
logging.info(PROXY_LOG_PREFIX + "removing proxy: %s", self.region)
send_event(EventProxyDeleted(region=self.region, proxy_id=self.proxy_id))
self.delete()
def is_outdated(self) -> bool:
if self.state not in VmState.available():
return True
if self.version != __version__:
logging.info(
PROXY_LOG_PREFIX + "mismatch version: proxy:%s service:%s state:%s",
self.version,
__version__,
self.state,
)
return True
if self.created_timestamp is not None:
proxy_timestamp = self.created_timestamp
if proxy_timestamp < (
datetime.datetime.now(tz=datetime.timezone.utc) - PROXY_LIFESPAN
):
logging.info(
PROXY_LOG_PREFIX
+ "proxy older than 7 days:proxy-created:%s state:%s",
self.created_timestamp,
self.state,
)
return True
return False
def is_used(self) -> bool:
if len(self.get_forwards()) == 0:
logging.info(PROXY_LOG_PREFIX + "no forwards: %s", self.region)
return False
return True
def is_alive(self) -> bool:
# Unfortunately, with and without TZ information is required for compare
# or exceptions are generated
ten_minutes_ago_no_tz = datetime.datetime.utcnow() - datetime.timedelta(
minutes=10
)
ten_minutes_ago = ten_minutes_ago_no_tz.astimezone(datetime.timezone.utc)
if (
self.heartbeat is not None
and self.heartbeat.timestamp < ten_minutes_ago_no_tz
):
logging.error(
PROXY_LOG_PREFIX + "last heartbeat is more than an 10 minutes old: "
"%s - last heartbeat:%s compared_to:%s",
self.region,
self.heartbeat,
ten_minutes_ago_no_tz,
)
return False
elif not self.heartbeat and self.timestamp and self.timestamp < ten_minutes_ago:
logging.error(
PROXY_LOG_PREFIX + "no heartbeat in the last 10 minutes: "
"%s timestamp: %s compared_to:%s",
self.region,
self.timestamp,
ten_minutes_ago,
)
return False
return True
def get_forwards(self) -> List[Forward]:
forwards: List[Forward] = []
for entry in ProxyForward.search_forward(
region=self.region, proxy_id=self.proxy_id
):
if entry.endtime < datetime.datetime.now(tz=datetime.timezone.utc):
entry.delete()
else:
forwards.append(
Forward(
src_port=entry.port,
dst_ip=entry.dst_ip,
dst_port=entry.dst_port,
)
)
return forwards
def save_proxy_config(self) -> None:
forwards = self.get_forwards()
proxy_config = ProxyConfig(
url=get_file_sas_url(
Container("proxy-configs"),
"%s/%s/config.json" % (self.region, self.proxy_id),
StorageType.config,
read=True,
),
notification=get_queue_sas(
"proxy",
StorageType.config,
add=True,
),
forwards=forwards,
region=self.region,
proxy_id=self.proxy_id,
instance_telemetry_key=os.environ.get("APPINSIGHTS_INSTRUMENTATIONKEY"),
microsoft_telemetry_key=os.environ.get("ONEFUZZ_TELEMETRY"),
instance_id=get_instance_id(),
)
save_blob(
Container("proxy-configs"),
"%s/%s/config.json" % (self.region, self.proxy_id),
proxy_config.json(),
StorageType.config,
)
@classmethod
def search_states(cls, *, states: Optional[List[VmState]] = None) -> List["Proxy"]:
query: QueryFilter = {}
if states:
query["state"] = states
return cls.search(query=query)
@classmethod
def get_or_create(cls, region: Region) -> Optional["Proxy"]:
proxy_list = Proxy.search(query={"region": [region], "outdated": [False]})
for proxy in proxy_list:
if proxy.is_outdated():
proxy.outdated = True
proxy.save()
continue
if proxy.state not in VmState.available():
continue
return proxy
logging.info(PROXY_LOG_PREFIX + "creating proxy: region:%s", region)
proxy = Proxy(region=region)
proxy.save()
send_event(EventProxyCreated(region=region, proxy_id=proxy.proxy_id))
return proxy
def set_state(self, state: VmState) -> None:
if self.state == state:
return
self.state = state
self.save()
send_event(
EventProxyStateUpdated(
region=self.region, proxy_id=self.proxy_id, state=self.state
)
)