refactor tests - fix issues with connect running multiple times

This commit is contained in:
xss 2023-10-22 08:47:26 +11:00
parent 8a80aa9c2f
commit c69ac8c6e7
2 changed files with 35 additions and 16 deletions

View File

@ -52,6 +52,7 @@ ALTITUDE_AMSL_THRESHOLD = 1500.0
client = mqtt.Client(transport="websockets") client = mqtt.Client(transport="websockets")
connected_flag = False connected_flag = False
setup = False
import socket import socket
socket.setdefaulttimeout(1) socket.setdefaulttimeout(1)
@ -444,8 +445,11 @@ def bulk_upload_es(index_prefix,payloads):
raise RuntimeError raise RuntimeError
def predict(event, context): def predict(event, context):
global setup
# Connect to MQTT # Connect to MQTT
connect() if not setup:
connect()
setup = True
# Use asyncio.run to synchronously "await" an async function # Use asyncio.run to synchronously "await" an async function
result = asyncio.run(predict_async(event, context)) result = asyncio.run(predict_async(event, context))
time.sleep(0.5) # give paho mqtt 500ms to send messages this could be improved on but paho mqtt is a pain to interface with time.sleep(0.5) # give paho mqtt 500ms to send messages this could be improved on but paho mqtt is a pain to interface with

View File

@ -1,4 +1,8 @@
from . import * import ham_predict_updater
import logging
import json
from datetime import datetime
import time
from . import mock_values, test_values from . import mock_values, test_values
import unittest import unittest
from unittest.mock import MagicMock, call, patch from unittest.mock import MagicMock, call, patch
@ -29,7 +33,7 @@ class MockHTTPS(object):
def getresponse(self): def getresponse(self):
return MockResponse() return MockResponse()
http.client.HTTPSConnection = MockHTTPS ham_predict_updater.http.client.HTTPSConnection = MockHTTPS
logging.basicConfig( logging.basicConfig(
format="%(asctime)s %(levelname)s:%(message)s", level=logging.DEBUG format="%(asctime)s %(levelname)s:%(message)s", level=logging.DEBUG
@ -37,30 +41,41 @@ logging.basicConfig(
class TestAmateurPrediction(unittest.TestCase): class TestAmateurPrediction(unittest.TestCase):
def setUp(self): def setUp(self):
es.request = MagicMock(side_effect=mock_es_request) ham_predict_updater.es.request = MagicMock(side_effect=mock_es_request)
client.connect = MagicMock() ham_predict_updater.client.connect = MagicMock()
client.loop_start = MagicMock() ham_predict_updater.client.loop_start = MagicMock()
client.username_pw_set = MagicMock() ham_predict_updater.client.username_pw_set = MagicMock()
client.tls_set = MagicMock() ham_predict_updater.client.tls_set = MagicMock()
client.publish = MagicMock() ham_predict_updater.client.publish = MagicMock()
on_connect(client, "userdata", "flags", 0) ham_predict_updater.on_connect(ham_predict_updater.client, "userdata", "flags", 0)
def tearDown(self): # reset some of the globals that get updated
ham_predict_updater.client = ham_predict_updater.mqtt.Client(transport="websockets")
ham_predict_updater.setup = False
ham_predict_updater.connected_flag = False
@patch("time.sleep") @patch("time.sleep")
def test_float_prediction(self, MockSleep): def test_float_prediction(self, MockSleep):
predict({},{}) ham_predict_updater.predict({},{})
date_prefix = datetime.now().strftime("%Y-%m") date_prefix = datetime.now().strftime("%Y-%m")
es.request.assert_has_calls( ham_predict_updater.es.request.assert_has_calls(
[ [
call(json.dumps(test_values.flight_doc_search),"flight-doc/_search", "POST"), call(json.dumps(test_values.flight_doc_search),"flight-doc/_search", "POST"),
call(json.dumps(test_values.ham_telm_search), "ham-telm-*/_search", "GET"), call(json.dumps(test_values.ham_telm_search), "ham-telm-*/_search", "GET"),
call(test_values.es_bulk_upload,f"ham-predictions-{date_prefix}/_bulk","POST") call(test_values.es_bulk_upload,f"ham-predictions-{date_prefix}/_bulk","POST")
] ]
) )
client.username_pw_set.assert_called() ham_predict_updater.client.username_pw_set.assert_called()
client.loop_start.assert_called() ham_predict_updater.client.loop_start.assert_called()
client.connect.assert_called() ham_predict_updater.client.connect.assert_called()
client.publish.assert_has_calls([test_values.mqtt_publish_call]) ham_predict_updater.client.publish.assert_has_calls([test_values.mqtt_publish_call])
time.sleep.assert_called_with(0.5) # make sure we sleep to let paho mqtt queue clear time.sleep.assert_called_with(0.5) # make sure we sleep to let paho mqtt queue clear
@patch('ham_predict_updater.connect')
def test_connect_only_called_once(self, mock_connect):
ham_predict_updater.predict({},{})
ham_predict_updater.predict({},{})
mock_connect.assert_called_once()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()