Decouple from reactor

This commit is contained in:
Itamar Turner-Trauring 2023-10-16 11:16:10 -04:00
parent b60e53b3fb
commit cb83b08923
11 changed files with 67 additions and 29 deletions

View File

@ -172,9 +172,6 @@ class KeyGenerator(object):
"""I create RSA keys for mutable files. Each call to generate() returns a
single keypair."""
def __init__(self, reactor: IReactorFromThreads):
self._reactor = reactor
def generate(self) -> defer.Deferred[tuple[rsa.PublicKey, rsa.PrivateKey]]:
"""
I return a Deferred that fires with a (verifyingkey, signingkey)
@ -182,7 +179,7 @@ class KeyGenerator(object):
"""
keysize = 2048
return defer_to_thread(
self._reactor, rsa.create_signing_keypair, keysize
rsa.create_signing_keypair, keysize
).addCallback(lambda t: (t[1], t[0]))
@ -631,13 +628,11 @@ class _Client(node.Node, pollmixin.PollMixin):
}
def __init__(self, config, main_tub, i2p_provider, tor_provider, introducer_clients,
storage_farm_broker, reactor=None):
storage_farm_broker):
"""
Use :func:`allmydata.client.create_client` to instantiate one of these.
"""
node.Node.__init__(self, config, main_tub, i2p_provider, tor_provider)
if reactor is None:
from twisted.internet import reactor
self.started_timestamp = time.time()
self.logSource = "Client"
@ -649,7 +644,7 @@ class _Client(node.Node, pollmixin.PollMixin):
self.init_stats_provider()
self.init_secrets()
self.init_node_key()
self._key_generator = KeyGenerator(reactor)
self._key_generator = KeyGenerator()
key_gen_furl = config.get_config("client", "key_generator.furl", None)
if key_gen_furl:
log.msg("[client]key_generator.furl= is now ignored, see #2783")

View File

@ -13,7 +13,6 @@ if PY2:
from builtins import filter, map, zip, ascii, chr, hex, input, next, oct, open, pow, round, super, bytes, dict, list, object, range, str, max, min # noqa: F401
from zope.interface import implementer
from twisted.internet import defer, reactor
from allmydata.util import mathutil
from allmydata.util.assertutil import precondition
from allmydata.util.cputhreadpool import defer_to_thread
@ -54,7 +53,7 @@ class CRSEncoder(object):
for inshare in inshares:
assert len(inshare) == self.share_size, (len(inshare), self.share_size, self.data_size, self.required_shares)
d = defer_to_thread(reactor, self.encoder.encode, inshares, desired_share_ids)
d = defer_to_thread(self.encoder.encode, inshares, desired_share_ids)
d.addCallback(lambda shares: (shares, desired_share_ids))
return d
@ -84,7 +83,6 @@ class CRSDecoder(object):
precondition(len(some_shares) == self.required_shares,
len(some_shares), self.required_shares)
return defer_to_thread(
reactor,
self.decoder.decode,
some_shares,
[int(s) for s in their_shareids]

View File

@ -779,7 +779,7 @@ class Publish(object):
hashed = salt + sharedata
else:
hashed = sharedata
block_hash = await defer_to_thread(reactor, hashutil.block_hash, hashed)
block_hash = await defer_to_thread(hashutil.block_hash, hashed)
self.blockhashes[shareid][segnum] = block_hash
# find the writer for this share
writers = self.writers[shareid]

View File

@ -769,9 +769,9 @@ class Retrieve(object):
"block hash tree failure: %s" % e)
if self._version == MDMF_VERSION:
blockhash = await defer_to_thread(reactor, hashutil.block_hash, salt + block)
blockhash = await defer_to_thread(hashutil.block_hash, salt + block)
else:
blockhash = await defer_to_thread(reactor, hashutil.block_hash, block)
blockhash = await defer_to_thread(hashutil.block_hash, block)
# If this works without an error, then validation is
# successful.
try:

View File

