From cb83b08923355d93b36c5f955fcf4dfab45fc500 Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Mon, 16 Oct 2023 11:16:10 -0400 Subject: [PATCH] Decouple from reactor --- src/allmydata/client.py | 11 +++------ src/allmydata/codec.py | 4 +--- src/allmydata/mutable/publish.py | 2 +- src/allmydata/mutable/retrieve.py | 4 ++-- src/allmydata/storage/http_client.py | 6 ++--- src/allmydata/storage/http_server.py | 2 +- src/allmydata/test/mutable/util.py | 2 +- src/allmydata/test/no_network.py | 4 +--- src/allmydata/test/test_storage_http.py | 7 ++++++ src/allmydata/test/test_util.py | 22 +++++++++++++++-- src/allmydata/util/cputhreadpool.py | 32 +++++++++++++++++++++---- 11 files changed, 67 insertions(+), 29 deletions(-) diff --git a/src/allmydata/client.py b/src/allmydata/client.py index 78c8484da..9a1a75ebc 100644 --- a/src/allmydata/client.py +++ b/src/allmydata/client.py @@ -172,9 +172,6 @@ class KeyGenerator(object): """I create RSA keys for mutable files. Each call to generate() returns a single keypair.""" - def __init__(self, reactor: IReactorFromThreads): - self._reactor = reactor - def generate(self) -> defer.Deferred[tuple[rsa.PublicKey, rsa.PrivateKey]]: """ I return a Deferred that fires with a (verifyingkey, signingkey) @@ -182,7 +179,7 @@ class KeyGenerator(object): """ keysize = 2048 return defer_to_thread( - self._reactor, rsa.create_signing_keypair, keysize + rsa.create_signing_keypair, keysize ).addCallback(lambda t: (t[1], t[0])) @@ -631,13 +628,11 @@ class _Client(node.Node, pollmixin.PollMixin): } def __init__(self, config, main_tub, i2p_provider, tor_provider, introducer_clients, - storage_farm_broker, reactor=None): + storage_farm_broker): """ Use :func:`allmydata.client.create_client` to instantiate one of these. """ node.Node.__init__(self, config, main_tub, i2p_provider, tor_provider) - if reactor is None: - from twisted.internet import reactor self.started_timestamp = time.time() self.logSource = "Client" @@ -649,7 +644,7 @@ class _Client(node.Node, pollmixin.PollMixin): self.init_stats_provider() self.init_secrets() self.init_node_key() - self._key_generator = KeyGenerator(reactor) + self._key_generator = KeyGenerator() key_gen_furl = config.get_config("client", "key_generator.furl", None) if key_gen_furl: log.msg("[client]key_generator.furl= is now ignored, see #2783") diff --git a/src/allmydata/codec.py b/src/allmydata/codec.py index 51dc74a8a..402af9204 100644 --- a/src/allmydata/codec.py +++ b/src/allmydata/codec.py @@ -13,7 +13,6 @@ if PY2: from builtins import filter, map, zip, ascii, chr, hex, input, next, oct, open, pow, round, super, bytes, dict, list, object, range, str, max, min # noqa: F401 from zope.interface import implementer -from twisted.internet import defer, reactor from allmydata.util import mathutil from allmydata.util.assertutil import precondition from allmydata.util.cputhreadpool import defer_to_thread @@ -54,7 +53,7 @@ class CRSEncoder(object): for inshare in inshares: assert len(inshare) == self.share_size, (len(inshare), self.share_size, self.data_size, self.required_shares) - d = defer_to_thread(reactor, self.encoder.encode, inshares, desired_share_ids) + d = defer_to_thread(self.encoder.encode, inshares, desired_share_ids) d.addCallback(lambda shares: (shares, desired_share_ids)) return d @@ -84,7 +83,6 @@ class CRSDecoder(object): precondition(len(some_shares) == self.required_shares, len(some_shares), self.required_shares) return defer_to_thread( - reactor, self.decoder.decode, some_shares, [int(s) for s in their_shareids] diff --git a/src/allmydata/mutable/publish.py b/src/allmydata/mutable/publish.py index e262ab967..bc87d5e0c 100644 --- a/src/allmydata/mutable/publish.py +++ b/src/allmydata/mutable/publish.py @@ -779,7 +779,7 @@ class Publish(object): hashed = salt + sharedata else: hashed = sharedata - block_hash = await defer_to_thread(reactor, hashutil.block_hash, hashed) + block_hash = await defer_to_thread(hashutil.block_hash, hashed) self.blockhashes[shareid][segnum] = block_hash # find the writer for this share writers = self.writers[shareid] diff --git a/src/allmydata/mutable/retrieve.py b/src/allmydata/mutable/retrieve.py index 93d0a410f..22e846aa5 100644 --- a/src/allmydata/mutable/retrieve.py +++ b/src/allmydata/mutable/retrieve.py @@ -769,9 +769,9 @@ class Retrieve(object): "block hash tree failure: %s" % e) if self._version == MDMF_VERSION: - blockhash = await defer_to_thread(reactor, hashutil.block_hash, salt + block) + blockhash = await defer_to_thread(hashutil.block_hash, salt + block) else: - blockhash = await defer_to_thread(reactor, hashutil.block_hash, block) + blockhash = await defer_to_thread(hashutil.block_hash, block) # If this works without an error, then validation is # successful. try: diff --git a/src/allmydata/storage/http_client.py b/src/allmydata/storage/http_client.py index 41c97bfb6..f3aef8b88 100644 --- a/src/allmydata/storage/http_client.py +++ b/src/allmydata/storage/http_client.py @@ -541,9 +541,7 @@ class StorageClient(object): "Can't use both `message_to_serialize` and `data` " "as keyword arguments at the same time" ) - kwargs["data"] = await defer_to_thread( - self._clock, dumps, message_to_serialize - ) + kwargs["data"] = await defer_to_thread(dumps, message_to_serialize) headers.addRawHeader("Content-Type", CBOR_MIME_TYPE) response = await self._treq.request( @@ -566,7 +564,7 @@ class StorageClient(object): schema.validate_cbor(data) return loads(data) - return await defer_to_thread(self._clock, validate_and_decode) + return await defer_to_thread(validate_and_decode) else: raise ClientException( -1, diff --git a/src/allmydata/storage/http_server.py b/src/allmydata/storage/http_server.py index 7e6682207..5b4e02288 100644 --- a/src/allmydata/storage/http_server.py +++ b/src/allmydata/storage/http_server.py @@ -638,7 +638,7 @@ async def read_encoded( # Pycddl will release the GIL when validating larger documents, so # let's take advantage of multiple CPUs: - await defer_to_thread(reactor, schema.validate_cbor, message) + await defer_to_thread(schema.validate_cbor, message) # The CBOR parser will allocate more memory, but at least we can feed # it the file-like object, so that if it's large it won't be make two diff --git a/src/allmydata/test/mutable/util.py b/src/allmydata/test/mutable/util.py index ab2f64f36..bed350652 100644 --- a/src/allmydata/test/mutable/util.py +++ b/src/allmydata/test/mutable/util.py @@ -317,7 +317,7 @@ def make_nodemaker_with_storage_broker(storage_broker): :param StorageFarmBroker peers: The storage broker to use. """ sh = client.SecretHolder(b"lease secret", b"convergence secret") - keygen = client.KeyGenerator(reactor) + keygen = client.KeyGenerator() nodemaker = NodeMaker(storage_broker, sh, None, None, None, {"k": 3, "n": 10}, SDMF_VERSION, keygen) diff --git a/src/allmydata/test/no_network.py b/src/allmydata/test/no_network.py index 20e4057e2..dbf994ee0 100644 --- a/src/allmydata/test/no_network.py +++ b/src/allmydata/test/no_network.py @@ -246,15 +246,13 @@ def create_no_network_client(basedir): from allmydata.client import read_config config = read_config(basedir, u'client.port') storage_broker = NoNetworkStorageBroker() - from twisted.internet import reactor client = _NoNetworkClient( config, main_tub=None, i2p_provider=None, tor_provider=None, introducer_clients=[], - storage_farm_broker=storage_broker, - reactor=reactor, + storage_farm_broker=storage_broker ) # this is a (pre-existing) reference-cycle and also a bad idea, see: # https://tahoe-lafs.org/trac/tahoe-lafs/ticket/2949 diff --git a/src/allmydata/test/test_storage_http.py b/src/allmydata/test/test_storage_http.py index eaed5f07c..b866f027a 100644 --- a/src/allmydata/test/test_storage_http.py +++ b/src/allmydata/test/test_storage_http.py @@ -43,6 +43,7 @@ from testtools.matchers import Equals from zope.interface import implementer from ..util.deferredutil import async_to_deferred +from ..util.cputhreadpool import disable_thread_pool_for_test from .common import SyncTestCase from ..storage.http_common import ( get_content_type, @@ -345,6 +346,7 @@ class CustomHTTPServerTests(SyncTestCase): def setUp(self): super(CustomHTTPServerTests, self).setUp() + disable_thread_pool_for_test(self) StorageClientFactory.start_test_mode( lambda pool: self.addCleanup(pool.closeCachedConnections) ) @@ -701,6 +703,7 @@ class GenericHTTPAPITests(SyncTestCase): def setUp(self): super(GenericHTTPAPITests, self).setUp() + disable_thread_pool_for_test(self) self.http = self.useFixture(HttpTestFixture()) def test_missing_authentication(self) -> None: @@ -808,6 +811,7 @@ class ImmutableHTTPAPITests(SyncTestCase): def setUp(self): super(ImmutableHTTPAPITests, self).setUp() + disable_thread_pool_for_test(self) self.http = self.useFixture(HttpTestFixture()) self.imm_client = StorageClientImmutables(self.http.client) self.general_client = StorageClientGeneral(self.http.client) @@ -1317,6 +1321,7 @@ class MutableHTTPAPIsTests(SyncTestCase): def setUp(self): super(MutableHTTPAPIsTests, self).setUp() + disable_thread_pool_for_test(self) self.http = self.useFixture(HttpTestFixture()) self.mut_client = StorageClientMutables(self.http.client) @@ -1734,6 +1739,7 @@ class ImmutableSharedTests(SharedImmutableMutableTestsMixin, SyncTestCase): def setUp(self): super(ImmutableSharedTests, self).setUp() + disable_thread_pool_for_test(self) self.http = self.useFixture(HttpTestFixture()) self.client = self.clientFactory(self.http.client) self.general_client = StorageClientGeneral(self.http.client) @@ -1788,6 +1794,7 @@ class MutableSharedTests(SharedImmutableMutableTestsMixin, SyncTestCase): def setUp(self): super(MutableSharedTests, self).setUp() + disable_thread_pool_for_test(self) self.http = self.useFixture(HttpTestFixture()) self.client = self.clientFactory(self.http.client) self.general_client = StorageClientGeneral(self.http.client) diff --git a/src/allmydata/test/test_util.py b/src/allmydata/test/test_util.py index 111f817a8..2b3f33474 100644 --- a/src/allmydata/test/test_util.py +++ b/src/allmydata/test/test_util.py @@ -28,7 +28,7 @@ from allmydata.util import pollmixin from allmydata.util import yamlutil from allmydata.util import rrefutil from allmydata.util.fileutil import EncryptedTemporaryFile -from allmydata.util.cputhreadpool import defer_to_thread +from allmydata.util.cputhreadpool import defer_to_thread, disable_thread_pool_for_test from allmydata.test.common_util import ReallyEqualMixin from .no_network import fireNow, LocalWrapper @@ -613,7 +613,7 @@ class CPUThreadPool(unittest.TestCase): return current_thread(), args, kwargs this_thread = current_thread().ident - result = defer_to_thread(reactor, f, 1, 3, key=4, value=5) + result = defer_to_thread(f, 1, 3, key=4, value=5) # Callbacks run in the correct thread: callback_thread_ident = [] @@ -630,3 +630,21 @@ class CPUThreadPool(unittest.TestCase): self.assertEqual(args, (1, 3)) self.assertEqual(kwargs, {"key": 4, "value": 5}) + def test_when_disabled_runs_in_same_thread(self): + """ + If the CPU thread pool is disabled, the given function runs in the + current thread. + """ + disable_thread_pool_for_test(self) + def f(*args, **kwargs): + return current_thread().ident, args, kwargs + + this_thread = current_thread().ident + result = defer_to_thread(f, 1, 3, key=4, value=5) + l = [] + result.addCallback(l.append) + thread, args, kwargs = l[0] + + self.assertEqual(thread, this_thread) + self.assertEqual(args, (1, 3)) + self.assertEqual(kwargs, {"key": 4, "value": 5}) diff --git a/src/allmydata/util/cputhreadpool.py b/src/allmydata/util/cputhreadpool.py index 225232e04..33799e150 100644 --- a/src/allmydata/util/cputhreadpool.py +++ b/src/allmydata/util/cputhreadpool.py @@ -19,12 +19,13 @@ from typing import TypeVar, Callable, cast from functools import partial import threading from typing_extensions import ParamSpec +from unittest import TestCase from twisted.python.threadpool import ThreadPool -from twisted.internet.defer import Deferred +from twisted.internet.defer import Deferred, maybeDeferred from twisted.internet.threads import deferToThreadPool from twisted.internet.interfaces import IReactorFromThreads - +from twisted.internet import reactor _CPU_THREAD_POOL = ThreadPool(minthreads=0, maxthreads=os.cpu_count(), name="TahoeCPU") if hasattr(threading, "_register_atexit"): @@ -46,14 +47,37 @@ _CPU_THREAD_POOL.start() P = ParamSpec("P") R = TypeVar("R") +# Is running in a thread pool disabled? Should only be true in synchronous unit +# tests. +_DISABLED = False + def defer_to_thread( - reactor: IReactorFromThreads, f: Callable[P, R], *args: P.args, **kwargs: P.kwargs + f: Callable[P, R], *args: P.args, **kwargs: P.kwargs ) -> Deferred[R]: """Run the function in a thread, return the result as a ``Deferred``.""" + if _DISABLED: + return maybeDeferred(f, *args, **kwargs) + # deferToThreadPool has no type annotations... result = deferToThreadPool(reactor, _CPU_THREAD_POOL, f, *args, **kwargs) return cast(Deferred[R], result) -__all__ = ["defer_to_thread"] +def disable_thread_pool_for_test(test: TestCase) -> None: + """ + For the duration of the test, calls to C{defer_to_thread} will actually run + synchronously. + """ + global _DISABLED + + def restore(): + global _DISABLED + _DISABLED = False + + test.addCleanup(restore) + + _DISABLED = True + + +__all__ = ["defer_to_thread", "disable_thread_pool_for_test"]