Add MQTT predictions for amateur payloads + amateur predictor tests (#124)

* Add MQTT predictions for amateur payloads

* ignore env var changes amateur predictor

---------

Co-authored-by: xss <michaela@michaela.lgbt>
This commit is contained in:
Michaela Wheeler 2023-10-22 07:44:19 +11:00 committed by GitHub
parent 3e849d2606
commit 8a80aa9c2f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 592 additions and 61 deletions

View File

@ -81,6 +81,9 @@ resource "aws_lambda_function" "ham_predict_updater" {
tags = {
Name = "ham_predict_updater"
}
lifecycle {
ignore_changes = [environment]
}
}

View File

@ -1,3 +1,7 @@
import sys
sys.path.append("sns_to_mqtt/vendor")
import paho.mqtt.client as mqtt
import json
from datetime import datetime
from datetime import timedelta
@ -8,7 +12,11 @@ from math import radians, degrees, sin, cos, atan2, sqrt, pi
import es
import asyncio
import functools
import os
import random
import time
TAWHIRI_SERVER = "tawhiri.v2.sondehub.org"
# FLIGHT PROFILE DEFAULTS
#
@ -39,6 +47,49 @@ ALTITUDE_AGL_THRESHOLD = 150.0
# Do not run predictions if the payload is below this altitude AMSL, and has an ascent rate below the above threshold.
ALTITUDE_AMSL_THRESHOLD = 1500.0
# Setup MQTT
client = mqtt.Client(transport="websockets")
connected_flag = False
import socket
socket.setdefaulttimeout(1)
## MQTT functions
def connect():
client.on_connect = on_connect
client.on_disconnect = on_disconnect
client.on_publish = on_publish
#client.tls_set()
client.username_pw_set(username=os.getenv("MQTT_USERNAME"), password=os.getenv("MQTT_PASSWORD"))
HOSTS = os.getenv("MQTT_HOST").split(",")
PORT = int(os.getenv("MQTT_PORT", default="8080"))
if PORT == 443:
client.tls_set()
HOST = random.choice(HOSTS)
print(f"Connecting to {HOST}")
client.connect(HOST, PORT, 5)
client.loop_start()
print("loop started")
def on_disconnect(client, userdata, rc):
global connected_flag
print("disconnected")
connected_flag=False #set flag
def on_connect(client, userdata, flags, rc):
global connected_flag
if rc==0:
print("connected")
connected_flag=True #set flag
else:
print("Bad connection Returned code")
def on_publish(client, userdata, mid):
pass
def get_flight_docs():
path = "flight-doc/_search"
payload = {
@ -252,7 +303,7 @@ def get_float_prediction(timestamp, latitude, longitude, altitude, current_rate=
# Generate the prediction URL
url = f"/api/v1/?launch_altitude={altitude}&launch_latitude={latitude}&launch_longitude={longitude}&launch_datetime={timestamp}&float_altitude={burst_altitude:.2f}&stop_datetime={(datetime.now() + timedelta(days=1)).isoformat()}Z&ascent_rate={ascent_rate:.2f}&profile=float_profile"
logging.debug(url)
conn = http.client.HTTPSConnection("tawhiri.v2.sondehub.org")
conn = http.client.HTTPSConnection(TAWHIRI_SERVER)
conn.request("GET", url)
res = conn.getresponse()
data = res.read()
@ -309,7 +360,7 @@ def get_standard_prediction(timestamp, latitude, longitude, altitude, current_ra
# Generate the prediction URL
url = f"/api/v1/?launch_latitude={latitude}&launch_longitude={longitude}&launch_datetime={timestamp}&launch_altitude={altitude:.2f}&ascent_rate={ascent_rate:.2f}&burst_altitude={burst_altitude:.2f}&descent_rate={descent_rate:.2f}"
logging.debug(url)
conn = http.client.HTTPSConnection("tawhiri.v2.sondehub.org")
conn = http.client.HTTPSConnection(TAWHIRI_SERVER)
conn.request("GET", url)
res = conn.getresponse()
data = res.read()
@ -343,37 +394,38 @@ def get_standard_prediction(timestamp, latitude, longitude, altitude, current_ra
return None
# Need to mock this out if we ever use it again
#
# def get_ruaumoko(latitude, longitude):
# """
# Request the ground level from ruaumoko.
def get_ruaumoko(latitude, longitude):
"""
Request the ground level from ruaumoko.
# Returns 0.0 if the ground level could not be determined, effectively
# defaulting to any checks based on this data being based on mean sea level.
# """
Returns 0.0 if the ground level could not be determined, effectively
defaulting to any checks based on this data being based on mean sea level.
"""
# # Shift longitude into the appropriate range for Tawhiri
# if longitude < 0:
# longitude += 360.0
# Shift longitude into the appropriate range for Tawhiri
if longitude < 0:
longitude += 360.0
# # Generate the prediction URL
# url = f"/api/ruaumoko/?latitude={latitude}&longitude={longitude}"
# logging.debug(url)
# conn = http.client.HTTPSConnection(TAWHIRI_SERVER)
# conn.request("GET", url)
# res = conn.getresponse()
# data = res.read()
# Generate the prediction URL
url = f"/api/ruaumoko/?latitude={latitude}&longitude={longitude}"
logging.debug(url)
conn = http.client.HTTPSConnection("tawhiri.v2.sondehub.org")
conn.request("GET", url)
res = conn.getresponse()
data = res.read()
if res.code != 200:
logging.debug(data)
return None
# if res.code != 200:
# logging.debug(data)
# return None
resp_data = json.loads(data.decode("utf-8"))
# resp_data = json.loads(data.decode("utf-8"))
if 'altitude' in resp_data:
return resp_data['altitude']
else:
return 0.0
# if 'altitude' in resp_data:
# return resp_data['altitude']
# else:
# return 0.0
def bulk_upload_es(index_prefix,payloads):
@ -392,8 +444,11 @@ def bulk_upload_es(index_prefix,payloads):
raise RuntimeError
def predict(event, context):
# Connect to MQTT
connect()
# Use asyncio.run to synchronously "await" an async function
result = asyncio.run(predict_async(event, context))
time.sleep(0.5) # give paho mqtt 500ms to send messages this could be improved on but paho mqtt is a pain to interface with
return result
async def predict_async(event, context):
@ -593,6 +648,18 @@ async def predict_async(event, context):
if len(output) > 0:
bulk_upload_es("ham-predictions", output)
# upload to mqtt
while not connected_flag:
time.sleep(0.01) # wait until connected
for prediction in output:
logging.debug(f'Publishing prediction for {prediction["payload_callsign"]} to MQTT')
client.publish(
topic=f'amateur-prediction/{prediction["payload_callsign"]}',
payload=json.dumps(prediction),
qos=0,
retain=False
)
logging.debug(f'Published prediction for {prediction["payload_callsign"]} to MQTT')
logging.debug("Finished")
return
@ -639,6 +706,7 @@ async def run_predictions_for_serial(sem, flight_docs, serial, value):
if (abs(value['rate']) <= ASCENT_RATE_THRESHOLD) and (value['alt'] < ALTITUDE_AMSL_THRESHOLD):
# Payload is 'floating' (e.g. is probably on the ground), and is below 1500m AMSL.
# Don't run a prediction in this case. It probably just hasn't been launched yet.
logging.debug(f"{serial} is floating and alt is low so not running prediction")
return None

View File

@ -1,43 +1,66 @@
from . import *
from . import mock_values, test_values
import unittest
from unittest.mock import MagicMock, call, patch
# Predictor test
# conn = http.client.HTTPSConnection("tawhiri.v2.sondehub.org")
# _now = datetime.utcnow().isoformat() + "Z"
# _ascent = get_standard_prediction(conn, _now, -34.0, 138.0, 10.0, burst_altitude=26000)
# print(f"Got {len(_ascent)} data points for ascent prediction.")
# _descent = get_standard_prediction(conn, _now, -34.0, 138.0, 24000.0, burst_altitude=24000.5)
# print(f"Got {len(_descent)} data points for descent prediction.")
# test = predict(
# {},{}
# )
#print(get_launch_sites())
#print(get_reverse_predictions())
# for _serial in test:
# print(f"{_serial['serial']}: {len(_serial['data'])}")
# Mock OpenSearch requests
def mock_es_request(body, path, method):
if path.endswith("_bulk"): # handle when the upload happens
return {}
elif(path == "flight-doc/_search"): # handle flightdoc queries
return mock_values.flight_docs
elif(path == "ham-telm-*/_search"): # handle telm searches
return mock_values.ham_telm
else:
raise NotImplemented
# Mock out tawhiri
class MockResponse(object):
code = 200
def read(self):
return mock_values.tawhiri_respose # currently we only mock a float profile
class MockHTTPS(object):
logging.debug(object)
def __init__(self, url):
logging.debug(url)
def request(self,method, url):
pass
def getresponse(self):
return MockResponse()
http.client.HTTPSConnection = MockHTTPS
logging.basicConfig(
format="%(asctime)s %(levelname)s:%(message)s", level=logging.DEBUG
)
print(predict(
{},{}
))
# bulk_upload_es("reverse-prediction",[{
# "datetime" : "2021-10-04",
# "data" : { },
# "serial" : "R12341234",
# "station" : "-2",
# "subtype" : "RS41-SGM",
# "ascent_rate" : "5",
# "alt" : 1000,
# "position" : [
# 1,
# 2
# ],
# "type" : "RS41"
# }]
# )
class TestAmateurPrediction(unittest.TestCase):
def setUp(self):
es.request = MagicMock(side_effect=mock_es_request)
client.connect = MagicMock()
client.loop_start = MagicMock()
client.username_pw_set = MagicMock()
client.tls_set = MagicMock()
client.publish = MagicMock()
on_connect(client, "userdata", "flags", 0)
@patch("time.sleep")
def test_float_prediction(self, MockSleep):
predict({},{})
date_prefix = datetime.now().strftime("%Y-%m")
es.request.assert_has_calls(
[
call(json.dumps(test_values.flight_doc_search),"flight-doc/_search", "POST"),
call(json.dumps(test_values.ham_telm_search), "ham-telm-*/_search", "GET"),
call(test_values.es_bulk_upload,f"ham-predictions-{date_prefix}/_bulk","POST")
]
)
client.username_pw_set.assert_called()
client.loop_start.assert_called()
client.connect.assert_called()
client.publish.assert_has_calls([test_values.mqtt_publish_call])
time.sleep.assert_called_with(0.5) # make sure we sleep to let paho mqtt queue clear
if __name__ == '__main__':
unittest.main()

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long