Refactoring proxy lifetime to only shutdown when proxy is out-of-date. (#839)

## Summary of the Pull Request

_What is this about?_
We'd like to refactor the proxy lifecycle to only delete when the proxy is out-of-date - i.e. when the proxy is older than 7 days or a mismatched version. I've changed two files, proxy.py and timer_daily\init.py to check for the version and timestamp before stopping a live proxy. 

## PR Checklist
* [ ] Applies to work item: #xxx
* [ ] CLA signed. If not, go over [here](https://cla.opensource.microsoft.com/microsoft/onefuzz) and sign the CLI.
* [ ] Tests added/passed
* [ ] Requires documentation to be updated
* [x] I've discussed this with core contributors already. If not checked, I'm ready to accept this work might be rejected in favor of a different grand plan. Issue number where discussion took place: #xxx

## Info on Pull Request

_What does this include?_
Changes to two files: 
proxy.py: 
- get_or_create() edited to check if timestamp is >7 days.
- Created is_outdated() to check version and timestamp for out-of-date proxy. 
timer_daily/init.py
- Proxy check now includes is_outdated() before determining if a proxy should be shutdown. 

## Validation Steps Performed
Deploying test instance to determine if proxy lives past a single day.
This commit is contained in:
nharper285
2021-05-20 07:33:29 -07:00
committed by GitHub
parent 2b67c7b02f
commit 2f81c44f01
13 changed files with 156 additions and 37 deletions

View File

@ -1255,6 +1255,7 @@ Each event will be submitted via HTTP POST to the user provided URL.
```json ```json
{ {
"proxy_id": "00000000-0000-0000-0000-000000000000",
"region": "eastus" "region": "eastus"
} }
``` ```
@ -1264,6 +1265,11 @@ Each event will be submitted via HTTP POST to the user provided URL.
```json ```json
{ {
"properties": { "properties": {
"proxy_id": {
"format": "uuid",
"title": "Proxy Id",
"type": "string"
},
"region": { "region": {
"title": "Region", "title": "Region",
"type": "string" "type": "string"
@ -1283,6 +1289,7 @@ Each event will be submitted via HTTP POST to the user provided URL.
```json ```json
{ {
"proxy_id": "00000000-0000-0000-0000-000000000000",
"region": "eastus" "region": "eastus"
} }
``` ```
@ -1292,6 +1299,11 @@ Each event will be submitted via HTTP POST to the user provided URL.
```json ```json
{ {
"properties": { "properties": {
"proxy_id": {
"format": "uuid",
"title": "Proxy Id",
"type": "string"
},
"region": { "region": {
"title": "Region", "title": "Region",
"type": "string" "type": "string"
@ -1317,6 +1329,7 @@ Each event will be submitted via HTTP POST to the user provided URL.
"example error message" "example error message"
] ]
}, },
"proxy_id": "00000000-0000-0000-0000-000000000000",
"region": "eastus" "region": "eastus"
} }
``` ```
@ -1379,6 +1392,11 @@ Each event will be submitted via HTTP POST to the user provided URL.
"error": { "error": {
"$ref": "#/definitions/Error" "$ref": "#/definitions/Error"
}, },
"proxy_id": {
"format": "uuid",
"title": "Proxy Id",
"type": "string"
},
"region": { "region": {
"title": "Region", "title": "Region",
"type": "string" "type": "string"
@ -4889,6 +4907,11 @@ Each event will be submitted via HTTP POST to the user provided URL.
}, },
"EventProxyCreated": { "EventProxyCreated": {
"properties": { "properties": {
"proxy_id": {
"format": "uuid",
"title": "Proxy Id",
"type": "string"
},
"region": { "region": {
"title": "Region", "title": "Region",
"type": "string" "type": "string"
@ -4902,6 +4925,11 @@ Each event will be submitted via HTTP POST to the user provided URL.
}, },
"EventProxyDeleted": { "EventProxyDeleted": {
"properties": { "properties": {
"proxy_id": {
"format": "uuid",
"title": "Proxy Id",
"type": "string"
},
"region": { "region": {
"title": "Region", "title": "Region",
"type": "string" "type": "string"
@ -4918,6 +4946,11 @@ Each event will be submitted via HTTP POST to the user provided URL.
"error": { "error": {
"$ref": "#/definitions/Error" "$ref": "#/definitions/Error"
}, },
"proxy_id": {
"format": "uuid",
"title": "Proxy Id",
"type": "string"
},
"region": { "region": {
"title": "Region", "title": "Region",
"type": "string" "type": "string"

View File

@ -341,11 +341,11 @@ def repro_extensions(
return extensions return extensions
def proxy_manager_extensions(region: Region) -> List[Extension]: def proxy_manager_extensions(region: Region, proxy_id: UUID) -> List[Extension]:
urls = [ urls = [
get_file_sas_url( get_file_sas_url(
Container("proxy-configs"), Container("proxy-configs"),
"%s/config.json" % region, "%s/%s/config.json" % (region, proxy_id),
StorageType.config, StorageType.config,
read=True, read=True,
), ),

View File

@ -7,6 +7,7 @@ import datetime
import logging import logging
import os import os
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from uuid import UUID, uuid4
from azure.mgmt.compute.models import VirtualMachine from azure.mgmt.compute.models import VirtualMachine
from onefuzztypes.enums import ErrorCode, VmState from onefuzztypes.enums import ErrorCode, VmState
@ -37,12 +38,17 @@ from .proxy_forward import ProxyForward
PROXY_SKU = "Standard_B2s" PROXY_SKU = "Standard_B2s"
PROXY_IMAGE = "Canonical:UbuntuServer:18.04-LTS:latest" PROXY_IMAGE = "Canonical:UbuntuServer:18.04-LTS:latest"
PROXY_LOG_PREFIX = "scaleset-proxy: " PROXY_LOG_PREFIX = "scaleset-proxy: "
PROXY_LIFESPAN = datetime.timedelta(days=7)
# This isn't intended to ever be shared to the client, hence not being in # This isn't intended to ever be shared to the client, hence not being in
# onefuzztypes # onefuzztypes
class Proxy(ORMMixin): class Proxy(ORMMixin):
timestamp: Optional[datetime.datetime] = Field(alias="Timestamp") timestamp: Optional[datetime.datetime] = Field(alias="Timestamp")
created_timestamp: datetime.datetime = Field(
default_factory=datetime.datetime.utcnow
)
proxy_id: UUID = Field(default_factory=uuid4)
region: Region region: Region
state: VmState = Field(default=VmState.init) state: VmState = Field(default=VmState.init)
auth: Authentication = Field(default_factory=build_auth) auth: Authentication = Field(default_factory=build_auth)
@ -50,14 +56,15 @@ class Proxy(ORMMixin):
error: Optional[Error] error: Optional[Error]
version: str = Field(default=__version__) version: str = Field(default=__version__)
heartbeat: Optional[ProxyHeartbeat] heartbeat: Optional[ProxyHeartbeat]
outdated: bool = Field(default=False)
@classmethod @classmethod
def key_fields(cls) -> Tuple[str, Optional[str]]: def key_fields(cls) -> Tuple[str, Optional[str]]:
return ("region", None) return ("region", "proxy_id")
def get_vm(self) -> VM: def get_vm(self) -> VM:
vm = VM( vm = VM(
name="proxy-%s" % self.region, name="proxy-%s-%s" % (self.region, self.proxy_id),
region=self.region, region=self.region,
sku=PROXY_SKU, sku=PROXY_SKU,
image=PROXY_IMAGE, image=PROXY_IMAGE,
@ -104,7 +111,9 @@ class Proxy(ORMMixin):
return return
logging.error(PROXY_LOG_PREFIX + "vm failed: %s - %s", self.region, error) logging.error(PROXY_LOG_PREFIX + "vm failed: %s - %s", self.region, error)
send_event(EventProxyFailed(region=self.region, error=error)) send_event(
EventProxyFailed(region=self.region, proxy_id=self.proxy_id, error=error)
)
self.error = error self.error = error
self.state = VmState.stopping self.state = VmState.stopping
self.save() self.save()
@ -131,7 +140,7 @@ class Proxy(ORMMixin):
return return
self.ip = ip self.ip = ip
extensions = proxy_manager_extensions(self.region) extensions = proxy_manager_extensions(self.region, self.proxy_id)
result = vm.add_extensions(extensions) result = vm.add_extensions(extensions)
if isinstance(result, Error): if isinstance(result, Error):
self.set_failed(result) self.set_failed(result)
@ -154,6 +163,29 @@ class Proxy(ORMMixin):
logging.info(PROXY_LOG_PREFIX + "removing proxy: %s", self.region) logging.info(PROXY_LOG_PREFIX + "removing proxy: %s", self.region)
self.delete() self.delete()
def is_outdated(self) -> bool:
if self.version != __version__:
logging.info(
PROXY_LOG_PREFIX + "mismatch version: proxy:%s service:%s state:%s",
self.version,
__version__,
self.state,
)
return True
if self.created_timestamp is not None:
proxy_timestamp = self.created_timestamp
if proxy_timestamp < (
datetime.datetime.now(tz=datetime.timezone.utc) - PROXY_LIFESPAN
):
logging.info(
PROXY_LOG_PREFIX
+ "proxy older than 7 days:proxy-created:%s state:%s",
self.created_timestamp,
self.state,
)
return True
return False
def is_used(self) -> bool: def is_used(self) -> bool:
if len(self.get_forwards()) == 0: if len(self.get_forwards()) == 0:
logging.info(PROXY_LOG_PREFIX + "no forwards: %s", self.region) logging.info(PROXY_LOG_PREFIX + "no forwards: %s", self.region)
@ -194,7 +226,9 @@ class Proxy(ORMMixin):
def get_forwards(self) -> List[Forward]: def get_forwards(self) -> List[Forward]:
forwards: List[Forward] = [] forwards: List[Forward] = []
for entry in ProxyForward.search_forward(region=self.region): for entry in ProxyForward.search_forward(
region=self.region, proxy_id=self.proxy_id
):
if entry.endtime < datetime.datetime.now(tz=datetime.timezone.utc): if entry.endtime < datetime.datetime.now(tz=datetime.timezone.utc):
entry.delete() entry.delete()
else: else:
@ -212,7 +246,7 @@ class Proxy(ORMMixin):
proxy_config = ProxyConfig( proxy_config = ProxyConfig(
url=get_file_sas_url( url=get_file_sas_url(
Container("proxy-configs"), Container("proxy-configs"),
"%s/config.json" % self.region, "%s/%s/config.json" % (self.region, self.proxy_id),
StorageType.config, StorageType.config,
read=True, read=True,
), ),
@ -223,6 +257,7 @@ class Proxy(ORMMixin):
), ),
forwards=forwards, forwards=forwards,
region=self.region, region=self.region,
proxy_id=self.proxy_id,
instance_telemetry_key=os.environ.get("APPINSIGHTS_INSTRUMENTATIONKEY"), instance_telemetry_key=os.environ.get("APPINSIGHTS_INSTRUMENTATIONKEY"),
microsoft_telemetry_key=os.environ.get("ONEFUZZ_TELEMETRY"), microsoft_telemetry_key=os.environ.get("ONEFUZZ_TELEMETRY"),
instance_id=get_instance_id(), instance_id=get_instance_id(),
@ -230,7 +265,7 @@ class Proxy(ORMMixin):
save_blob( save_blob(
Container("proxy-configs"), Container("proxy-configs"),
"%s/config.json" % self.region, "%s/%s/config.json" % (self.region, self.proxy_id),
proxy_config.json(), proxy_config.json(),
StorageType.config, StorageType.config,
) )
@ -244,28 +279,22 @@ class Proxy(ORMMixin):
@classmethod @classmethod
def get_or_create(cls, region: Region) -> Optional["Proxy"]: def get_or_create(cls, region: Region) -> Optional["Proxy"]:
proxy = Proxy.get(region) proxy_list = Proxy.search(query={"region": [region], "outdated": [False]})
if proxy is not None: for proxy in proxy_list:
if proxy.version != __version__: if proxy.is_outdated():
logging.info( proxy.outdated = True
PROXY_LOG_PREFIX + "mismatch version: proxy:%s service:%s state:%s",
proxy.version,
__version__,
proxy.state,
)
if proxy.state != VmState.stopping:
# If the proxy is out-of-date, delete and re-create it
proxy.state = VmState.stopping
proxy.save() proxy.save()
return None continue
if proxy.state not in VmState.available():
continue
return proxy return proxy
logging.info(PROXY_LOG_PREFIX + "creating proxy: region:%s", region) logging.info(PROXY_LOG_PREFIX + "creating proxy: region:%s", region)
proxy = Proxy(region=region) proxy = Proxy(region=region)
proxy.save() proxy.save()
send_event(EventProxyCreated(region=region)) send_event(EventProxyCreated(region=region, proxy_id=proxy.proxy_id))
return proxy return proxy
def delete(self) -> None: def delete(self) -> None:
super().delete() super().delete()
send_event(EventProxyDeleted(region=self.region)) send_event(EventProxyDeleted(region=self.region, proxy_id=self.proxy_id))

View File

@ -26,6 +26,7 @@ class ProxyForward(ORMMixin):
port: int port: int
scaleset_id: UUID scaleset_id: UUID
machine_id: UUID machine_id: UUID
proxy_id: Optional[UUID]
dst_ip: str dst_ip: str
dst_port: int dst_port: int
endtime: datetime.datetime = Field(default_factory=datetime.datetime.utcnow) endtime: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
@ -93,11 +94,15 @@ class ProxyForward(ORMMixin):
cls, cls,
scaleset_id: UUID, scaleset_id: UUID,
*, *,
proxy_id: Optional[UUID] = None,
machine_id: Optional[UUID] = None, machine_id: Optional[UUID] = None,
dst_port: Optional[int] = None, dst_port: Optional[int] = None,
) -> List[Region]: ) -> List[Region]:
entries = cls.search_forward( entries = cls.search_forward(
scaleset_id=scaleset_id, machine_id=machine_id, dst_port=dst_port scaleset_id=scaleset_id,
machine_id=machine_id,
proxy_id=proxy_id,
dst_port=dst_port,
) )
regions = set() regions = set()
for entry in entries: for entry in entries:
@ -112,6 +117,7 @@ class ProxyForward(ORMMixin):
scaleset_id: Optional[UUID] = None, scaleset_id: Optional[UUID] = None,
region: Optional[Region] = None, region: Optional[Region] = None,
machine_id: Optional[UUID] = None, machine_id: Optional[UUID] = None,
proxy_id: Optional[UUID] = None,
dst_port: Optional[int] = None, dst_port: Optional[int] = None,
) -> List["ProxyForward"]: ) -> List["ProxyForward"]:
@ -125,6 +131,9 @@ class ProxyForward(ORMMixin):
if machine_id is not None: if machine_id is not None:
query["machine_id"] = [machine_id] query["machine_id"] = [machine_id]
if proxy_id is not None:
query["proxy_id"] = [proxy_id]
if dst_port is not None: if dst_port is not None:
query["dst_port"] = [dst_port] query["dst_port"] = [dst_port]

View File

@ -25,7 +25,6 @@ from ..azure.queue import create_queue, delete_queue
from ..azure.storage import StorageType from ..azure.storage import StorageType
from ..events import send_event from ..events import send_event
from ..orm import MappingIntStrAny, ORMMixin, QueryFilter from ..orm import MappingIntStrAny, ORMMixin, QueryFilter
from ..proxy_forward import ProxyForward
from ..workers.nodes import Node, NodeTasks from ..workers.nodes import Node, NodeTasks
from ..workers.pools import Pool from ..workers.pools import Pool
from ..workers.scalesets import Scaleset from ..workers.scalesets import Scaleset
@ -125,7 +124,6 @@ class Task(BASE_TASK, ORMMixin):
def stopping(self) -> None: def stopping(self) -> None:
logging.info("stopping task: %s:%s", self.job_id, self.task_id) logging.info("stopping task: %s:%s", self.job_id, self.task_id)
ProxyForward.remove_forward(self.task_id)
Node.stop_task(self.task_id) Node.stop_task(self.task_id)
if not NodeTasks.get_nodes_by_task_id(self.task_id): if not NodeTasks.get_nodes_by_task_id(self.task_id):
self.stopped() self.stopped()

View File

@ -79,6 +79,8 @@ def post(req: func.HttpRequest) -> func.HttpResponse:
proxy = Proxy.get_or_create(scaleset.region) proxy = Proxy.get_or_create(scaleset.region)
if proxy: if proxy:
forward.proxy_id = proxy.proxy_id
forward.save()
proxy.save_proxy_config() proxy.save_proxy_config()
return ok(get_result(forward, proxy)) return ok(get_result(forward, proxy))

View File

@ -18,7 +18,7 @@ def main(msg: func.QueueMessage, dashboard: func.Out[str]) -> None:
logging.info(PROXY_LOG_PREFIX + "heartbeat: %s", body) logging.info(PROXY_LOG_PREFIX + "heartbeat: %s", body)
raw = json.loads(body) raw = json.loads(body)
heartbeat = ProxyHeartbeat.parse_obj(raw) heartbeat = ProxyHeartbeat.parse_obj(raw)
proxy = Proxy.get(heartbeat.region) proxy = Proxy.get(heartbeat.region, heartbeat.proxy_id)
if proxy is None: if proxy is None:
logging.warning( logging.warning(
PROXY_LOG_PREFIX + "received heartbeat for missing proxy: %s", body PROXY_LOG_PREFIX + "received heartbeat for missing proxy: %s", body

View File

@ -7,17 +7,40 @@ import logging
import azure.functions as func import azure.functions as func
from onefuzztypes.enums import VmState from onefuzztypes.enums import VmState
from onefuzztypes.events import EventProxyCreated
from ..onefuzzlib.events import get_events from ..onefuzzlib.events import get_events, send_event
from ..onefuzzlib.proxy import Proxy from ..onefuzzlib.proxy import Proxy
from ..onefuzzlib.webhooks import WebhookMessageLog from ..onefuzzlib.webhooks import WebhookMessageLog
from ..onefuzzlib.workers.scalesets import Scaleset from ..onefuzzlib.workers.scalesets import Scaleset
def main(mytimer: func.TimerRequest, dashboard: func.Out[str]) -> None: # noqa: F841 def main(mytimer: func.TimerRequest, dashboard: func.Out[str]) -> None: # noqa: F841
for proxy in Proxy.search(): proxy_list = Proxy.search()
# Marking Outdated Proxies. Subsequently, shutting down Outdated & Unused Proxies.
for proxy in proxy_list:
if proxy.is_outdated():
logging.info("marking proxy in %s as outdated.", proxy.region)
proxy.outdated = True
proxy.save()
# Creating a new proxy if no proxy exists for a given region.
for proxy in proxy_list:
if proxy.outdated:
region_list = list(
filter(
lambda x: (x.region == proxy.region and not x.outdated),
proxy_list,
)
)
if not len(region_list):
logging.info("outdated proxy in %s, creating new one.", proxy.region)
new_proxy = Proxy(region=proxy.region)
new_proxy.save()
send_event(
EventProxyCreated(region=proxy.region, proxy_id=proxy.proxy_id)
)
if not proxy.is_used(): if not proxy.is_used():
logging.info("stopping proxy") logging.info("stopping one proxy in %s.", proxy.region)
proxy.state = VmState.stopping proxy.state = VmState.stopping
proxy.save() proxy.save()

View File

@ -0,0 +1,17 @@
#!/bin/bash
#
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
set -ex
cd src/api-service
pip install -r requirements-dev.txt
black ./__app__ --check
flake8 ./__app__
bandit -r ./__app__
isort --profile black ./__app__ --check
mypy ./__app__ --ignore-missing-imports
pytest -v tests

View File

@ -48,6 +48,7 @@ pub struct ConfigData {
pub instance_telemetry_key: Option<InstanceTelemetryKey>, pub instance_telemetry_key: Option<InstanceTelemetryKey>,
pub microsoft_telemetry_key: Option<MicrosoftTelemetryKey>, pub microsoft_telemetry_key: Option<MicrosoftTelemetryKey>,
pub region: String, pub region: String,
pub proxy_id: Uuid,
pub url: Url, pub url: Url,
pub notification: Url, pub notification: Url,
pub forwards: Vec<Forward>, pub forwards: Vec<Forward>,
@ -56,6 +57,7 @@ pub struct ConfigData {
#[derive(Debug, Deserialize, Serialize, PartialEq)] #[derive(Debug, Deserialize, Serialize, PartialEq)]
pub struct NotifyResponse<'a> { pub struct NotifyResponse<'a> {
pub region: &'a str, pub region: &'a str,
pub proxy_id: Uuid,
pub forwards: Vec<Forward>, pub forwards: Vec<Forward>,
} }
@ -141,6 +143,7 @@ impl Config {
client client
.enqueue(NotifyResponse { .enqueue(NotifyResponse {
region: &self.data.region, region: &self.data.region,
proxy_id: self.data.proxy_id,
forwards: self.data.forwards.clone(), forwards: self.data.forwards.clone(),
}) })
.await?; .await?;

6
src/pytypes/extra/generate-docs.py Normal file → Executable file
View File

@ -156,10 +156,11 @@ def main() -> None:
state=TaskState.init, state=TaskState.init,
config=task_config, config=task_config,
), ),
EventProxyCreated(region=Region("eastus")), EventProxyCreated(region=Region("eastus"), proxy_id=UUID(int=0)),
EventProxyDeleted(region=Region("eastus")), EventProxyDeleted(region=Region("eastus"), proxy_id=UUID(int=0)),
EventProxyFailed( EventProxyFailed(
region=Region("eastus"), region=Region("eastus"),
proxy_id=UUID(int=0),
error=Error(code=ErrorCode.PROXY_FAILED, errors=["example error message"]), error=Error(code=ErrorCode.PROXY_FAILED, errors=["example error message"]),
), ),
EventPoolCreated( EventPoolCreated(
@ -272,7 +273,6 @@ def main() -> None:
) )
result = "" result = ""
result += layer( result += layer(
1, 1,
"Webhook Events", "Webhook Events",

View File

@ -121,14 +121,17 @@ class EventPoolCreated(BaseEvent):
class EventProxyCreated(BaseEvent): class EventProxyCreated(BaseEvent):
region: Region region: Region
proxy_id: Optional[UUID]
class EventProxyDeleted(BaseEvent): class EventProxyDeleted(BaseEvent):
region: Region region: Region
proxy_id: Optional[UUID]
class EventProxyFailed(BaseEvent): class EventProxyFailed(BaseEvent):
region: Region region: Region
proxy_id: Optional[UUID]
error: Error error: Error

View File

@ -432,6 +432,7 @@ class ProxyConfig(BaseModel):
url: str url: str
notification: str notification: str
region: Region region: Region
proxy_id: UUID
forwards: List[Forward] forwards: List[Forward]
instance_telemetry_key: Optional[str] instance_telemetry_key: Optional[str]
microsoft_telemetry_key: Optional[str] microsoft_telemetry_key: Optional[str]
@ -440,6 +441,7 @@ class ProxyConfig(BaseModel):
class ProxyHeartbeat(BaseModel): class ProxyHeartbeat(BaseModel):
region: Region region: Region
proxy_id: UUID
forwards: List[Forward] forwards: List[Forward]
timestamp: datetime = Field(default_factory=datetime.utcnow) timestamp: datetime = Field(default_factory=datetime.utcnow)