@ -541,9 +541,7 @@ class StorageClient(object):
"Can't use both `message_to_serialize` and `data` "
"as keyword arguments at the same time"
)
kwargs["data"] = await defer_to_thread(
self._clock, dumps, message_to_serialize
)
kwargs["data"] = await defer_to_thread(dumps, message_to_serialize)
headers.addRawHeader("Content-Type", CBOR_MIME_TYPE)
response = await self._treq.request(
@ -566,7 +564,7 @@ class StorageClient(object):
schema.validate_cbor(data)
return loads(data)
return await defer_to_thread(self._clock, validate_and_decode)
return await defer_to_thread(validate_and_decode)
else:
raise ClientException(
-1,

View File

@ -638,7 +638,7 @@ async def read_encoded(
# Pycddl will release the GIL when validating larger documents, so
# let's take advantage of multiple CPUs:
await defer_to_thread(reactor, schema.validate_cbor, message)
await defer_to_thread(schema.validate_cbor, message)
# The CBOR parser will allocate more memory, but at least we can feed
# it the file-like object, so that if it's large it won't be make two

View File

@ -317,7 +317,7 @@ def make_nodemaker_with_storage_broker(storage_broker):
:param StorageFarmBroker peers: The storage broker to use.
"""
sh = client.SecretHolder(b"lease secret", b"convergence secret")
keygen = client.KeyGenerator(reactor)
keygen = client.KeyGenerator()
nodemaker = NodeMaker(storage_broker, sh, None,
None, None,
{"k": 3, "n": 10}, SDMF_VERSION, keygen)

View File

@ -246,15 +246,13 @@ def create_no_network_client(basedir):
from allmydata.client import read_config
config = read_config(basedir, u'client.port')
storage_broker = NoNetworkStorageBroker()
from twisted.internet import reactor
client = _NoNetworkClient(
config,
main_tub=None,
i2p_provider=None,
tor_provider=None,
introducer_clients=[],
storage_farm_broker=storage_broker,
reactor=reactor,
storage_farm_broker=storage_broker
)
# this is a (pre-existing) reference-cycle and also a bad idea, see:
# https://tahoe-lafs.org/trac/tahoe-lafs/ticket/2949

View File

@ -43,6 +43,7 @@ from testtools.matchers import Equals
from zope.interface import implementer
from ..util.deferredutil import async_to_deferred
from ..util.cputhreadpool import disable_thread_pool_for_test
from .common import SyncTestCase
from ..storage.http_common import (
get_content_type,
@ -345,6 +346,7 @@ class CustomHTTPServerTests(SyncTestCase):
def setUp(self):
super(CustomHTTPServerTests, self).setUp()
disable_thread_pool_for_test(self)
StorageClientFactory.start_test_mode(
lambda pool: self.addCleanup(pool.closeCachedConnections)
)
@ -701,6 +703,7 @@ class GenericHTTPAPITests(SyncTestCase):
def setUp(self):
super(GenericHTTPAPITests, self).setUp()
disable_thread_pool_for_test(self)
self.http = self.useFixture(HttpTestFixture())
def test_missing_authentication(self) -> None:
@ -808,6 +811,7 @@ class ImmutableHTTPAPITests(SyncTestCase):
def setUp(self):
super(ImmutableHTTPAPITests, self).setUp()
disable_thread_pool_for_test(self)
self.http = self.useFixture(HttpTestFixture())
self.imm_client = StorageClientImmutables(self.http.client)
self.general_client = StorageClientGeneral(self.http.client)
@ -1317,6 +1321,7 @@ class MutableHTTPAPIsTests(SyncTestCase):
def setUp(self):
super(MutableHTTPAPIsTests, self).setUp()
disable_thread_pool_for_test(self)
self.http = self.useFixture(HttpTestFixture())
self.mut_client = StorageClientMutables(self.http.client)
@ -1734,6 +1739,7 @@ class ImmutableSharedTests(SharedImmutableMutableTestsMixin, SyncTestCase):
def setUp(self):
super(ImmutableSharedTests, self).setUp()
disable_thread_pool_for_test(self)
self.http = self.useFixture(HttpTestFixture())
self.client = self.clientFactory(self.http.client)
self.general_client = StorageClientGeneral(self.http.client)
@ -1788,6 +1794,7 @@ class MutableSharedTests(SharedImmutableMutableTestsMixin, SyncTestCase):
def setUp(self):
super(MutableSharedTests, self).setUp()
disable_thread_pool_for_test(self)
self.http = self.useFixture(HttpTestFixture())
self.client = self.clientFactory(self.http.client)
self.general_client = StorageClientGeneral(self.http.client)

View File

@ -28,7 +28,7 @@ from allmydata.util import pollmixin
from allmydata.util import yamlutil
from allmydata.util import rrefutil
from allmydata.util.fileutil import EncryptedTemporaryFile
from allmydata.util.cputhreadpool import defer_to_thread
from allmydata.util.cputhreadpool import defer_to_thread, disable_thread_pool_for_test
from allmydata.test.common_util import ReallyEqualMixin
from .no_network import fireNow, LocalWrapper
@ -613,7 +613,7 @@ class CPUThreadPool(unittest.TestCase):
return current_thread(), args, kwargs
this_thread = current_thread().ident
result = defer_to_thread(reactor, f, 1, 3, key=4, value=5)
result = defer_to_thread(f, 1, 3, key=4, value=5)
# Callbacks run in the correct thread:
callback_thread_ident = []
@ -630,3 +630,21 @@ class CPUThreadPool(unittest.TestCase):
self.assertEqual(args, (1, 3))
self.assertEqual(kwargs, {"key": 4, "value": 5})
def test_when_disabled_runs_in_same_thread(self):
"""
If the CPU thread pool is disabled, the given function runs in the
current thread.
"""
disable_thread_pool_for_test(self)
def f(*args, **kwargs):
return current_thread().ident, args, kwargs
this_thread = current_thread().ident
result = defer_to_thread(f, 1, 3, key=4, value=5)
l = []
result.addCallback(l.append)
thread, args, kwargs = l[0]
self.assertEqual(thread, this_thread)
self.assertEqual(args, (1, 3))
self.assertEqual(kwargs, {"key": 4, "value": 5})

View File

@ -19,12 +19,13 @@ from typing import TypeVar, Callable, cast
from functools import partial
import threading
from typing_extensions import ParamSpec
from unittest import TestCase
from twisted.python.threadpool import ThreadPool
from twisted.internet.defer import Deferred
from twisted.internet.defer import Deferred, maybeDeferred
from twisted.internet.threads import deferToThreadPool
from twisted.internet.interfaces import IReactorFromThreads
from twisted.internet import reactor
_CPU_THREAD_POOL = ThreadPool(minthreads=0, maxthreads=os.cpu_count(), name="TahoeCPU")
if hasattr(threading, "_register_atexit"):
@ -46,14 +47,37 @@ _CPU_THREAD_POOL.start()
P = ParamSpec("P")
R = TypeVar("R")
# Is running in a thread pool disabled? Should only be true in synchronous unit
# tests.
_DISABLED = False
def defer_to_thread(
reactor: IReactorFromThreads, f: Callable[P, R], *args: P.args, **kwargs: P.kwargs
f: Callable[P, R], *args: P.args, **kwargs: P.kwargs
) -> Deferred[R]:
"""Run the function in a thread, return the result as a ``Deferred``."""
if _DISABLED:
return maybeDeferred(f, *args, **kwargs)
# deferToThreadPool has no type annotations...
result = deferToThreadPool(reactor, _CPU_THREAD_POOL, f, *args, **kwargs)
return cast(Deferred[R], result)
__all__ = ["defer_to_thread"]
def disable_thread_pool_for_test(test: TestCase) -> None:
"""
For the duration of the test, calls to C{defer_to_thread} will actually run
synchronously.
"""
global _DISABLED
def restore():
global _DISABLED
_DISABLED = False
test.addCleanup(restore)
_DISABLED = True
__all__ = ["defer_to_thread", "disable_thread_pool_for_test"]