diff --git a/src/allmydata/test/test_storage_http.py b/src/allmydata/test/test_storage_http.py index 3dc6bac96..80bd2661b 100644 --- a/src/allmydata/test/test_storage_http.py +++ b/src/allmydata/test/test_storage_http.py @@ -19,7 +19,7 @@ from base64 import b64encode from twisted.internet.defer import inlineCallbacks -from hypothesis import given, strategies as st +from hypothesis import assume, given, strategies as st from fixtures import Fixture, TempDir from treq.testing import StubTreq from klein import Klein @@ -37,6 +37,35 @@ from ..storage.http_server import ( from ..storage.http_client import StorageClient, ClientException +def _post_process(params): + secret_types, secrets = params + secrets = {t: s for (t, s) in zip(secret_types, secrets)} + headers = [ + "{} {}".format( + secret_type.value, str(b64encode(secrets[secret_type]), "ascii").strip() + ) + for secret_type in secret_types + ] + return secrets, headers + + +# Creates a tuple of ({Secret enum value: secret_bytes}, [http headers with secrets]). +SECRETS_STRATEGY = ( + st.sets(st.sampled_from(Secrets)) + .flatmap( + lambda secret_types: st.tuples( + st.just(secret_types), + st.lists( + st.binary(min_size=32, max_size=32), + min_size=len(secret_types), + max_size=len(secret_types), + ), + ) + ) + .map(_post_process) +) + + class ExtractSecretsTests(SyncTestCase): """ Tests for ``_extract_secrets``. @@ -47,54 +76,31 @@ class ExtractSecretsTests(SyncTestCase): raise SkipTest("Not going to bother supporting Python 2") super(ExtractSecretsTests, self).setUp() - @given( - params=st.sets(st.sampled_from(Secrets)).flatmap( - lambda secret_types: st.tuples( - st.just(secret_types), - st.lists( - st.binary(min_size=32, max_size=32), - min_size=len(secret_types), - max_size=len(secret_types), - ), - ) - ) - ) - def test_extract_secrets(self, params): + @given(secrets_to_send=SECRETS_STRATEGY) + def test_extract_secrets(self, secrets_to_send): """ ``_extract_secrets()`` returns a dictionary with the extracted secrets if the input secrets match the required secrets. """ - secret_types, secrets = params - secrets = {t: s for (t, s) in zip(secret_types, secrets)} - headers = [ - "{} {}".format( - secret_type.value, str(b64encode(secrets[secret_type]), "ascii").strip() - ) - for secret_type in secret_types - ] + secrets, headers = secrets_to_send # No secrets needed, none given: - self.assertEqual(_extract_secrets(headers, secret_types), secrets) + self.assertEqual(_extract_secrets(headers, secrets.keys()), secrets) - def test_wrong_number_of_secrets(self): + @given( + secrets_to_send=SECRETS_STRATEGY, + secrets_to_require=st.sets(st.sampled_from(Secrets)), + ) + def test_wrong_number_of_secrets(self, secrets_to_send, secrets_to_require): """ If the wrong number of secrets are passed to ``_extract_secrets``, a ``ClientSecretsException`` is raised. """ - secret1 = b"\xFF\x11ZEBRa" - lease_secret = "lease-renew-secret " + str(b64encode(secret1), "ascii").strip() + secrets_to_send, headers = secrets_to_send + assume(secrets_to_send.keys() != secrets_to_require) - # Missing secret: with self.assertRaises(ClientSecretsException): - _extract_secrets([], {Secrets.LEASE_RENEW}) - - # Wrong secret: - with self.assertRaises(ClientSecretsException): - _extract_secrets([lease_secret], {Secrets.UPLOAD}) - - # Extra secret: - with self.assertRaises(ClientSecretsException): - _extract_secrets([lease_secret], {}) + _extract_secrets(headers, secrets_to_require) def test_bad_secrets(self): """