Use hypothesis for another test.

This commit is contained in:
Itamar Turner-Trauring 2021-12-22 11:44:45 -05:00
parent 776f19cbb2
commit 8b4d166a54

View File

@ -19,7 +19,7 @@ from base64 import b64encode
from twisted.internet.defer import inlineCallbacks
from hypothesis import given, strategies as st
from hypothesis import assume, given, strategies as st
from fixtures import Fixture, TempDir
from treq.testing import StubTreq
from klein import Klein
@ -37,6 +37,35 @@ from ..storage.http_server import (
from ..storage.http_client import StorageClient, ClientException
def _post_process(params):
secret_types, secrets = params
secrets = {t: s for (t, s) in zip(secret_types, secrets)}
headers = [
"{} {}".format(
secret_type.value, str(b64encode(secrets[secret_type]), "ascii").strip()
)
for secret_type in secret_types
]
return secrets, headers
# Creates a tuple of ({Secret enum value: secret_bytes}, [http headers with secrets]).
SECRETS_STRATEGY = (
st.sets(st.sampled_from(Secrets))
.flatmap(
lambda secret_types: st.tuples(
st.just(secret_types),
st.lists(
st.binary(min_size=32, max_size=32),
min_size=len(secret_types),
max_size=len(secret_types),
),
)
)
.map(_post_process)
)
class ExtractSecretsTests(SyncTestCase):
"""
Tests for ``_extract_secrets``.
@ -47,54 +76,31 @@ class ExtractSecretsTests(SyncTestCase):
raise SkipTest("Not going to bother supporting Python 2")
super(ExtractSecretsTests, self).setUp()
@given(
params=st.sets(st.sampled_from(Secrets)).flatmap(
lambda secret_types: st.tuples(
st.just(secret_types),
st.lists(
st.binary(min_size=32, max_size=32),
min_size=len(secret_types),
max_size=len(secret_types),
),
)
)
)
def test_extract_secrets(self, params):
@given(secrets_to_send=SECRETS_STRATEGY)
def test_extract_secrets(self, secrets_to_send):
"""
``_extract_secrets()`` returns a dictionary with the extracted secrets
if the input secrets match the required secrets.
"""
secret_types, secrets = params
secrets = {t: s for (t, s) in zip(secret_types, secrets)}
headers = [
"{} {}".format(
secret_type.value, str(b64encode(secrets[secret_type]), "ascii").strip()
)
for secret_type in secret_types
]
secrets, headers = secrets_to_send
# No secrets needed, none given:
self.assertEqual(_extract_secrets(headers, secret_types), secrets)
self.assertEqual(_extract_secrets(headers, secrets.keys()), secrets)
def test_wrong_number_of_secrets(self):
@given(
secrets_to_send=SECRETS_STRATEGY,
secrets_to_require=st.sets(st.sampled_from(Secrets)),
)
def test_wrong_number_of_secrets(self, secrets_to_send, secrets_to_require):
"""
If the wrong number of secrets are passed to ``_extract_secrets``, a
``ClientSecretsException`` is raised.
"""
secret1 = b"\xFF\x11ZEBRa"
lease_secret = "lease-renew-secret " + str(b64encode(secret1), "ascii").strip()
secrets_to_send, headers = secrets_to_send
assume(secrets_to_send.keys() != secrets_to_require)
# Missing secret:
with self.assertRaises(ClientSecretsException):
_extract_secrets([], {Secrets.LEASE_RENEW})
# Wrong secret:
with self.assertRaises(ClientSecretsException):
_extract_secrets([lease_secret], {Secrets.UPLOAD})
# Extra secret:
with self.assertRaises(ClientSecretsException):
_extract_secrets([lease_secret], {})
_extract_secrets(headers, secrets_to_require)
def test_bad_secrets(self):
"""