Add prediction via websockets

This commit is contained in:
xss 2023-09-25 11:48:43 +10:00
parent 86836b61b0
commit b725cc5840
2 changed files with 132 additions and 43 deletions

View File

@ -1,3 +1,6 @@
import sys
sys.path.append("sns_to_mqtt/vendor")
import paho.mqtt.client as mqtt
import json
from datetime import datetime
import http.client
@ -7,7 +10,54 @@ from math import radians, degrees, sin, cos, atan2, sqrt, pi
import es
import asyncio
import functools
import os
import random
import time
import traceback
client = mqtt.Client(transport="websockets")
connected_flag = False
import socket
socket.setdefaulttimeout(1)
## MQTT functions
def connect():
client.on_connect = on_connect
client.on_disconnect = on_disconnect
client.on_publish = on_publish
#client.tls_set()
client.username_pw_set(username=os.getenv("MQTT_USERNAME"), password=os.getenv("MQTT_PASSWORD"))
HOSTS = os.getenv("MQTT_HOST").split(",")
PORT = int(os.getenv("MQTT_PORT", default="8080"))
if PORT == 443:
client.tls_set()
HOST = random.choice(HOSTS)
print(f"Connecting to {HOST}")
client.connect(HOST, PORT, 5)
client.loop_start()
print("loop started")
def on_disconnect(client, userdata, rc):
global connected_flag
print("disconnected")
connected_flag=False #set flag
def on_connect(client, userdata, flags, rc):
global connected_flag
if rc==0:
print("connected")
connected_flag=True #set flag
else:
print("Bad connection Returned code")
def on_publish(client, userdata, mid):
pass
# setup MQTT
connect()
# FLIGHT PROFILE DEFAULTS
#
@ -248,55 +298,58 @@ def get_standard_prediction(timestamp, latitude, longitude, altitude, current_ra
- Longitude is in the range 0-360.0
- All ascent/descent rates must be positive.
"""
try:
# Bomb out if the rates are too low.
if ascent_rate < ASCENT_RATE_THRESHOLD:
return None
# Bomb out if the rates are too low.
if ascent_rate < ASCENT_RATE_THRESHOLD:
return None
if descent_rate < ASCENT_RATE_THRESHOLD:
return None
if descent_rate < ASCENT_RATE_THRESHOLD:
return None
# Shift longitude into the appropriate range for Tawhiri
if longitude < 0:
longitude += 360.0
# Shift longitude into the appropriate range for Tawhiri
if longitude < 0:
longitude += 360.0
# Generate the prediction URL
url = f"/api/v1/?launch_latitude={latitude}&launch_longitude={longitude}&launch_datetime={timestamp}&launch_altitude={altitude:.2f}&ascent_rate={ascent_rate:.2f}&burst_altitude={burst_altitude:.2f}&descent_rate={descent_rate:.2f}"
logging.debug(url)
conn = http.client.HTTPSConnection("tawhiri.v2.sondehub.org")
conn.request("GET", url)
res = conn.getresponse()
data = res.read()
# Generate the prediction URL
url = f"/api/v1/?launch_latitude={latitude}&launch_longitude={longitude}&launch_datetime={timestamp}&launch_altitude={altitude:.2f}&ascent_rate={ascent_rate:.2f}&burst_altitude={burst_altitude:.2f}&descent_rate={descent_rate:.2f}"
logging.debug(url)
conn = http.client.HTTPSConnection("tawhiri.v2.sondehub.org")
conn.request("GET", url)
res = conn.getresponse()
data = res.read()
if res.code != 200:
logging.debug(data)
return None
pred_data = json.loads(data.decode("utf-8"))
path = []
if 'prediction' in pred_data:
for stage in pred_data['prediction']:
# Probably don't need to worry about this, it should only result in one or two points
# in 'ascent'.
if stage['stage'] == 'ascent' and current_rate < 0: # ignore ascent stage if we have already burst
continue
else:
for item in stage['trajectory']:
path.append({
"time": int(datetime.fromisoformat(item['datetime'].split(".")[0].replace("Z","")).timestamp()),
"lat": item['latitude'],
"lon": item['longitude'] - 360 if item['longitude'] > 180 else item['longitude'],
"alt": item['altitude'],
})
if res.code != 200:
logging.debug(data)
return None
pred_data['path'] = path
return pred_data
else:
return None
pred_data = json.loads(data.decode("utf-8"))
path = []
if 'prediction' in pred_data:
for stage in pred_data['prediction']:
# Probably don't need to worry about this, it should only result in one or two points
# in 'ascent'.
if stage['stage'] == 'ascent' and current_rate < 0: # ignore ascent stage if we have already burst
continue
else:
for item in stage['trajectory']:
path.append({
"time": int(datetime.fromisoformat(item['datetime'].split(".")[0].replace("Z","")).timestamp()),
"lat": item['latitude'],
"lon": item['longitude'] - 360 if item['longitude'] > 180 else item['longitude'],
"alt": item['altitude'],
})
pred_data['path'] = path
return pred_data
else:
return None
except:
traceback.print_exc()
logging.error(f"Error turnning standard prediction for {url}")
return None
def get_launch_estimate(timestamp, latitude, longitude, altitude, ascent_rate=PREDICT_DEFAULTS['ascent_rate'], current_rate=5.0):
"""
@ -426,8 +479,10 @@ def bulk_upload_es(index_prefix,payloads):
raise RuntimeError
def predict(event, context):
client.loop(timeout=0.05, max_packets=1) # make sure MQTT reconnects
# Use asyncio.run to synchronously "await" an async function
result = asyncio.run(predict_async(event, context))
time.sleep(0.5) # give paho mqtt 500ms to send messages this could be improved on but paho mqtt is a pain to interface with
return result
async def predict_async(event, context):
@ -665,6 +720,27 @@ async def predict_async(event, context):
if len(output_reverse) > 0:
bulk_upload_es("reverse-prediction", output_reverse)
# upload to mqtt
while not connected_flag:
time.sleep(0.01) # wait until connected
for prediction in output:
logging.debug(f'Publishing prediction for {prediction["serial"]} to MQTT')
client.publish(
topic=f'prediction/{prediction["serial"]}',
payload=json.dumps(prediction),
qos=0,
retain=False
)
logging.debug(f'Published prediction for {prediction["serial"]} to MQTT')
for prediction in output_reverse:
logging.debug(f'Publishing reverse prediction for {prediction["serial"]} to MQTT')
client.publish(
topic=f'reverse-prediction/{prediction["serial"]}',
payload=json.dumps(prediction),
qos=0,
retain=False
)
logging.debug(f'Published reverse prediction for {prediction["serial"]} to MQTT')
logging.debug("Finished")
return
@ -819,4 +895,3 @@ async def run_predictions_for_serial(sem, serial, value, reverse_predictions, la
burst_altitude=burst_altitude,
descent_rate=descent_rate
)), reverse_serial_data]

View File

@ -53,6 +53,17 @@ resource "aws_iam_role_policy" "predict_updater" {
"Effect": "Allow",
"Action": "s3:*",
"Resource": "*"
},
{
"Action": [
"ec2:DescribeNetworkInterfaces",
"ec2:CreateNetworkInterface",
"ec2:DeleteNetworkInterface",
"ec2:DescribeInstances",
"ec2:AttachNetworkInterface"
],
"Effect": "Allow",
"Resource": "*"
}
]
}
@ -81,6 +92,9 @@ resource "aws_lambda_function" "predict_updater" {
tags = {
Name = "predict_updater"
}
lifecycle {
ignore_changes = [environment]
}
}