mirror of
https://github.com/tahoe-lafs/tahoe-lafs.git
synced 2024-12-18 20:47:54 +00:00
Merge pull request #1342 from tahoe-lafs/4068-reduce-cpu-in-eventloop-thread
Reduce blocking operations in eventloop thread
This commit is contained in:
commit
4fbf31b00c
1
newsfragments/4068.feature
Normal file
1
newsfragments/4068.feature
Normal file
@ -0,0 +1 @@
|
||||
Some operations now run in threads, improving the responsiveness of Tahoe nodes.
|
@ -47,6 +47,8 @@ from allmydata.util.abbreviate import parse_abbreviated_size
|
||||
from allmydata.util.time_format import parse_duration, parse_date
|
||||
from allmydata.util.i2p_provider import create as create_i2p_provider
|
||||
from allmydata.util.tor_provider import create as create_tor_provider, _Provider as TorProvider
|
||||
from allmydata.util.cputhreadpool import defer_to_thread
|
||||
from allmydata.util.deferredutil import async_to_deferred
|
||||
from allmydata.stats import StatsProvider
|
||||
from allmydata.history import History
|
||||
from allmydata.interfaces import (
|
||||
@ -170,12 +172,18 @@ class KeyGenerator(object):
|
||||
"""I create RSA keys for mutable files. Each call to generate() returns a
|
||||
single keypair."""
|
||||
|
||||
def generate(self):
|
||||
"""I return a Deferred that fires with a (verifyingkey, signingkey)
|
||||
pair. The returned key will be 2048 bit"""
|
||||
@async_to_deferred
|
||||
async def generate(self) -> tuple[rsa.PublicKey, rsa.PrivateKey]:
|
||||
"""
|
||||
I return a Deferred that fires with a (verifyingkey, signingkey)
|
||||
pair. The returned key will be 2048 bit.
|
||||
"""
|
||||
keysize = 2048
|
||||
signer, verifier = rsa.create_signing_keypair(keysize)
|
||||
return defer.succeed( (verifier, signer) )
|
||||
private, public = await defer_to_thread(
|
||||
rsa.create_signing_keypair, keysize
|
||||
)
|
||||
return public, private
|
||||
|
||||
|
||||
class Terminator(service.Service):
|
||||
def __init__(self):
|
||||
|
@ -13,9 +13,10 @@ if PY2:
|
||||
from builtins import filter, map, zip, ascii, chr, hex, input, next, oct, open, pow, round, super, bytes, dict, list, object, range, str, max, min # noqa: F401
|
||||
|
||||
from zope.interface import implementer
|
||||
from twisted.internet import defer
|
||||
from allmydata.util import mathutil
|
||||
from allmydata.util.assertutil import precondition
|
||||
from allmydata.util.cputhreadpool import defer_to_thread
|
||||
from allmydata.util.deferredutil import async_to_deferred
|
||||
from allmydata.interfaces import ICodecEncoder, ICodecDecoder
|
||||
import zfec
|
||||
|
||||
@ -45,7 +46,8 @@ class CRSEncoder(object):
|
||||
def get_block_size(self):
|
||||
return self.share_size
|
||||
|
||||
def encode(self, inshares, desired_share_ids=None):
|
||||
@async_to_deferred
|
||||
async def encode(self, inshares, desired_share_ids=None):
|
||||
precondition(desired_share_ids is None or len(desired_share_ids) <= self.max_shares, desired_share_ids, self.max_shares)
|
||||
|
||||
if desired_share_ids is None:
|
||||
@ -53,9 +55,8 @@ class CRSEncoder(object):
|
||||
|
||||
for inshare in inshares:
|
||||
assert len(inshare) == self.share_size, (len(inshare), self.share_size, self.data_size, self.required_shares)
|
||||
shares = self.encoder.encode(inshares, desired_share_ids)
|
||||
|
||||
return defer.succeed((shares, desired_share_ids))
|
||||
shares = await defer_to_thread(self.encoder.encode, inshares, desired_share_ids)
|
||||
return (shares, desired_share_ids)
|
||||
|
||||
def encode_proposal(self, data, desired_share_ids=None):
|
||||
raise NotImplementedError()
|
||||
@ -77,14 +78,17 @@ class CRSDecoder(object):
|
||||
def get_needed_shares(self):
|
||||
return self.required_shares
|
||||
|
||||
def decode(self, some_shares, their_shareids):
|
||||
@async_to_deferred
|
||||
async def decode(self, some_shares, their_shareids):
|
||||
precondition(len(some_shares) == len(their_shareids),
|
||||
len(some_shares), len(their_shareids))
|
||||
precondition(len(some_shares) == self.required_shares,
|
||||
len(some_shares), self.required_shares)
|
||||
data = self.decoder.decode(some_shares,
|
||||
[int(s) for s in their_shareids])
|
||||
return defer.succeed(data)
|
||||
return await defer_to_thread(
|
||||
self.decoder.decode,
|
||||
some_shares,
|
||||
[int(s) for s in their_shareids]
|
||||
)
|
||||
|
||||
def parse_params(serializedparams):
|
||||
pieces = serializedparams.split(b"-")
|
||||
|
@ -23,6 +23,8 @@ from allmydata.interfaces import IPublishStatus, SDMF_VERSION, MDMF_VERSION, \
|
||||
IMutableUploadable
|
||||
from allmydata.util import base32, hashutil, mathutil, log
|
||||
from allmydata.util.dictutil import DictOfSets
|
||||
from allmydata.util.deferredutil import async_to_deferred
|
||||
from allmydata.util.cputhreadpool import defer_to_thread
|
||||
from allmydata import hashtree, codec
|
||||
from allmydata.storage.server import si_b2a
|
||||
from foolscap.api import eventually, fireEventually
|
||||
@ -706,7 +708,8 @@ class Publish(object):
|
||||
writer.put_salt(salt)
|
||||
|
||||
|
||||
def _encode_segment(self, segnum):
|
||||
@async_to_deferred
|
||||
async def _encode_segment(self, segnum):
|
||||
"""
|
||||
I encrypt and encode the segment segnum.
|
||||
"""
|
||||
@ -726,13 +729,17 @@ class Publish(object):
|
||||
|
||||
assert len(data) == segsize, len(data)
|
||||
|
||||
salt = os.urandom(16)
|
||||
|
||||
key = hashutil.ssk_readkey_data_hash(salt, self.readkey)
|
||||
self._status.set_status("Encrypting")
|
||||
encryptor = aes.create_encryptor(key)
|
||||
crypttext = aes.encrypt_data(encryptor, data)
|
||||
assert len(crypttext) == len(data)
|
||||
|
||||
def encrypt(readkey):
|
||||
salt = os.urandom(16)
|
||||
key = hashutil.ssk_readkey_data_hash(salt, readkey)
|
||||
encryptor = aes.create_encryptor(key)
|
||||
crypttext = aes.encrypt_data(encryptor, data)
|
||||
assert len(crypttext) == len(data)
|
||||
return salt, crypttext
|
||||
|
||||
salt, crypttext = await defer_to_thread(encrypt, self.readkey)
|
||||
|
||||
now = time.time()
|
||||
self._status.accumulate_encrypt_time(now - started)
|
||||
@ -753,16 +760,14 @@ class Publish(object):
|
||||
piece = piece + b"\x00"*(piece_size - len(piece)) # padding
|
||||
crypttext_pieces[i] = piece
|
||||
assert len(piece) == piece_size
|
||||
d = fec.encode(crypttext_pieces)
|
||||
def _done_encoding(res):
|
||||
elapsed = time.time() - started
|
||||
self._status.accumulate_encode_time(elapsed)
|
||||
return (res, salt)
|
||||
d.addCallback(_done_encoding)
|
||||
return d
|
||||
|
||||
res = await fec.encode(crypttext_pieces)
|
||||
elapsed = time.time() - started
|
||||
self._status.accumulate_encode_time(elapsed)
|
||||
return (res, salt)
|
||||
|
||||
def _push_segment(self, encoded_and_salt, segnum):
|
||||
@async_to_deferred
|
||||
async def _push_segment(self, encoded_and_salt, segnum):
|
||||
"""
|
||||
I push (data, salt) as segment number segnum.
|
||||
"""
|
||||
@ -776,7 +781,7 @@ class Publish(object):
|
||||
hashed = salt + sharedata
|
||||
else:
|
||||
hashed = sharedata
|
||||
block_hash = hashutil.block_hash(hashed)
|
||||
block_hash = await defer_to_thread(hashutil.block_hash, hashed)
|
||||
self.blockhashes[shareid][segnum] = block_hash
|
||||
# find the writer for this share
|
||||
writers = self.writers[shareid]
|
||||
|
@ -20,6 +20,7 @@ from allmydata.interfaces import IRetrieveStatus, NotEnoughSharesError, \
|
||||
from allmydata.util.assertutil import _assert, precondition
|
||||
from allmydata.util import hashutil, log, mathutil, deferredutil
|
||||
from allmydata.util.dictutil import DictOfSets
|
||||
from allmydata.util.cputhreadpool import defer_to_thread
|
||||
from allmydata import hashtree, codec
|
||||
from allmydata.storage.server import si_b2a
|
||||
|
||||
@ -734,7 +735,8 @@ class Retrieve(object):
|
||||
return None
|
||||
|
||||
|
||||
def _validate_block(self, results, segnum, reader, server, started):
|
||||
@deferredutil.async_to_deferred
|
||||
async def _validate_block(self, results, segnum, reader, server, started):
|
||||
"""
|
||||
I validate a block from one share on a remote server.
|
||||
"""
|
||||
@ -767,9 +769,9 @@ class Retrieve(object):
|
||||
"block hash tree failure: %s" % e)
|
||||
|
||||
if self._version == MDMF_VERSION:
|
||||
blockhash = hashutil.block_hash(salt + block)
|
||||
blockhash = await defer_to_thread(hashutil.block_hash, salt + block)
|
||||
else:
|
||||
blockhash = hashutil.block_hash(block)
|
||||
blockhash = await defer_to_thread(hashutil.block_hash, block)
|
||||
# If this works without an error, then validation is
|
||||
# successful.
|
||||
try:
|
||||
@ -893,8 +895,8 @@ class Retrieve(object):
|
||||
d.addCallback(_process)
|
||||
return d
|
||||
|
||||
|
||||
def _decrypt_segment(self, segment_and_salt):
|
||||
@deferredutil.async_to_deferred
|
||||
async def _decrypt_segment(self, segment_and_salt):
|
||||
"""
|
||||
I take a single segment and its salt, and decrypt it. I return
|
||||
the plaintext of the segment that is in my argument.
|
||||
@ -903,9 +905,14 @@ class Retrieve(object):
|
||||
self._set_current_status("decrypting")
|
||||
self.log("decrypting segment %d" % self._current_segment)
|
||||
started = time.time()
|
||||
key = hashutil.ssk_readkey_data_hash(salt, self._node.get_readkey())
|
||||
decryptor = aes.create_decryptor(key)
|
||||
plaintext = aes.decrypt_data(decryptor, segment)
|
||||
readkey = self._node.get_readkey()
|
||||
|
||||
def decrypt():
|
||||
key = hashutil.ssk_readkey_data_hash(salt, readkey)
|
||||
decryptor = aes.create_decryptor(key)
|
||||
return aes.decrypt_data(decryptor, segment)
|
||||
|
||||
plaintext = await defer_to_thread(decrypt)
|
||||
self._status.accumulate_decrypt_time(time.time() - started)
|
||||
return plaintext
|
||||
|
||||
|
@ -64,6 +64,7 @@ from .common import si_b2a, si_to_human_readable
|
||||
from ..util.hashutil import timing_safe_compare
|
||||
from ..util.deferredutil import async_to_deferred
|
||||
from ..util.tor_provider import _Provider as TorProvider
|
||||
from ..util.cputhreadpool import defer_to_thread
|
||||
|
||||
try:
|
||||
from txtorcon import Tor # type: ignore
|
||||
@ -473,7 +474,8 @@ class StorageClient(object):
|
||||
into corresponding HTTP headers.
|
||||
|
||||
If ``message_to_serialize`` is set, it will be serialized (by default
|
||||
with CBOR) and set as the request body.
|
||||
with CBOR) and set as the request body. It should not be mutated
|
||||
during execution of this function!
|
||||
|
||||
Default timeout is 60 seconds.
|
||||
"""
|
||||
@ -539,7 +541,7 @@ class StorageClient(object):
|
||||
"Can't use both `message_to_serialize` and `data` "
|
||||
"as keyword arguments at the same time"
|
||||
)
|
||||
kwargs["data"] = dumps(message_to_serialize)
|
||||
kwargs["data"] = await defer_to_thread(dumps, message_to_serialize)
|
||||
headers.addRawHeader("Content-Type", CBOR_MIME_TYPE)
|
||||
|
||||
response = await self._treq.request(
|
||||
@ -557,8 +559,12 @@ class StorageClient(object):
|
||||
if content_type == CBOR_MIME_TYPE:
|
||||
f = await limited_content(response, self._clock)
|
||||
data = f.read()
|
||||
schema.validate_cbor(data)
|
||||
return loads(data)
|
||||
|
||||
def validate_and_decode():
|
||||
schema.validate_cbor(data)
|
||||
return loads(data)
|
||||
|
||||
return await defer_to_thread(validate_and_decode)
|
||||
else:
|
||||
raise ClientException(
|
||||
-1,
|
||||
@ -1232,7 +1238,8 @@ class StorageClientMutables:
|
||||
return cast(
|
||||
Set[int],
|
||||
await self._client.decode_cbor(
|
||||
response, _SCHEMAS["mutable_list_shares"]
|
||||
response,
|
||||
_SCHEMAS["mutable_list_shares"],
|
||||
),
|
||||
)
|
||||
else:
|
||||
|
@ -638,10 +638,7 @@ async def read_encoded(
|
||||
|
||||
# Pycddl will release the GIL when validating larger documents, so
|
||||
# let's take advantage of multiple CPUs:
|
||||
if size > 10_000:
|
||||
await defer_to_thread(reactor, schema.validate_cbor, message)
|
||||
else:
|
||||
schema.validate_cbor(message)
|
||||
await defer_to_thread(schema.validate_cbor, message)
|
||||
|
||||
# The CBOR parser will allocate more memory, but at least we can feed
|
||||
# it the file-like object, so that if it's large it won't be make two
|
||||
|
40
src/allmydata/test/blocking.py
Normal file
40
src/allmydata/test/blocking.py
Normal file
@ -0,0 +1,40 @@
|
||||
import sys
|
||||
import traceback
|
||||
import signal
|
||||
import threading
|
||||
|
||||
from twisted.internet import reactor
|
||||
|
||||
|
||||
def print_stacks():
|
||||
print("Uh oh, something is blocking the event loop!")
|
||||
current_thread = threading.get_ident()
|
||||
for thread_id, frame in sys._current_frames().items():
|
||||
if thread_id == current_thread:
|
||||
traceback.print_stack(frame, limit=10)
|
||||
break
|
||||
|
||||
|
||||
def catch_blocking_in_event_loop(test=None):
|
||||
"""
|
||||
Print tracebacks if the event loop is blocked for more than a short amount
|
||||
of time.
|
||||
"""
|
||||
signal.signal(signal.SIGALRM, lambda *args: print_stacks())
|
||||
|
||||
current_scheduled = [None]
|
||||
|
||||
def cancel_and_rerun():
|
||||
signal.setitimer(signal.ITIMER_REAL, 0)
|
||||
signal.setitimer(signal.ITIMER_REAL, 0.015)
|
||||
current_scheduled[0] = reactor.callLater(0.01, cancel_and_rerun)
|
||||
|
||||
cancel_and_rerun()
|
||||
|
||||
def cleanup():
|
||||
signal.signal(signal.SIGALRM, signal.SIG_DFL)
|
||||
signal.setitimer(signal.ITIMER_REAL, 0)
|
||||
current_scheduled[0].cancel()
|
||||
|
||||
if test is not None:
|
||||
test.addCleanup(cleanup)
|
@ -424,7 +424,7 @@ class Check(GridTestMixin, CLITestMixin, unittest.TestCase):
|
||||
def _stash_uri(n):
|
||||
self.uriList.append(n.get_uri())
|
||||
d.addCallback(_stash_uri)
|
||||
d = c0.create_dirnode()
|
||||
d.addCallback(lambda _: c0.create_dirnode())
|
||||
d.addCallback(_stash_uri)
|
||||
|
||||
d.addCallback(lambda ign: self.do_cli("check", self.uriList[0], self.uriList[1]))
|
||||
|
@ -685,6 +685,10 @@ class SystemTestMixin(pollmixin.PollMixin, testutil.StallMixin):
|
||||
REDUCE_HTTP_CLIENT_TIMEOUT : bool = True
|
||||
|
||||
def setUp(self):
|
||||
if os.getenv("TAHOE_DEBUG_BLOCKING") == "1":
|
||||
from .blocking import catch_blocking_in_event_loop
|
||||
catch_blocking_in_event_loop(self)
|
||||
|
||||
self._http_client_pools = []
|
||||
http_client.StorageClientFactory.start_test_mode(self._got_new_http_connection_pool)
|
||||
self.addCleanup(http_client.StorageClientFactory.stop_test_mode)
|
||||
|
@ -252,7 +252,7 @@ def create_no_network_client(basedir):
|
||||
i2p_provider=None,
|
||||
tor_provider=None,
|
||||
introducer_clients=[],
|
||||
storage_farm_broker=storage_broker,
|
||||
storage_farm_broker=storage_broker
|
||||
)
|
||||
# this is a (pre-existing) reference-cycle and also a bad idea, see:
|
||||
# https://tahoe-lafs.org/trac/tahoe-lafs/ticket/2949
|
||||
|
@ -43,6 +43,7 @@ from testtools.matchers import Equals
|
||||
from zope.interface import implementer
|
||||
|
||||
from ..util.deferredutil import async_to_deferred
|
||||
from ..util.cputhreadpool import disable_thread_pool_for_test
|
||||
from .common import SyncTestCase
|
||||
from ..storage.http_common import (
|
||||
get_content_type,
|
||||
@ -345,6 +346,7 @@ class CustomHTTPServerTests(SyncTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(CustomHTTPServerTests, self).setUp()
|
||||
disable_thread_pool_for_test(self)
|
||||
StorageClientFactory.start_test_mode(
|
||||
lambda pool: self.addCleanup(pool.closeCachedConnections)
|
||||
)
|
||||
@ -701,6 +703,7 @@ class GenericHTTPAPITests(SyncTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(GenericHTTPAPITests, self).setUp()
|
||||
disable_thread_pool_for_test(self)
|
||||
self.http = self.useFixture(HttpTestFixture())
|
||||
|
||||
def test_missing_authentication(self) -> None:
|
||||
@ -808,6 +811,7 @@ class ImmutableHTTPAPITests(SyncTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(ImmutableHTTPAPITests, self).setUp()
|
||||
disable_thread_pool_for_test(self)
|
||||
self.http = self.useFixture(HttpTestFixture())
|
||||
self.imm_client = StorageClientImmutables(self.http.client)
|
||||
self.general_client = StorageClientGeneral(self.http.client)
|
||||
@ -1317,6 +1321,7 @@ class MutableHTTPAPIsTests(SyncTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(MutableHTTPAPIsTests, self).setUp()
|
||||
disable_thread_pool_for_test(self)
|
||||
self.http = self.useFixture(HttpTestFixture())
|
||||
self.mut_client = StorageClientMutables(self.http.client)
|
||||
|
||||
@ -1734,6 +1739,7 @@ class ImmutableSharedTests(SharedImmutableMutableTestsMixin, SyncTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(ImmutableSharedTests, self).setUp()
|
||||
disable_thread_pool_for_test(self)
|
||||
self.http = self.useFixture(HttpTestFixture())
|
||||
self.client = self.clientFactory(self.http.client)
|
||||
self.general_client = StorageClientGeneral(self.http.client)
|
||||
@ -1788,6 +1794,7 @@ class MutableSharedTests(SharedImmutableMutableTestsMixin, SyncTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(MutableSharedTests, self).setUp()
|
||||
disable_thread_pool_for_test(self)
|
||||
self.http = self.useFixture(HttpTestFixture())
|
||||
self.client = self.clientFactory(self.http.client)
|
||||
self.general_client = StorageClientGeneral(self.http.client)
|
||||
|
@ -18,7 +18,6 @@ import json
|
||||
from threading import current_thread
|
||||
|
||||
from twisted.trial import unittest
|
||||
from twisted.internet import reactor
|
||||
from foolscap.api import Violation, RemoteException
|
||||
|
||||
from allmydata.util import idlib, mathutil
|
||||
@ -28,7 +27,7 @@ from allmydata.util import pollmixin
|
||||
from allmydata.util import yamlutil
|
||||
from allmydata.util import rrefutil
|
||||
from allmydata.util.fileutil import EncryptedTemporaryFile
|
||||
from allmydata.util.cputhreadpool import defer_to_thread
|
||||
from allmydata.util.cputhreadpool import defer_to_thread, disable_thread_pool_for_test
|
||||
from allmydata.test.common_util import ReallyEqualMixin
|
||||
from .no_network import fireNow, LocalWrapper
|
||||
|
||||
@ -613,20 +612,25 @@ class CPUThreadPool(unittest.TestCase):
|
||||
return current_thread(), args, kwargs
|
||||
|
||||
this_thread = current_thread().ident
|
||||
result = defer_to_thread(reactor, f, 1, 3, key=4, value=5)
|
||||
|
||||
# Callbacks run in the correct thread:
|
||||
callback_thread_ident = []
|
||||
def passthrough(result):
|
||||
callback_thread_ident.append(current_thread().ident)
|
||||
return result
|
||||
|
||||
result.addCallback(passthrough)
|
||||
thread, args, kwargs = await defer_to_thread(f, 1, 3, key=4, value=5)
|
||||
|
||||
# The task ran in a different thread:
|
||||
thread, args, kwargs = await result
|
||||
self.assertEqual(callback_thread_ident[0], this_thread)
|
||||
self.assertNotEqual(thread.ident, this_thread)
|
||||
self.assertEqual(args, (1, 3))
|
||||
self.assertEqual(kwargs, {"key": 4, "value": 5})
|
||||
|
||||
async def test_when_disabled_runs_in_same_thread(self):
|
||||
"""
|
||||
If the CPU thread pool is disabled, the given function runs in the
|
||||
current thread.
|
||||
"""
|
||||
disable_thread_pool_for_test(self)
|
||||
def f(*args, **kwargs):
|
||||
return current_thread().ident, args, kwargs
|
||||
|
||||
this_thread = current_thread().ident
|
||||
thread, args, kwargs = await defer_to_thread(f, 1, 3, key=4, value=5)
|
||||
|
||||
self.assertEqual(thread, this_thread)
|
||||
self.assertEqual(args, (1, 3))
|
||||
self.assertEqual(kwargs, {"key": 4, "value": 5})
|
||||
|
@ -15,16 +15,15 @@ scheduler affinity or cgroups, but that's not the end of the world.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import TypeVar, Callable, cast
|
||||
from typing import TypeVar, Callable
|
||||
from functools import partial
|
||||
import threading
|
||||
from typing_extensions import ParamSpec
|
||||
from unittest import TestCase
|
||||
|
||||
from twisted.python.threadpool import ThreadPool
|
||||
from twisted.internet.defer import Deferred
|
||||
from twisted.internet.threads import deferToThreadPool
|
||||
from twisted.internet.interfaces import IReactorFromThreads
|
||||
|
||||
from twisted.internet import reactor
|
||||
|
||||
_CPU_THREAD_POOL = ThreadPool(minthreads=0, maxthreads=os.cpu_count(), name="TahoeCPU")
|
||||
if hasattr(threading, "_register_atexit"):
|
||||
@ -46,14 +45,43 @@ _CPU_THREAD_POOL.start()
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
# Is running in a thread pool disabled? Should only be true in synchronous unit
|
||||
# tests.
|
||||
_DISABLED = False
|
||||
|
||||
|
||||
async def defer_to_thread(f: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
"""
|
||||
Run the function in a thread, return the result.
|
||||
|
||||
However, if ``disable_thread_pool_for_test()`` was called the function will
|
||||
be called synchronously inside the current thread.
|
||||
|
||||
To reduce chances of synchronous tests being misleading as a result, this
|
||||
is an async function on presumption that will encourage immediate ``await``ing.
|
||||
"""
|
||||
if _DISABLED:
|
||||
return f(*args, **kwargs)
|
||||
|
||||
def defer_to_thread(
|
||||
reactor: IReactorFromThreads, f: Callable[P, R], *args: P.args, **kwargs: P.kwargs
|
||||
) -> Deferred[R]:
|
||||
"""Run the function in a thread, return the result as a ``Deferred``."""
|
||||
# deferToThreadPool has no type annotations...
|
||||
result = deferToThreadPool(reactor, _CPU_THREAD_POOL, f, *args, **kwargs)
|
||||
return cast(Deferred[R], result)
|
||||
result = await deferToThreadPool(reactor, _CPU_THREAD_POOL, f, *args, **kwargs)
|
||||
return result
|
||||
|
||||
|
||||
__all__ = ["defer_to_thread"]
|
||||
def disable_thread_pool_for_test(test: TestCase) -> None:
|
||||
"""
|
||||
For the duration of the test, calls to ``defer_to_thread()`` will actually
|
||||
run synchronously, which is useful for synchronous unit tests.
|
||||
"""
|
||||
global _DISABLED
|
||||
|
||||
def restore():
|
||||
global _DISABLED
|
||||
_DISABLED = False
|
||||
|
||||
test.addCleanup(restore)
|
||||
|
||||
_DISABLED = True
|
||||
|
||||
|
||||
__all__ = ["defer_to_thread", "disable_thread_pool_for_test"]
|
||||
|
Loading…
Reference in New Issue
Block a user