mirror of
https://github.com/projecthorus/sondehub-infra.git
synced 2025-01-29 15:13:53 +00:00
Add MQTT predictions for amateur payloads + amateur predictor tests (#124)
* Add MQTT predictions for amateur payloads * ignore env var changes amateur predictor --------- Co-authored-by: xss <michaela@michaela.lgbt>
This commit is contained in:
parent
3e849d2606
commit
8a80aa9c2f
@ -81,6 +81,9 @@ resource "aws_lambda_function" "ham_predict_updater" {
|
||||
tags = {
|
||||
Name = "ham_predict_updater"
|
||||
}
|
||||
lifecycle {
|
||||
ignore_changes = [environment]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
@ -1,3 +1,7 @@
|
||||
import sys
|
||||
sys.path.append("sns_to_mqtt/vendor")
|
||||
|
||||
import paho.mqtt.client as mqtt
|
||||
import json
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
@ -8,7 +12,11 @@ from math import radians, degrees, sin, cos, atan2, sqrt, pi
|
||||
import es
|
||||
import asyncio
|
||||
import functools
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
|
||||
TAWHIRI_SERVER = "tawhiri.v2.sondehub.org"
|
||||
|
||||
# FLIGHT PROFILE DEFAULTS
|
||||
#
|
||||
@ -39,6 +47,49 @@ ALTITUDE_AGL_THRESHOLD = 150.0
|
||||
# Do not run predictions if the payload is below this altitude AMSL, and has an ascent rate below the above threshold.
|
||||
ALTITUDE_AMSL_THRESHOLD = 1500.0
|
||||
|
||||
|
||||
# Setup MQTT
|
||||
client = mqtt.Client(transport="websockets")
|
||||
|
||||
connected_flag = False
|
||||
|
||||
import socket
|
||||
socket.setdefaulttimeout(1)
|
||||
|
||||
|
||||
## MQTT functions
|
||||
def connect():
|
||||
client.on_connect = on_connect
|
||||
client.on_disconnect = on_disconnect
|
||||
client.on_publish = on_publish
|
||||
#client.tls_set()
|
||||
client.username_pw_set(username=os.getenv("MQTT_USERNAME"), password=os.getenv("MQTT_PASSWORD"))
|
||||
HOSTS = os.getenv("MQTT_HOST").split(",")
|
||||
PORT = int(os.getenv("MQTT_PORT", default="8080"))
|
||||
if PORT == 443:
|
||||
client.tls_set()
|
||||
HOST = random.choice(HOSTS)
|
||||
print(f"Connecting to {HOST}")
|
||||
client.connect(HOST, PORT, 5)
|
||||
client.loop_start()
|
||||
print("loop started")
|
||||
|
||||
def on_disconnect(client, userdata, rc):
|
||||
global connected_flag
|
||||
print("disconnected")
|
||||
connected_flag=False #set flag
|
||||
|
||||
def on_connect(client, userdata, flags, rc):
|
||||
global connected_flag
|
||||
if rc==0:
|
||||
print("connected")
|
||||
connected_flag=True #set flag
|
||||
else:
|
||||
print("Bad connection Returned code")
|
||||
|
||||
def on_publish(client, userdata, mid):
|
||||
pass
|
||||
|
||||
def get_flight_docs():
|
||||
path = "flight-doc/_search"
|
||||
payload = {
|
||||
@ -252,7 +303,7 @@ def get_float_prediction(timestamp, latitude, longitude, altitude, current_rate=
|
||||
# Generate the prediction URL
|
||||
url = f"/api/v1/?launch_altitude={altitude}&launch_latitude={latitude}&launch_longitude={longitude}&launch_datetime={timestamp}&float_altitude={burst_altitude:.2f}&stop_datetime={(datetime.now() + timedelta(days=1)).isoformat()}Z&ascent_rate={ascent_rate:.2f}&profile=float_profile"
|
||||
logging.debug(url)
|
||||
conn = http.client.HTTPSConnection("tawhiri.v2.sondehub.org")
|
||||
conn = http.client.HTTPSConnection(TAWHIRI_SERVER)
|
||||
conn.request("GET", url)
|
||||
res = conn.getresponse()
|
||||
data = res.read()
|
||||
@ -309,7 +360,7 @@ def get_standard_prediction(timestamp, latitude, longitude, altitude, current_ra
|
||||
# Generate the prediction URL
|
||||
url = f"/api/v1/?launch_latitude={latitude}&launch_longitude={longitude}&launch_datetime={timestamp}&launch_altitude={altitude:.2f}&ascent_rate={ascent_rate:.2f}&burst_altitude={burst_altitude:.2f}&descent_rate={descent_rate:.2f}"
|
||||
logging.debug(url)
|
||||
conn = http.client.HTTPSConnection("tawhiri.v2.sondehub.org")
|
||||
conn = http.client.HTTPSConnection(TAWHIRI_SERVER)
|
||||
conn.request("GET", url)
|
||||
res = conn.getresponse()
|
||||
data = res.read()
|
||||
@ -343,37 +394,38 @@ def get_standard_prediction(timestamp, latitude, longitude, altitude, current_ra
|
||||
return None
|
||||
|
||||
|
||||
# Need to mock this out if we ever use it again
|
||||
#
|
||||
# def get_ruaumoko(latitude, longitude):
|
||||
# """
|
||||
# Request the ground level from ruaumoko.
|
||||
|
||||
def get_ruaumoko(latitude, longitude):
|
||||
"""
|
||||
Request the ground level from ruaumoko.
|
||||
# Returns 0.0 if the ground level could not be determined, effectively
|
||||
# defaulting to any checks based on this data being based on mean sea level.
|
||||
# """
|
||||
|
||||
Returns 0.0 if the ground level could not be determined, effectively
|
||||
defaulting to any checks based on this data being based on mean sea level.
|
||||
"""
|
||||
# # Shift longitude into the appropriate range for Tawhiri
|
||||
# if longitude < 0:
|
||||
# longitude += 360.0
|
||||
|
||||
# Shift longitude into the appropriate range for Tawhiri
|
||||
if longitude < 0:
|
||||
longitude += 360.0
|
||||
# # Generate the prediction URL
|
||||
# url = f"/api/ruaumoko/?latitude={latitude}&longitude={longitude}"
|
||||
# logging.debug(url)
|
||||
# conn = http.client.HTTPSConnection(TAWHIRI_SERVER)
|
||||
# conn.request("GET", url)
|
||||
# res = conn.getresponse()
|
||||
# data = res.read()
|
||||
|
||||
# Generate the prediction URL
|
||||
url = f"/api/ruaumoko/?latitude={latitude}&longitude={longitude}"
|
||||
logging.debug(url)
|
||||
conn = http.client.HTTPSConnection("tawhiri.v2.sondehub.org")
|
||||
conn.request("GET", url)
|
||||
res = conn.getresponse()
|
||||
data = res.read()
|
||||
|
||||
if res.code != 200:
|
||||
logging.debug(data)
|
||||
return None
|
||||
# if res.code != 200:
|
||||
# logging.debug(data)
|
||||
# return None
|
||||
|
||||
resp_data = json.loads(data.decode("utf-8"))
|
||||
# resp_data = json.loads(data.decode("utf-8"))
|
||||
|
||||
if 'altitude' in resp_data:
|
||||
return resp_data['altitude']
|
||||
else:
|
||||
return 0.0
|
||||
# if 'altitude' in resp_data:
|
||||
# return resp_data['altitude']
|
||||
# else:
|
||||
# return 0.0
|
||||
|
||||
|
||||
def bulk_upload_es(index_prefix,payloads):
|
||||
@ -392,8 +444,11 @@ def bulk_upload_es(index_prefix,payloads):
|
||||
raise RuntimeError
|
||||
|
||||
def predict(event, context):
|
||||
# Connect to MQTT
|
||||
connect()
|
||||
# Use asyncio.run to synchronously "await" an async function
|
||||
result = asyncio.run(predict_async(event, context))
|
||||
time.sleep(0.5) # give paho mqtt 500ms to send messages this could be improved on but paho mqtt is a pain to interface with
|
||||
return result
|
||||
|
||||
async def predict_async(event, context):
|
||||
@ -593,6 +648,18 @@ async def predict_async(event, context):
|
||||
if len(output) > 0:
|
||||
bulk_upload_es("ham-predictions", output)
|
||||
|
||||
# upload to mqtt
|
||||
while not connected_flag:
|
||||
time.sleep(0.01) # wait until connected
|
||||
for prediction in output:
|
||||
logging.debug(f'Publishing prediction for {prediction["payload_callsign"]} to MQTT')
|
||||
client.publish(
|
||||
topic=f'amateur-prediction/{prediction["payload_callsign"]}',
|
||||
payload=json.dumps(prediction),
|
||||
qos=0,
|
||||
retain=False
|
||||
)
|
||||
logging.debug(f'Published prediction for {prediction["payload_callsign"]} to MQTT')
|
||||
|
||||
logging.debug("Finished")
|
||||
return
|
||||
@ -639,6 +706,7 @@ async def run_predictions_for_serial(sem, flight_docs, serial, value):
|
||||
if (abs(value['rate']) <= ASCENT_RATE_THRESHOLD) and (value['alt'] < ALTITUDE_AMSL_THRESHOLD):
|
||||
# Payload is 'floating' (e.g. is probably on the ground), and is below 1500m AMSL.
|
||||
# Don't run a prediction in this case. It probably just hasn't been launched yet.
|
||||
logging.debug(f"{serial} is floating and alt is low so not running prediction")
|
||||
return None
|
||||
|
||||
|
||||
|
@ -1,43 +1,66 @@
|
||||
from . import *
|
||||
from . import mock_values, test_values
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, call, patch
|
||||
|
||||
# Predictor test
|
||||
# conn = http.client.HTTPSConnection("tawhiri.v2.sondehub.org")
|
||||
# _now = datetime.utcnow().isoformat() + "Z"
|
||||
|
||||
# _ascent = get_standard_prediction(conn, _now, -34.0, 138.0, 10.0, burst_altitude=26000)
|
||||
# print(f"Got {len(_ascent)} data points for ascent prediction.")
|
||||
# _descent = get_standard_prediction(conn, _now, -34.0, 138.0, 24000.0, burst_altitude=24000.5)
|
||||
# print(f"Got {len(_descent)} data points for descent prediction.")
|
||||
|
||||
# test = predict(
|
||||
# {},{}
|
||||
# )
|
||||
#print(get_launch_sites())
|
||||
#print(get_reverse_predictions())
|
||||
# for _serial in test:
|
||||
# print(f"{_serial['serial']}: {len(_serial['data'])}")
|
||||
# Mock OpenSearch requests
|
||||
def mock_es_request(body, path, method):
|
||||
if path.endswith("_bulk"): # handle when the upload happens
|
||||
return {}
|
||||
elif(path == "flight-doc/_search"): # handle flightdoc queries
|
||||
return mock_values.flight_docs
|
||||
elif(path == "ham-telm-*/_search"): # handle telm searches
|
||||
return mock_values.ham_telm
|
||||
else:
|
||||
raise NotImplemented
|
||||
|
||||
# Mock out tawhiri
|
||||
class MockResponse(object):
|
||||
code = 200
|
||||
def read(self):
|
||||
return mock_values.tawhiri_respose # currently we only mock a float profile
|
||||
|
||||
class MockHTTPS(object):
|
||||
logging.debug(object)
|
||||
def __init__(self, url):
|
||||
logging.debug(url)
|
||||
def request(self,method, url):
|
||||
pass
|
||||
def getresponse(self):
|
||||
return MockResponse()
|
||||
|
||||
http.client.HTTPSConnection = MockHTTPS
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s %(levelname)s:%(message)s", level=logging.DEBUG
|
||||
)
|
||||
|
||||
print(predict(
|
||||
{},{}
|
||||
))
|
||||
# bulk_upload_es("reverse-prediction",[{
|
||||
# "datetime" : "2021-10-04",
|
||||
# "data" : { },
|
||||
# "serial" : "R12341234",
|
||||
# "station" : "-2",
|
||||
# "subtype" : "RS41-SGM",
|
||||
# "ascent_rate" : "5",
|
||||
# "alt" : 1000,
|
||||
# "position" : [
|
||||
# 1,
|
||||
# 2
|
||||
# ],
|
||||
# "type" : "RS41"
|
||||
# }]
|
||||
# )
|
||||
class TestAmateurPrediction(unittest.TestCase):
|
||||
def setUp(self):
|
||||
es.request = MagicMock(side_effect=mock_es_request)
|
||||
client.connect = MagicMock()
|
||||
client.loop_start = MagicMock()
|
||||
client.username_pw_set = MagicMock()
|
||||
client.tls_set = MagicMock()
|
||||
client.publish = MagicMock()
|
||||
on_connect(client, "userdata", "flags", 0)
|
||||
|
||||
@patch("time.sleep")
|
||||
def test_float_prediction(self, MockSleep):
|
||||
predict({},{})
|
||||
date_prefix = datetime.now().strftime("%Y-%m")
|
||||
es.request.assert_has_calls(
|
||||
[
|
||||
call(json.dumps(test_values.flight_doc_search),"flight-doc/_search", "POST"),
|
||||
call(json.dumps(test_values.ham_telm_search), "ham-telm-*/_search", "GET"),
|
||||
call(test_values.es_bulk_upload,f"ham-predictions-{date_prefix}/_bulk","POST")
|
||||
]
|
||||
)
|
||||
client.username_pw_set.assert_called()
|
||||
client.loop_start.assert_called()
|
||||
client.connect.assert_called()
|
||||
client.publish.assert_has_calls([test_values.mqtt_publish_call])
|
||||
time.sleep.assert_called_with(0.5) # make sure we sleep to let paho mqtt queue clear
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
401
lambda/ham_predict_updater/mock_values.py
Normal file
401
lambda/ham_predict_updater/mock_values.py
Normal file
File diff suppressed because one or more lines are too long
36
lambda/ham_predict_updater/test_values.py
Normal file
36
lambda/ham_predict_updater/test_values.py
Normal file
File diff suppressed because one or more lines are too long
Loading…
x
Reference in New Issue
Block a user