Merge pull request #1342 from tahoe-lafs/4068-reduce-cpu-in-eventloop-thread

Reduce blocking operations in eventloop thread
This commit is contained in:
Itamar Turner-Trauring 2023-10-20 10:48:27 -04:00 committed by GitHub
commit 4fbf31b00c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 185 additions and 73 deletions

View File

@ -0,0 +1 @@
Some operations now run in threads, improving the responsiveness of Tahoe nodes.

View File

@ -47,6 +47,8 @@ from allmydata.util.abbreviate import parse_abbreviated_size
from allmydata.util.time_format import parse_duration, parse_date from allmydata.util.time_format import parse_duration, parse_date
from allmydata.util.i2p_provider import create as create_i2p_provider from allmydata.util.i2p_provider import create as create_i2p_provider
from allmydata.util.tor_provider import create as create_tor_provider, _Provider as TorProvider from allmydata.util.tor_provider import create as create_tor_provider, _Provider as TorProvider
from allmydata.util.cputhreadpool import defer_to_thread
from allmydata.util.deferredutil import async_to_deferred
from allmydata.stats import StatsProvider from allmydata.stats import StatsProvider
from allmydata.history import History from allmydata.history import History
from allmydata.interfaces import ( from allmydata.interfaces import (
@ -170,12 +172,18 @@ class KeyGenerator(object):
"""I create RSA keys for mutable files. Each call to generate() returns a """I create RSA keys for mutable files. Each call to generate() returns a
single keypair.""" single keypair."""
def generate(self): @async_to_deferred
"""I return a Deferred that fires with a (verifyingkey, signingkey) async def generate(self) -> tuple[rsa.PublicKey, rsa.PrivateKey]:
pair. The returned key will be 2048 bit""" """
I return a Deferred that fires with a (verifyingkey, signingkey)
pair. The returned key will be 2048 bit.
"""
keysize = 2048 keysize = 2048
signer, verifier = rsa.create_signing_keypair(keysize) private, public = await defer_to_thread(
return defer.succeed( (verifier, signer) ) rsa.create_signing_keypair, keysize
)
return public, private
class Terminator(service.Service): class Terminator(service.Service):
def __init__(self): def __init__(self):

View File

@ -13,9 +13,10 @@ if PY2:
from builtins import filter, map, zip, ascii, chr, hex, input, next, oct, open, pow, round, super, bytes, dict, list, object, range, str, max, min # noqa: F401 from builtins import filter, map, zip, ascii, chr, hex, input, next, oct, open, pow, round, super, bytes, dict, list, object, range, str, max, min # noqa: F401
from zope.interface import implementer from zope.interface import implementer
from twisted.internet import defer
from allmydata.util import mathutil from allmydata.util import mathutil
from allmydata.util.assertutil import precondition from allmydata.util.assertutil import precondition
from allmydata.util.cputhreadpool import defer_to_thread
from allmydata.util.deferredutil import async_to_deferred
from allmydata.interfaces import ICodecEncoder, ICodecDecoder from allmydata.interfaces import ICodecEncoder, ICodecDecoder
import zfec import zfec
@ -45,7 +46,8 @@ class CRSEncoder(object):
def get_block_size(self): def get_block_size(self):
return self.share_size return self.share_size
def encode(self, inshares, desired_share_ids=None): @async_to_deferred
async def encode(self, inshares, desired_share_ids=None):
precondition(desired_share_ids is None or len(desired_share_ids) <= self.max_shares, desired_share_ids, self.max_shares) precondition(desired_share_ids is None or len(desired_share_ids) <= self.max_shares, desired_share_ids, self.max_shares)
if desired_share_ids is None: if desired_share_ids is None:
@ -53,9 +55,8 @@ class CRSEncoder(object):
for inshare in inshares: for inshare in inshares:
assert len(inshare) == self.share_size, (len(inshare), self.share_size, self.data_size, self.required_shares) assert len(inshare) == self.share_size, (len(inshare), self.share_size, self.data_size, self.required_shares)
shares = self.encoder.encode(inshares, desired_share_ids) shares = await defer_to_thread(self.encoder.encode, inshares, desired_share_ids)
return (shares, desired_share_ids)
return defer.succeed((shares, desired_share_ids))
def encode_proposal(self, data, desired_share_ids=None): def encode_proposal(self, data, desired_share_ids=None):
raise NotImplementedError() raise NotImplementedError()
@ -77,14 +78,17 @@ class CRSDecoder(object):
def get_needed_shares(self): def get_needed_shares(self):
return self.required_shares return self.required_shares
def decode(self, some_shares, their_shareids): @async_to_deferred
async def decode(self, some_shares, their_shareids):
precondition(len(some_shares) == len(their_shareids), precondition(len(some_shares) == len(their_shareids),
len(some_shares), len(their_shareids)) len(some_shares), len(their_shareids))
precondition(len(some_shares) == self.required_shares, precondition(len(some_shares) == self.required_shares,
len(some_shares), self.required_shares) len(some_shares), self.required_shares)
data = self.decoder.decode(some_shares, return await defer_to_thread(
[int(s) for s in their_shareids]) self.decoder.decode,
return defer.succeed(data) some_shares,
[int(s) for s in their_shareids]
)
def parse_params(serializedparams): def parse_params(serializedparams):
pieces = serializedparams.split(b"-") pieces = serializedparams.split(b"-")

View File

@ -23,6 +23,8 @@ from allmydata.interfaces import IPublishStatus, SDMF_VERSION, MDMF_VERSION, \
IMutableUploadable IMutableUploadable
from allmydata.util import base32, hashutil, mathutil, log from allmydata.util import base32, hashutil, mathutil, log
from allmydata.util.dictutil import DictOfSets from allmydata.util.dictutil import DictOfSets
from allmydata.util.deferredutil import async_to_deferred
from allmydata.util.cputhreadpool import defer_to_thread
from allmydata import hashtree, codec from allmydata import hashtree, codec
from allmydata.storage.server import si_b2a from allmydata.storage.server import si_b2a
from foolscap.api import eventually, fireEventually from foolscap.api import eventually, fireEventually
@ -706,7 +708,8 @@ class Publish(object):
writer.put_salt(salt) writer.put_salt(salt)
def _encode_segment(self, segnum): @async_to_deferred
async def _encode_segment(self, segnum):
""" """
I encrypt and encode the segment segnum. I encrypt and encode the segment segnum.
""" """
@ -726,13 +729,17 @@ class Publish(object):
assert len(data) == segsize, len(data) assert len(data) == segsize, len(data)
salt = os.urandom(16)
key = hashutil.ssk_readkey_data_hash(salt, self.readkey)
self._status.set_status("Encrypting") self._status.set_status("Encrypting")
encryptor = aes.create_encryptor(key)
crypttext = aes.encrypt_data(encryptor, data) def encrypt(readkey):
assert len(crypttext) == len(data) salt = os.urandom(16)
key = hashutil.ssk_readkey_data_hash(salt, readkey)
encryptor = aes.create_encryptor(key)
crypttext = aes.encrypt_data(encryptor, data)
assert len(crypttext) == len(data)
return salt, crypttext
salt, crypttext = await defer_to_thread(encrypt, self.readkey)
now = time.time() now = time.time()
self._status.accumulate_encrypt_time(now - started) self._status.accumulate_encrypt_time(now - started)
@ -753,16 +760,14 @@ class Publish(object):
piece = piece + b"\x00"*(piece_size - len(piece)) # padding piece = piece + b"\x00"*(piece_size - len(piece)) # padding
crypttext_pieces[i] = piece crypttext_pieces[i] = piece
assert len(piece) == piece_size assert len(piece) == piece_size
d = fec.encode(crypttext_pieces)
def _done_encoding(res):
elapsed = time.time() - started
self._status.accumulate_encode_time(elapsed)
return (res, salt)
d.addCallback(_done_encoding)
return d
res = await fec.encode(crypttext_pieces)
elapsed = time.time() - started
self._status.accumulate_encode_time(elapsed)
return (res, salt)
def _push_segment(self, encoded_and_salt, segnum): @async_to_deferred
async def _push_segment(self, encoded_and_salt, segnum):
""" """
I push (data, salt) as segment number segnum. I push (data, salt) as segment number segnum.
""" """
@ -776,7 +781,7 @@ class Publish(object):
hashed = salt + sharedata hashed = salt + sharedata
else: else:
hashed = sharedata hashed = sharedata
block_hash = hashutil.block_hash(hashed) block_hash = await defer_to_thread(hashutil.block_hash, hashed)
self.blockhashes[shareid][segnum] = block_hash self.blockhashes[shareid][segnum] = block_hash
# find the writer for this share # find the writer for this share
writers = self.writers[shareid] writers = self.writers[shareid]

View File

@ -20,6 +20,7 @@ from allmydata.interfaces import IRetrieveStatus, NotEnoughSharesError, \
from allmydata.util.assertutil import _assert, precondition from allmydata.util.assertutil import _assert, precondition
from allmydata.util import hashutil, log, mathutil, deferredutil from allmydata.util import hashutil, log, mathutil, deferredutil
from allmydata.util.dictutil import DictOfSets from allmydata.util.dictutil import DictOfSets
from allmydata.util.cputhreadpool import defer_to_thread
from allmydata import hashtree, codec from allmydata import hashtree, codec
from allmydata.storage.server import si_b2a from allmydata.storage.server import si_b2a
@ -734,7 +735,8 @@ class Retrieve(object):
return None return None
def _validate_block(self, results, segnum, reader, server, started): @deferredutil.async_to_deferred
async def _validate_block(self, results, segnum, reader, server, started):
""" """
I validate a block from one share on a remote server. I validate a block from one share on a remote server.
""" """
@ -767,9 +769,9 @@ class Retrieve(object):
"block hash tree failure: %s" % e) "block hash tree failure: %s" % e)
if self._version == MDMF_VERSION: if self._version == MDMF_VERSION:
blockhash = hashutil.block_hash(salt + block) blockhash = await defer_to_thread(hashutil.block_hash, salt + block)
else: else:
blockhash = hashutil.block_hash(block) blockhash = await defer_to_thread(hashutil.block_hash, block)
# If this works without an error, then validation is # If this works without an error, then validation is
# successful. # successful.
try: try:
@ -893,8 +895,8 @@ class Retrieve(object):
d.addCallback(_process) d.addCallback(_process)
return d return d
@deferredutil.async_to_deferred
def _decrypt_segment(self, segment_and_salt): async def _decrypt_segment(self, segment_and_salt):
""" """
I take a single segment and its salt, and decrypt it. I return I take a single segment and its salt, and decrypt it. I return
the plaintext of the segment that is in my argument. the plaintext of the segment that is in my argument.
@ -903,9 +905,14 @@ class Retrieve(object):
self._set_current_status("decrypting") self._set_current_status("decrypting")
self.log("decrypting segment %d" % self._current_segment) self.log("decrypting segment %d" % self._current_segment)
started = time.time() started = time.time()
key = hashutil.ssk_readkey_data_hash(salt, self._node.get_readkey()) readkey = self._node.get_readkey()
decryptor = aes.create_decryptor(key)
plaintext = aes.decrypt_data(decryptor, segment) def decrypt():
key = hashutil.ssk_readkey_data_hash(salt, readkey)
decryptor = aes.create_decryptor(key)
return aes.decrypt_data(decryptor, segment)
plaintext = await defer_to_thread(decrypt)
self._status.accumulate_decrypt_time(time.time() - started) self._status.accumulate_decrypt_time(time.time() - started)
return plaintext return plaintext

