diff --git a/src/allmydata/client.py b/src/allmydata/client.py index 803d2946e..dd3c912de 100644 --- a/src/allmydata/client.py +++ b/src/allmydata/client.py @@ -48,6 +48,7 @@ from allmydata.util.time_format import parse_duration, parse_date from allmydata.util.i2p_provider import create as create_i2p_provider from allmydata.util.tor_provider import create as create_tor_provider, _Provider as TorProvider from allmydata.util.cputhreadpool import defer_to_thread +from allmydata.util.deferredutil import async_to_deferred from allmydata.stats import StatsProvider from allmydata.history import History from allmydata.interfaces import ( @@ -171,15 +172,17 @@ class KeyGenerator(object): """I create RSA keys for mutable files. Each call to generate() returns a single keypair.""" - def generate(self) -> defer.Deferred[tuple[rsa.PublicKey, rsa.PrivateKey]]: + @async_to_deferred + async def generate(self) -> tuple[rsa.PublicKey, rsa.PrivateKey]: """ I return a Deferred that fires with a (verifyingkey, signingkey) pair. The returned key will be 2048 bit. """ keysize = 2048 - return defer_to_thread( + private, public = await defer_to_thread( rsa.create_signing_keypair, keysize - ).addCallback(lambda t: (t[1], t[0])) + ) + return public, private class Terminator(service.Service): diff --git a/src/allmydata/codec.py b/src/allmydata/codec.py index 402af9204..af375a117 100644 --- a/src/allmydata/codec.py +++ b/src/allmydata/codec.py @@ -16,6 +16,7 @@ from zope.interface import implementer from allmydata.util import mathutil from allmydata.util.assertutil import precondition from allmydata.util.cputhreadpool import defer_to_thread +from allmydata.util.deferredutil import async_to_deferred from allmydata.interfaces import ICodecEncoder, ICodecDecoder import zfec @@ -45,7 +46,8 @@ class CRSEncoder(object): def get_block_size(self): return self.share_size - def encode(self, inshares, desired_share_ids=None): + @async_to_deferred + async def encode(self, inshares, desired_share_ids=None): precondition(desired_share_ids is None or len(desired_share_ids) <= self.max_shares, desired_share_ids, self.max_shares) if desired_share_ids is None: @@ -53,9 +55,8 @@ class CRSEncoder(object): for inshare in inshares: assert len(inshare) == self.share_size, (len(inshare), self.share_size, self.data_size, self.required_shares) - d = defer_to_thread(self.encoder.encode, inshares, desired_share_ids) - d.addCallback(lambda shares: (shares, desired_share_ids)) - return d + shares = await defer_to_thread(self.encoder.encode, inshares, desired_share_ids) + return (shares, desired_share_ids) def encode_proposal(self, data, desired_share_ids=None): raise NotImplementedError() @@ -77,12 +78,13 @@ class CRSDecoder(object): def get_needed_shares(self): return self.required_shares - def decode(self, some_shares, their_shareids): + @async_to_deferred + async def decode(self, some_shares, their_shareids): precondition(len(some_shares) == len(their_shareids), len(some_shares), len(their_shareids)) precondition(len(some_shares) == self.required_shares, len(some_shares), self.required_shares) - return defer_to_thread( + return await defer_to_thread( self.decoder.decode, some_shares, [int(s) for s in their_shareids] diff --git a/src/allmydata/test/test_util.py b/src/allmydata/test/test_util.py index 3b77b55a4..d3a36a756 100644 --- a/src/allmydata/test/test_util.py +++ b/src/allmydata/test/test_util.py @@ -612,24 +612,14 @@ class CPUThreadPool(unittest.TestCase): return current_thread(), args, kwargs this_thread = current_thread().ident - result = defer_to_thread(f, 1, 3, key=4, value=5) - - # Callbacks run in the correct thread: - callback_thread_ident = [] - def passthrough(result): - callback_thread_ident.append(current_thread().ident) - return result - - result.addCallback(passthrough) + thread, args, kwargs = await defer_to_thread(f, 1, 3, key=4, value=5) # The task ran in a different thread: - thread, args, kwargs = await result - self.assertEqual(callback_thread_ident[0], this_thread) self.assertNotEqual(thread.ident, this_thread) self.assertEqual(args, (1, 3)) self.assertEqual(kwargs, {"key": 4, "value": 5}) - def test_when_disabled_runs_in_same_thread(self): + async def test_when_disabled_runs_in_same_thread(self): """ If the CPU thread pool is disabled, the given function runs in the current thread. @@ -639,10 +629,7 @@ class CPUThreadPool(unittest.TestCase): return current_thread().ident, args, kwargs this_thread = current_thread().ident - result = defer_to_thread(f, 1, 3, key=4, value=5) - l = [] - result.addCallback(l.append) - thread, args, kwargs = l[0] + thread, args, kwargs = await defer_to_thread(f, 1, 3, key=4, value=5) self.assertEqual(thread, this_thread) self.assertEqual(args, (1, 3)) diff --git a/src/allmydata/util/cputhreadpool.py b/src/allmydata/util/cputhreadpool.py index 5c93e9e30..032a3a823 100644 --- a/src/allmydata/util/cputhreadpool.py +++ b/src/allmydata/util/cputhreadpool.py @@ -15,14 +15,13 @@ scheduler affinity or cgroups, but that's not the end of the world. """ import os -from typing import TypeVar, Callable, cast +from typing import TypeVar, Callable from functools import partial import threading from typing_extensions import ParamSpec from unittest import TestCase from twisted.python.threadpool import ThreadPool -from twisted.internet.defer import Deferred, maybeDeferred from twisted.internet.threads import deferToThreadPool from twisted.internet import reactor @@ -51,21 +50,22 @@ R = TypeVar("R") _DISABLED = False -def defer_to_thread( - f: Callable[P, R], *args: P.args, **kwargs: P.kwargs -) -> Deferred[R]: +async def defer_to_thread(f: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R: """ - Run the function in a thread, return the result as a ``Deferred``. + Run the function in a thread, return the result. However, if ``disable_thread_pool_for_test()`` was called the function will be called synchronously inside the current thread. + + To reduce chances of synchronous tests being misleading as a result, this + is an async function on presumption that will encourage immediate ``await``ing. """ if _DISABLED: - return maybeDeferred(f, *args, **kwargs) + return f(*args, **kwargs) # deferToThreadPool has no type annotations... - result = deferToThreadPool(reactor, _CPU_THREAD_POOL, f, *args, **kwargs) - return cast(Deferred[R], result) + result = await deferToThreadPool(reactor, _CPU_THREAD_POOL, f, *args, **kwargs) + return result def disable_thread_pool_for_test(test: TestCase) -> None: