diff --git a/lambda/ham_predict_updater/__init__.py b/lambda/ham_predict_updater/__init__.py index fbe4a18..93db9d3 100644 --- a/lambda/ham_predict_updater/__init__.py +++ b/lambda/ham_predict_updater/__init__.py @@ -52,6 +52,7 @@ ALTITUDE_AMSL_THRESHOLD = 1500.0 client = mqtt.Client(transport="websockets") connected_flag = False +setup = False import socket socket.setdefaulttimeout(1) @@ -444,8 +445,11 @@ def bulk_upload_es(index_prefix,payloads): raise RuntimeError def predict(event, context): + global setup # Connect to MQTT - connect() + if not setup: + connect() + setup = True # Use asyncio.run to synchronously "await" an async function result = asyncio.run(predict_async(event, context)) time.sleep(0.5) # give paho mqtt 500ms to send messages this could be improved on but paho mqtt is a pain to interface with diff --git a/lambda/ham_predict_updater/__main__.py b/lambda/ham_predict_updater/__main__.py index ebc82cf..e3643e4 100644 --- a/lambda/ham_predict_updater/__main__.py +++ b/lambda/ham_predict_updater/__main__.py @@ -1,4 +1,8 @@ -from . import * +import ham_predict_updater +import logging +import json +from datetime import datetime +import time from . import mock_values, test_values import unittest from unittest.mock import MagicMock, call, patch @@ -29,7 +33,7 @@ class MockHTTPS(object): def getresponse(self): return MockResponse() -http.client.HTTPSConnection = MockHTTPS +ham_predict_updater.http.client.HTTPSConnection = MockHTTPS logging.basicConfig( format="%(asctime)s %(levelname)s:%(message)s", level=logging.DEBUG @@ -37,30 +41,41 @@ logging.basicConfig( class TestAmateurPrediction(unittest.TestCase): def setUp(self): - es.request = MagicMock(side_effect=mock_es_request) - client.connect = MagicMock() - client.loop_start = MagicMock() - client.username_pw_set = MagicMock() - client.tls_set = MagicMock() - client.publish = MagicMock() - on_connect(client, "userdata", "flags", 0) + ham_predict_updater.es.request = MagicMock(side_effect=mock_es_request) + ham_predict_updater.client.connect = MagicMock() + ham_predict_updater.client.loop_start = MagicMock() + ham_predict_updater.client.username_pw_set = MagicMock() + ham_predict_updater.client.tls_set = MagicMock() + ham_predict_updater.client.publish = MagicMock() + ham_predict_updater.on_connect(ham_predict_updater.client, "userdata", "flags", 0) + def tearDown(self): # reset some of the globals that get updated + ham_predict_updater.client = ham_predict_updater.mqtt.Client(transport="websockets") + ham_predict_updater.setup = False + ham_predict_updater.connected_flag = False + @patch("time.sleep") def test_float_prediction(self, MockSleep): - predict({},{}) + ham_predict_updater.predict({},{}) date_prefix = datetime.now().strftime("%Y-%m") - es.request.assert_has_calls( + ham_predict_updater.es.request.assert_has_calls( [ call(json.dumps(test_values.flight_doc_search),"flight-doc/_search", "POST"), call(json.dumps(test_values.ham_telm_search), "ham-telm-*/_search", "GET"), call(test_values.es_bulk_upload,f"ham-predictions-{date_prefix}/_bulk","POST") ] ) - client.username_pw_set.assert_called() - client.loop_start.assert_called() - client.connect.assert_called() - client.publish.assert_has_calls([test_values.mqtt_publish_call]) + ham_predict_updater.client.username_pw_set.assert_called() + ham_predict_updater.client.loop_start.assert_called() + ham_predict_updater.client.connect.assert_called() + ham_predict_updater.client.publish.assert_has_calls([test_values.mqtt_publish_call]) time.sleep.assert_called_with(0.5) # make sure we sleep to let paho mqtt queue clear + + @patch('ham_predict_updater.connect') + def test_connect_only_called_once(self, mock_connect): + ham_predict_updater.predict({},{}) + ham_predict_updater.predict({},{}) + mock_connect.assert_called_once() if __name__ == '__main__': unittest.main()