mirror of
https://github.com/projecthorus/sondehub-infra.git
synced 2024-12-18 20:57:56 +00:00
Add prediction via websockets
This commit is contained in:
parent
86836b61b0
commit
b725cc5840
@ -1,3 +1,6 @@
|
||||
import sys
|
||||
sys.path.append("sns_to_mqtt/vendor")
|
||||
import paho.mqtt.client as mqtt
|
||||
import json
|
||||
from datetime import datetime
|
||||
import http.client
|
||||
@ -7,7 +10,54 @@ from math import radians, degrees, sin, cos, atan2, sqrt, pi
|
||||
import es
|
||||
import asyncio
|
||||
import functools
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
import traceback
|
||||
|
||||
client = mqtt.Client(transport="websockets")
|
||||
|
||||
connected_flag = False
|
||||
|
||||
import socket
|
||||
socket.setdefaulttimeout(1)
|
||||
|
||||
|
||||
## MQTT functions
|
||||
def connect():
|
||||
client.on_connect = on_connect
|
||||
client.on_disconnect = on_disconnect
|
||||
client.on_publish = on_publish
|
||||
#client.tls_set()
|
||||
client.username_pw_set(username=os.getenv("MQTT_USERNAME"), password=os.getenv("MQTT_PASSWORD"))
|
||||
HOSTS = os.getenv("MQTT_HOST").split(",")
|
||||
PORT = int(os.getenv("MQTT_PORT", default="8080"))
|
||||
if PORT == 443:
|
||||
client.tls_set()
|
||||
HOST = random.choice(HOSTS)
|
||||
print(f"Connecting to {HOST}")
|
||||
client.connect(HOST, PORT, 5)
|
||||
client.loop_start()
|
||||
print("loop started")
|
||||
|
||||
def on_disconnect(client, userdata, rc):
|
||||
global connected_flag
|
||||
print("disconnected")
|
||||
connected_flag=False #set flag
|
||||
|
||||
def on_connect(client, userdata, flags, rc):
|
||||
global connected_flag
|
||||
if rc==0:
|
||||
print("connected")
|
||||
connected_flag=True #set flag
|
||||
else:
|
||||
print("Bad connection Returned code")
|
||||
|
||||
def on_publish(client, userdata, mid):
|
||||
pass
|
||||
|
||||
# setup MQTT
|
||||
connect()
|
||||
|
||||
# FLIGHT PROFILE DEFAULTS
|
||||
#
|
||||
@ -248,55 +298,58 @@ def get_standard_prediction(timestamp, latitude, longitude, altitude, current_ra
|
||||
- Longitude is in the range 0-360.0
|
||||
- All ascent/descent rates must be positive.
|
||||
"""
|
||||
try:
|
||||
# Bomb out if the rates are too low.
|
||||
if ascent_rate < ASCENT_RATE_THRESHOLD:
|
||||
return None
|
||||
|
||||
# Bomb out if the rates are too low.
|
||||
if ascent_rate < ASCENT_RATE_THRESHOLD:
|
||||
return None
|
||||
|
||||
if descent_rate < ASCENT_RATE_THRESHOLD:
|
||||
return None
|
||||
if descent_rate < ASCENT_RATE_THRESHOLD:
|
||||
return None
|
||||
|
||||
|
||||
# Shift longitude into the appropriate range for Tawhiri
|
||||
if longitude < 0:
|
||||
longitude += 360.0
|
||||
# Shift longitude into the appropriate range for Tawhiri
|
||||
if longitude < 0:
|
||||
longitude += 360.0
|
||||
|
||||
# Generate the prediction URL
|
||||
url = f"/api/v1/?launch_latitude={latitude}&launch_longitude={longitude}&launch_datetime={timestamp}&launch_altitude={altitude:.2f}&ascent_rate={ascent_rate:.2f}&burst_altitude={burst_altitude:.2f}&descent_rate={descent_rate:.2f}"
|
||||
logging.debug(url)
|
||||
conn = http.client.HTTPSConnection("tawhiri.v2.sondehub.org")
|
||||
conn.request("GET", url)
|
||||
res = conn.getresponse()
|
||||
data = res.read()
|
||||
# Generate the prediction URL
|
||||
url = f"/api/v1/?launch_latitude={latitude}&launch_longitude={longitude}&launch_datetime={timestamp}&launch_altitude={altitude:.2f}&ascent_rate={ascent_rate:.2f}&burst_altitude={burst_altitude:.2f}&descent_rate={descent_rate:.2f}"
|
||||
logging.debug(url)
|
||||
conn = http.client.HTTPSConnection("tawhiri.v2.sondehub.org")
|
||||
conn.request("GET", url)
|
||||
res = conn.getresponse()
|
||||
data = res.read()
|
||||
|
||||
if res.code != 200:
|
||||
logging.debug(data)
|
||||
return None
|
||||
|
||||
pred_data = json.loads(data.decode("utf-8"))
|
||||
|
||||
path = []
|
||||
|
||||
if 'prediction' in pred_data:
|
||||
for stage in pred_data['prediction']:
|
||||
# Probably don't need to worry about this, it should only result in one or two points
|
||||
# in 'ascent'.
|
||||
if stage['stage'] == 'ascent' and current_rate < 0: # ignore ascent stage if we have already burst
|
||||
continue
|
||||
else:
|
||||
for item in stage['trajectory']:
|
||||
path.append({
|
||||
"time": int(datetime.fromisoformat(item['datetime'].split(".")[0].replace("Z","")).timestamp()),
|
||||
"lat": item['latitude'],
|
||||
"lon": item['longitude'] - 360 if item['longitude'] > 180 else item['longitude'],
|
||||
"alt": item['altitude'],
|
||||
})
|
||||
if res.code != 200:
|
||||
logging.debug(data)
|
||||
return None
|
||||
|
||||
pred_data['path'] = path
|
||||
return pred_data
|
||||
else:
|
||||
return None
|
||||
pred_data = json.loads(data.decode("utf-8"))
|
||||
|
||||
path = []
|
||||
|
||||
if 'prediction' in pred_data:
|
||||
for stage in pred_data['prediction']:
|
||||
# Probably don't need to worry about this, it should only result in one or two points
|
||||
# in 'ascent'.
|
||||
if stage['stage'] == 'ascent' and current_rate < 0: # ignore ascent stage if we have already burst
|
||||
continue
|
||||
else:
|
||||
for item in stage['trajectory']:
|
||||
path.append({
|
||||
"time": int(datetime.fromisoformat(item['datetime'].split(".")[0].replace("Z","")).timestamp()),
|
||||
"lat": item['latitude'],
|
||||
"lon": item['longitude'] - 360 if item['longitude'] > 180 else item['longitude'],
|
||||
"alt": item['altitude'],
|
||||
})
|
||||
|
||||
pred_data['path'] = path
|
||||
return pred_data
|
||||
else:
|
||||
return None
|
||||
except:
|
||||
traceback.print_exc()
|
||||
logging.error(f"Error turnning standard prediction for {url}")
|
||||
return None
|
||||
|
||||
def get_launch_estimate(timestamp, latitude, longitude, altitude, ascent_rate=PREDICT_DEFAULTS['ascent_rate'], current_rate=5.0):
|
||||
"""
|
||||
@ -426,8 +479,10 @@ def bulk_upload_es(index_prefix,payloads):
|
||||
raise RuntimeError
|
||||
|
||||
def predict(event, context):
|
||||
client.loop(timeout=0.05, max_packets=1) # make sure MQTT reconnects
|
||||
# Use asyncio.run to synchronously "await" an async function
|
||||
result = asyncio.run(predict_async(event, context))
|
||||
time.sleep(0.5) # give paho mqtt 500ms to send messages this could be improved on but paho mqtt is a pain to interface with
|
||||
return result
|
||||
|
||||
async def predict_async(event, context):
|
||||
@ -665,6 +720,27 @@ async def predict_async(event, context):
|
||||
if len(output_reverse) > 0:
|
||||
bulk_upload_es("reverse-prediction", output_reverse)
|
||||
|
||||
# upload to mqtt
|
||||
while not connected_flag:
|
||||
time.sleep(0.01) # wait until connected
|
||||
for prediction in output:
|
||||
logging.debug(f'Publishing prediction for {prediction["serial"]} to MQTT')
|
||||
client.publish(
|
||||
topic=f'prediction/{prediction["serial"]}',
|
||||
payload=json.dumps(prediction),
|
||||
qos=0,
|
||||
retain=False
|
||||
)
|
||||
logging.debug(f'Published prediction for {prediction["serial"]} to MQTT')
|
||||
for prediction in output_reverse:
|
||||
logging.debug(f'Publishing reverse prediction for {prediction["serial"]} to MQTT')
|
||||
client.publish(
|
||||
topic=f'reverse-prediction/{prediction["serial"]}',
|
||||
payload=json.dumps(prediction),
|
||||
qos=0,
|
||||
retain=False
|
||||
)
|
||||
logging.debug(f'Published reverse prediction for {prediction["serial"]} to MQTT')
|
||||
logging.debug("Finished")
|
||||
return
|
||||
|
||||
@ -819,4 +895,3 @@ async def run_predictions_for_serial(sem, serial, value, reverse_predictions, la
|
||||
burst_altitude=burst_altitude,
|
||||
descent_rate=descent_rate
|
||||
)), reverse_serial_data]
|
||||
|
||||
|
14
predictor.tf
14
predictor.tf
@ -53,6 +53,17 @@ resource "aws_iam_role_policy" "predict_updater" {
|
||||
"Effect": "Allow",
|
||||
"Action": "s3:*",
|
||||
"Resource": "*"
|
||||
},
|
||||
{
|
||||
"Action": [
|
||||
"ec2:DescribeNetworkInterfaces",
|
||||
"ec2:CreateNetworkInterface",
|
||||
"ec2:DeleteNetworkInterface",
|
||||
"ec2:DescribeInstances",
|
||||
"ec2:AttachNetworkInterface"
|
||||
],
|
||||
"Effect": "Allow",
|
||||
"Resource": "*"
|
||||
}
|
||||
]
|
||||
}
|
||||
@ -81,6 +92,9 @@ resource "aws_lambda_function" "predict_updater" {
|
||||
tags = {
|
||||
Name = "predict_updater"
|
||||
}
|
||||
lifecycle {
|
||||
ignore_changes = [environment]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user