View File

@ -64,6 +64,7 @@ from .common import si_b2a, si_to_human_readable
from ..util.hashutil import timing_safe_compare from ..util.hashutil import timing_safe_compare
from ..util.deferredutil import async_to_deferred from ..util.deferredutil import async_to_deferred
from ..util.tor_provider import _Provider as TorProvider from ..util.tor_provider import _Provider as TorProvider
from ..util.cputhreadpool import defer_to_thread
try: try:
from txtorcon import Tor # type: ignore from txtorcon import Tor # type: ignore
@ -473,7 +474,8 @@ class StorageClient(object):
into corresponding HTTP headers. into corresponding HTTP headers.
If ``message_to_serialize`` is set, it will be serialized (by default If ``message_to_serialize`` is set, it will be serialized (by default
with CBOR) and set as the request body. with CBOR) and set as the request body. It should not be mutated
during execution of this function!
Default timeout is 60 seconds. Default timeout is 60 seconds.
""" """
@ -539,7 +541,7 @@ class StorageClient(object):
"Can't use both `message_to_serialize` and `data` " "Can't use both `message_to_serialize` and `data` "
"as keyword arguments at the same time" "as keyword arguments at the same time"
) )
kwargs["data"] = dumps(message_to_serialize) kwargs["data"] = await defer_to_thread(dumps, message_to_serialize)
headers.addRawHeader("Content-Type", CBOR_MIME_TYPE) headers.addRawHeader("Content-Type", CBOR_MIME_TYPE)
response = await self._treq.request( response = await self._treq.request(
@ -557,8 +559,12 @@ class StorageClient(object):
if content_type == CBOR_MIME_TYPE: if content_type == CBOR_MIME_TYPE:
f = await limited_content(response, self._clock) f = await limited_content(response, self._clock)
data = f.read() data = f.read()
schema.validate_cbor(data)
return loads(data) def validate_and_decode():
schema.validate_cbor(data)
return loads(data)
return await defer_to_thread(validate_and_decode)
else: else:
raise ClientException( raise ClientException(
-1, -1,
@ -1232,7 +1238,8 @@ class StorageClientMutables:
return cast( return cast(
Set[int], Set[int],
await self._client.decode_cbor( await self._client.decode_cbor(
response, _SCHEMAS["mutable_list_shares"] response,
_SCHEMAS["mutable_list_shares"],
), ),
) )
else: else:

View File

@ -638,10 +638,7 @@ async def read_encoded(
# Pycddl will release the GIL when validating larger documents, so # Pycddl will release the GIL when validating larger documents, so
# let's take advantage of multiple CPUs: # let's take advantage of multiple CPUs:
if size > 10_000: await defer_to_thread(schema.validate_cbor, message)
await defer_to_thread(reactor, schema.validate_cbor, message)
else:
schema.validate_cbor(message)
# The CBOR parser will allocate more memory, but at least we can feed # The CBOR parser will allocate more memory, but at least we can feed
# it the file-like object, so that if it's large it won't be make two # it the file-like object, so that if it's large it won't be make two

View File

@ -0,0 +1,40 @@
import sys
import traceback
import signal
import threading
from twisted.internet import reactor
def print_stacks():
print("Uh oh, something is blocking the event loop!")
current_thread = threading.get_ident()
for thread_id, frame in sys._current_frames().items():
if thread_id == current_thread:
traceback.print_stack(frame, limit=10)
break
def catch_blocking_in_event_loop(test=None):
"""
Print tracebacks if the event loop is blocked for more than a short amount
of time.
"""
signal.signal(signal.SIGALRM, lambda *args: print_stacks())
current_scheduled = [None]
def cancel_and_rerun():
signal.setitimer(signal.ITIMER_REAL, 0)
signal.setitimer(signal.ITIMER_REAL, 0.015)
current_scheduled[0] = reactor.callLater(0.01, cancel_and_rerun)
cancel_and_rerun()
def cleanup():
signal.signal(signal.SIGALRM, signal.SIG_DFL)
signal.setitimer(signal.ITIMER_REAL, 0)
current_scheduled[0].cancel()
if test is not None:
test.addCleanup(cleanup)

View File

@ -424,7 +424,7 @@ class Check(GridTestMixin, CLITestMixin, unittest.TestCase):
def _stash_uri(n): def _stash_uri(n):
self.uriList.append(n.get_uri()) self.uriList.append(n.get_uri())
d.addCallback(_stash_uri) d.addCallback(_stash_uri)
d = c0.create_dirnode() d.addCallback(lambda _: c0.create_dirnode())
d.addCallback(_stash_uri) d.addCallback(_stash_uri)
d.addCallback(lambda ign: self.do_cli("check", self.uriList[0], self.uriList[1])) d.addCallback(lambda ign: self.do_cli("check", self.uriList[0], self.uriList[1]))

View File

@ -685,6 +685,10 @@ class SystemTestMixin(pollmixin.PollMixin, testutil.StallMixin):
REDUCE_HTTP_CLIENT_TIMEOUT : bool = True REDUCE_HTTP_CLIENT_TIMEOUT : bool = True
def setUp(self): def setUp(self):
if os.getenv("TAHOE_DEBUG_BLOCKING") == "1":
from .blocking import catch_blocking_in_event_loop
catch_blocking_in_event_loop(self)
self._http_client_pools = [] self._http_client_pools = []
http_client.StorageClientFactory.start_test_mode(self._got_new_http_connection_pool) http_client.StorageClientFactory.start_test_mode(self._got_new_http_connection_pool)
self.addCleanup(http_client.StorageClientFactory.stop_test_mode) self.addCleanup(http_client.StorageClientFactory.stop_test_mode)

View File

@ -252,7 +252,7 @@ def create_no_network_client(basedir):
i2p_provider=None, i2p_provider=None,
tor_provider=None, tor_provider=None,
introducer_clients=[], introducer_clients=[],
storage_farm_broker=storage_broker, storage_farm_broker=storage_broker
) )
# this is a (pre-existing) reference-cycle and also a bad idea, see: # this is a (pre-existing) reference-cycle and also a bad idea, see:
# https://tahoe-lafs.org/trac/tahoe-lafs/ticket/2949 # https://tahoe-lafs.org/trac/tahoe-lafs/ticket/2949

View File

@ -43,6 +43,7 @@ from testtools.matchers import Equals
from zope.interface import implementer from zope.interface import implementer
from ..util.deferredutil import async_to_deferred from ..util.deferredutil import async_to_deferred
from ..util.cputhreadpool import disable_thread_pool_for_test
from .common import SyncTestCase from .common import SyncTestCase
from ..storage.http_common import ( from ..storage.http_common import (
get_content_type, get_content_type,
@ -345,6 +346,7 @@ class CustomHTTPServerTests(SyncTestCase):
def setUp(self): def setUp(self):
super(CustomHTTPServerTests, self).setUp() super(CustomHTTPServerTests, self).setUp()
disable_thread_pool_for_test(self)
StorageClientFactory.start_test_mode( StorageClientFactory.start_test_mode(
lambda pool: self.addCleanup(pool.closeCachedConnections) lambda pool: self.addCleanup(pool.closeCachedConnections)
) )
@ -701,6 +703,7 @@ class GenericHTTPAPITests(SyncTestCase):
def setUp(self): def setUp(self):
super(GenericHTTPAPITests, self).setUp() super(GenericHTTPAPITests, self).setUp()
disable_thread_pool_for_test(self)
self.http = self.useFixture(HttpTestFixture()) self.http = self.useFixture(HttpTestFixture())
def test_missing_authentication(self) -> None: def test_missing_authentication(self) -> None:
@ -808,6 +811,7 @@ class ImmutableHTTPAPITests(SyncTestCase):
def setUp(self): def setUp(self):
super(ImmutableHTTPAPITests, self).setUp() super(ImmutableHTTPAPITests, self).setUp()
disable_thread_pool_for_test(self)
self.http = self.useFixture(HttpTestFixture()) self.http = self.useFixture(HttpTestFixture())
self.imm_client = StorageClientImmutables(self.http.client) self.imm_client = StorageClientImmutables(self.http.client)
self.general_client = StorageClientGeneral(self.http.client) self.general_client = StorageClientGeneral(self.http.client)
@ -1317,6 +1321,7 @@ class MutableHTTPAPIsTests(SyncTestCase):
def setUp(self): def setUp(self):
super(MutableHTTPAPIsTests, self).setUp() super(MutableHTTPAPIsTests, self).setUp()
disable_thread_pool_for_test(self)
self.http = self.useFixture(HttpTestFixture()) self.http = self.useFixture(HttpTestFixture())
self.mut_client = StorageClientMutables(self.http.client) self.mut_client = StorageClientMutables(self.http.client)
@ -1734,6 +1739,7 @@ class ImmutableSharedTests(SharedImmutableMutableTestsMixin, SyncTestCase):
def setUp(self): def setUp(self):
super(ImmutableSharedTests, self).setUp() super(ImmutableSharedTests, self).setUp()
disable_thread_pool_for_test(self)
self.http = self.useFixture(HttpTestFixture()) self.http = self.useFixture(HttpTestFixture())
self.client = self.clientFactory(self.http.client) self.client = self.clientFactory(self.http.client)
self.general_client = StorageClientGeneral(self.http.client) self.general_client = StorageClientGeneral(self.http.client)
@ -1788,6 +1794,7 @@ class MutableSharedTests(SharedImmutableMutableTestsMixin, SyncTestCase):
def setUp(self): def setUp(self):
super(MutableSharedTests, self).setUp() super(MutableSharedTests, self).setUp()
disable_thread_pool_for_test(self)
self.http = self.useFixture(HttpTestFixture()) self.http = self.useFixture(HttpTestFixture())
self.client = self.clientFactory(self.http.client) self.client = self.clientFactory(self.http.client)
self.general_client = StorageClientGeneral(self.http.client) self.general_client = StorageClientGeneral(self.http.client)

View File

@ -18,7 +18,6 @@ import json
from threading import current_thread from threading import current_thread
from twisted.trial import unittest from twisted.trial import unittest
from twisted.internet import reactor
from foolscap.api import Violation, RemoteException from foolscap.api import Violation, RemoteException
from allmydata.util import idlib, mathutil from allmydata.util import idlib, mathutil
@ -28,7 +27,7 @@ from allmydata.util import pollmixin
from allmydata.util import yamlutil from allmydata.util import yamlutil
from allmydata.util import rrefutil from allmydata.util import rrefutil
from allmydata.util.fileutil import EncryptedTemporaryFile from allmydata.util.fileutil import EncryptedTemporaryFile
from allmydata.util.cputhreadpool import defer_to_thread from allmydata.util.cputhreadpool import defer_to_thread, disable_thread_pool_for_test
from allmydata.test.common_util import ReallyEqualMixin from allmydata.test.common_util import ReallyEqualMixin
from .no_network import fireNow, LocalWrapper from .no_network import fireNow, LocalWrapper
@ -613,20 +612,25 @@ class CPUThreadPool(unittest.TestCase):
return current_thread(), args, kwargs return current_thread(), args, kwargs
this_thread = current_thread().ident this_thread = current_thread().ident
result = defer_to_thread(reactor, f, 1, 3, key=4, value=5) thread, args, kwargs = await defer_to_thread(f, 1, 3, key=4, value=5)
# Callbacks run in the correct thread:
callback_thread_ident = []
def passthrough(result):
callback_thread_ident.append(current_thread().ident)
return result
result.addCallback(passthrough)
# The task ran in a different thread: # The task ran in a different thread:
thread, args, kwargs = await result
self.assertEqual(callback_thread_ident[0], this_thread)
self.assertNotEqual(thread.ident, this_thread) self.assertNotEqual(thread.ident, this_thread)
self.assertEqual(args, (1, 3)) self.assertEqual(args, (1, 3))
self.assertEqual(kwargs, {"key": 4, "value": 5}) self.assertEqual(kwargs, {"key": 4, "value": 5})
async def test_when_disabled_runs_in_same_thread(self):
"""
If the CPU thread pool is disabled, the given function runs in the
current thread.
"""
disable_thread_pool_for_test(self)
def f(*args, **kwargs):
return current_thread().ident, args, kwargs
this_thread = current_thread().ident
thread, args, kwargs = await defer_to_thread(f, 1, 3, key=4, value=5)
self.assertEqual(thread, this_thread)
self.assertEqual(args, (1, 3))
self.assertEqual(kwargs, {"key": 4, "value": 5})

View File

@ -15,16 +15,15 @@ scheduler affinity or cgroups, but that's not the end of the world.
""" """
import os import os
from typing import TypeVar, Callable, cast from typing import TypeVar, Callable
from functools import partial from functools import partial
import threading import threading
from typing_extensions import ParamSpec from typing_extensions import ParamSpec
from unittest import TestCase
from twisted.python.threadpool import ThreadPool from twisted.python.threadpool import ThreadPool
from twisted.internet.defer import Deferred
from twisted.internet.threads import deferToThreadPool from twisted.internet.threads import deferToThreadPool
from twisted.internet.interfaces import IReactorFromThreads from twisted.internet import reactor
_CPU_THREAD_POOL = ThreadPool(minthreads=0, maxthreads=os.cpu_count(), name="TahoeCPU") _CPU_THREAD_POOL = ThreadPool(minthreads=0, maxthreads=os.cpu_count(), name="TahoeCPU")
if hasattr(threading, "_register_atexit"): if hasattr(threading, "_register_atexit"):
@ -46,14 +45,43 @@ _CPU_THREAD_POOL.start()
P = ParamSpec("P") P = ParamSpec("P")
R = TypeVar("R") R = TypeVar("R")
# Is running in a thread pool disabled? Should only be true in synchronous unit
# tests.
_DISABLED = False
async def defer_to_thread(f: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
"""
Run the function in a thread, return the result.
However, if ``disable_thread_pool_for_test()`` was called the function will
be called synchronously inside the current thread.
To reduce chances of synchronous tests being misleading as a result, this
is an async function on presumption that will encourage immediate ``await``ing.
"""
if _DISABLED:
return f(*args, **kwargs)
def defer_to_thread(
reactor: IReactorFromThreads, f: Callable[P, R], *args: P.args, **kwargs: P.kwargs
) -> Deferred[R]:
"""Run the function in a thread, return the result as a ``Deferred``."""
# deferToThreadPool has no type annotations... # deferToThreadPool has no type annotations...
result = deferToThreadPool(reactor, _CPU_THREAD_POOL, f, *args, **kwargs) result = await deferToThreadPool(reactor, _CPU_THREAD_POOL, f, *args, **kwargs)
return cast(Deferred[R], result) return result
__all__ = ["defer_to_thread"] def disable_thread_pool_for_test(test: TestCase) -> None:
"""
For the duration of the test, calls to ``defer_to_thread()`` will actually
run synchronously, which is useful for synchronous unit tests.
"""
global _DISABLED
def restore():
global _DISABLED
_DISABLED = False
test.addCleanup(restore)
_DISABLED = True
__all__ = ["defer_to_thread", "disable_thread_pool_for_test"]