mirror of
https://github.com/tahoe-lafs/tahoe-lafs.git
synced 2024-12-19 04:57:54 +00:00
Merge pull request #1342 from tahoe-lafs/4068-reduce-cpu-in-eventloop-thread
Reduce blocking operations in eventloop thread
This commit is contained in:
commit
4fbf31b00c
1
newsfragments/4068.feature
Normal file
1
newsfragments/4068.feature
Normal file
@ -0,0 +1 @@
|
|||||||
|
Some operations now run in threads, improving the responsiveness of Tahoe nodes.
|
@ -47,6 +47,8 @@ from allmydata.util.abbreviate import parse_abbreviated_size
|
|||||||
from allmydata.util.time_format import parse_duration, parse_date
|
from allmydata.util.time_format import parse_duration, parse_date
|
||||||
from allmydata.util.i2p_provider import create as create_i2p_provider
|
from allmydata.util.i2p_provider import create as create_i2p_provider
|
||||||
from allmydata.util.tor_provider import create as create_tor_provider, _Provider as TorProvider
|
from allmydata.util.tor_provider import create as create_tor_provider, _Provider as TorProvider
|
||||||
|
from allmydata.util.cputhreadpool import defer_to_thread
|
||||||
|
from allmydata.util.deferredutil import async_to_deferred
|
||||||
from allmydata.stats import StatsProvider
|
from allmydata.stats import StatsProvider
|
||||||
from allmydata.history import History
|
from allmydata.history import History
|
||||||
from allmydata.interfaces import (
|
from allmydata.interfaces import (
|
||||||
@ -170,12 +172,18 @@ class KeyGenerator(object):
|
|||||||
"""I create RSA keys for mutable files. Each call to generate() returns a
|
"""I create RSA keys for mutable files. Each call to generate() returns a
|
||||||
single keypair."""
|
single keypair."""
|
||||||
|
|
||||||
def generate(self):
|
@async_to_deferred
|
||||||
"""I return a Deferred that fires with a (verifyingkey, signingkey)
|
async def generate(self) -> tuple[rsa.PublicKey, rsa.PrivateKey]:
|
||||||
pair. The returned key will be 2048 bit"""
|
"""
|
||||||
|
I return a Deferred that fires with a (verifyingkey, signingkey)
|
||||||
|
pair. The returned key will be 2048 bit.
|
||||||
|
"""
|
||||||
keysize = 2048
|
keysize = 2048
|
||||||
signer, verifier = rsa.create_signing_keypair(keysize)
|
private, public = await defer_to_thread(
|
||||||
return defer.succeed( (verifier, signer) )
|
rsa.create_signing_keypair, keysize
|
||||||
|
)
|
||||||
|
return public, private
|
||||||
|
|
||||||
|
|
||||||
class Terminator(service.Service):
|
class Terminator(service.Service):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -13,9 +13,10 @@ if PY2:
|
|||||||
from builtins import filter, map, zip, ascii, chr, hex, input, next, oct, open, pow, round, super, bytes, dict, list, object, range, str, max, min # noqa: F401
|
from builtins import filter, map, zip, ascii, chr, hex, input, next, oct, open, pow, round, super, bytes, dict, list, object, range, str, max, min # noqa: F401
|
||||||
|
|
||||||
from zope.interface import implementer
|
from zope.interface import implementer
|
||||||
from twisted.internet import defer
|
|
||||||
from allmydata.util import mathutil
|
from allmydata.util import mathutil
|
||||||
from allmydata.util.assertutil import precondition
|
from allmydata.util.assertutil import precondition
|
||||||
|
from allmydata.util.cputhreadpool import defer_to_thread
|
||||||
|
from allmydata.util.deferredutil import async_to_deferred
|
||||||
from allmydata.interfaces import ICodecEncoder, ICodecDecoder
|
from allmydata.interfaces import ICodecEncoder, ICodecDecoder
|
||||||
import zfec
|
import zfec
|
||||||
|
|
||||||
@ -45,7 +46,8 @@ class CRSEncoder(object):
|
|||||||
def get_block_size(self):
|
def get_block_size(self):
|
||||||
return self.share_size
|
return self.share_size
|
||||||
|
|
||||||
def encode(self, inshares, desired_share_ids=None):
|
@async_to_deferred
|
||||||
|
async def encode(self, inshares, desired_share_ids=None):
|
||||||
precondition(desired_share_ids is None or len(desired_share_ids) <= self.max_shares, desired_share_ids, self.max_shares)
|
precondition(desired_share_ids is None or len(desired_share_ids) <= self.max_shares, desired_share_ids, self.max_shares)
|
||||||
|
|
||||||
if desired_share_ids is None:
|
if desired_share_ids is None:
|
||||||
@ -53,9 +55,8 @@ class CRSEncoder(object):
|
|||||||
|
|
||||||
for inshare in inshares:
|
for inshare in inshares:
|
||||||
assert len(inshare) == self.share_size, (len(inshare), self.share_size, self.data_size, self.required_shares)
|
assert len(inshare) == self.share_size, (len(inshare), self.share_size, self.data_size, self.required_shares)
|
||||||
shares = self.encoder.encode(inshares, desired_share_ids)
|
shares = await defer_to_thread(self.encoder.encode, inshares, desired_share_ids)
|
||||||
|
return (shares, desired_share_ids)
|
||||||
return defer.succeed((shares, desired_share_ids))
|
|
||||||
|
|
||||||
def encode_proposal(self, data, desired_share_ids=None):
|
def encode_proposal(self, data, desired_share_ids=None):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
@ -77,14 +78,17 @@ class CRSDecoder(object):
|
|||||||
def get_needed_shares(self):
|
def get_needed_shares(self):
|
||||||
return self.required_shares
|
return self.required_shares
|
||||||
|
|
||||||
def decode(self, some_shares, their_shareids):
|
@async_to_deferred
|
||||||
|
async def decode(self, some_shares, their_shareids):
|
||||||
precondition(len(some_shares) == len(their_shareids),
|
precondition(len(some_shares) == len(their_shareids),
|
||||||
len(some_shares), len(their_shareids))
|
len(some_shares), len(their_shareids))
|
||||||
precondition(len(some_shares) == self.required_shares,
|
precondition(len(some_shares) == self.required_shares,
|
||||||
len(some_shares), self.required_shares)
|
len(some_shares), self.required_shares)
|
||||||
data = self.decoder.decode(some_shares,
|
return await defer_to_thread(
|
||||||
[int(s) for s in their_shareids])
|
self.decoder.decode,
|
||||||
return defer.succeed(data)
|
some_shares,
|
||||||
|
[int(s) for s in their_shareids]
|
||||||
|
)
|
||||||
|
|
||||||
def parse_params(serializedparams):
|
def parse_params(serializedparams):
|
||||||
pieces = serializedparams.split(b"-")
|
pieces = serializedparams.split(b"-")
|
||||||
|
@ -23,6 +23,8 @@ from allmydata.interfaces import IPublishStatus, SDMF_VERSION, MDMF_VERSION, \
|
|||||||
IMutableUploadable
|
IMutableUploadable
|
||||||
from allmydata.util import base32, hashutil, mathutil, log
|
from allmydata.util import base32, hashutil, mathutil, log
|
||||||
from allmydata.util.dictutil import DictOfSets
|
from allmydata.util.dictutil import DictOfSets
|
||||||
|
from allmydata.util.deferredutil import async_to_deferred
|
||||||
|
from allmydata.util.cputhreadpool import defer_to_thread
|
||||||
from allmydata import hashtree, codec
|
from allmydata import hashtree, codec
|
||||||
from allmydata.storage.server import si_b2a
|
from allmydata.storage.server import si_b2a
|
||||||
from foolscap.api import eventually, fireEventually
|
from foolscap.api import eventually, fireEventually
|
||||||
@ -706,7 +708,8 @@ class Publish(object):
|
|||||||
writer.put_salt(salt)
|
writer.put_salt(salt)
|
||||||
|
|
||||||
|
|
||||||
def _encode_segment(self, segnum):
|
@async_to_deferred
|
||||||
|
async def _encode_segment(self, segnum):
|
||||||
"""
|
"""
|
||||||
I encrypt and encode the segment segnum.
|
I encrypt and encode the segment segnum.
|
||||||
"""
|
"""
|
||||||
@ -726,13 +729,17 @@ class Publish(object):
|
|||||||
|
|
||||||
assert len(data) == segsize, len(data)
|
assert len(data) == segsize, len(data)
|
||||||
|
|
||||||
salt = os.urandom(16)
|
|
||||||
|
|
||||||
key = hashutil.ssk_readkey_data_hash(salt, self.readkey)
|
|
||||||
self._status.set_status("Encrypting")
|
self._status.set_status("Encrypting")
|
||||||
encryptor = aes.create_encryptor(key)
|
|
||||||
crypttext = aes.encrypt_data(encryptor, data)
|
def encrypt(readkey):
|
||||||
assert len(crypttext) == len(data)
|
salt = os.urandom(16)
|
||||||
|
key = hashutil.ssk_readkey_data_hash(salt, readkey)
|
||||||
|
encryptor = aes.create_encryptor(key)
|
||||||
|
crypttext = aes.encrypt_data(encryptor, data)
|
||||||
|
assert len(crypttext) == len(data)
|
||||||
|
return salt, crypttext
|
||||||
|
|
||||||
|
salt, crypttext = await defer_to_thread(encrypt, self.readkey)
|
||||||
|
|
||||||
now = time.time()
|
now = time.time()
|
||||||
self._status.accumulate_encrypt_time(now - started)
|
self._status.accumulate_encrypt_time(now - started)
|
||||||
@ -753,16 +760,14 @@ class Publish(object):
|
|||||||
piece = piece + b"\x00"*(piece_size - len(piece)) # padding
|
piece = piece + b"\x00"*(piece_size - len(piece)) # padding
|
||||||
crypttext_pieces[i] = piece
|
crypttext_pieces[i] = piece
|
||||||
assert len(piece) == piece_size
|
assert len(piece) == piece_size
|
||||||
d = fec.encode(crypttext_pieces)
|
|
||||||
def _done_encoding(res):
|
|
||||||
elapsed = time.time() - started
|
|
||||||
self._status.accumulate_encode_time(elapsed)
|
|
||||||
return (res, salt)
|
|
||||||
d.addCallback(_done_encoding)
|
|
||||||
return d
|
|
||||||
|
|
||||||
|
res = await fec.encode(crypttext_pieces)
|
||||||
|
elapsed = time.time() - started
|
||||||
|
self._status.accumulate_encode_time(elapsed)
|
||||||
|
return (res, salt)
|
||||||
|
|
||||||
def _push_segment(self, encoded_and_salt, segnum):
|
@async_to_deferred
|
||||||
|
async def _push_segment(self, encoded_and_salt, segnum):
|
||||||
"""
|
"""
|
||||||
I push (data, salt) as segment number segnum.
|
I push (data, salt) as segment number segnum.
|
||||||
"""
|
"""
|
||||||
@ -776,7 +781,7 @@ class Publish(object):
|
|||||||
hashed = salt + sharedata
|
hashed = salt + sharedata
|
||||||
else:
|
else:
|
||||||
hashed = sharedata
|
hashed = sharedata
|
||||||
block_hash = hashutil.block_hash(hashed)
|
block_hash = await defer_to_thread(hashutil.block_hash, hashed)
|
||||||
self.blockhashes[shareid][segnum] = block_hash
|
self.blockhashes[shareid][segnum] = block_hash
|
||||||
# find the writer for this share
|
# find the writer for this share
|
||||||
writers = self.writers[shareid]
|
writers = self.writers[shareid]
|
||||||
|
@ -20,6 +20,7 @@ from allmydata.interfaces import IRetrieveStatus, NotEnoughSharesError, \
|
|||||||
from allmydata.util.assertutil import _assert, precondition
|
from allmydata.util.assertutil import _assert, precondition
|
||||||
from allmydata.util import hashutil, log, mathutil, deferredutil
|
from allmydata.util import hashutil, log, mathutil, deferredutil
|
||||||
from allmydata.util.dictutil import DictOfSets
|
from allmydata.util.dictutil import DictOfSets
|
||||||
|
from allmydata.util.cputhreadpool import defer_to_thread
|
||||||
from allmydata import hashtree, codec
|
from allmydata import hashtree, codec
|
||||||
from allmydata.storage.server import si_b2a
|
from allmydata.storage.server import si_b2a
|
||||||
|
|
||||||
@ -734,7 +735,8 @@ class Retrieve(object):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _validate_block(self, results, segnum, reader, server, started):
|
@deferredutil.async_to_deferred
|
||||||
|
async def _validate_block(self, results, segnum, reader, server, started):
|
||||||
"""
|
"""
|
||||||
I validate a block from one share on a remote server.
|
I validate a block from one share on a remote server.
|
||||||
"""
|
"""
|
||||||
@ -767,9 +769,9 @@ class Retrieve(object):
|
|||||||
"block hash tree failure: %s" % e)
|
"block hash tree failure: %s" % e)
|
||||||
|
|
||||||
if self._version == MDMF_VERSION:
|
if self._version == MDMF_VERSION:
|
||||||
blockhash = hashutil.block_hash(salt + block)
|
blockhash = await defer_to_thread(hashutil.block_hash, salt + block)
|
||||||
else:
|
else:
|
||||||
blockhash = hashutil.block_hash(block)
|
blockhash = await defer_to_thread(hashutil.block_hash, block)
|
||||||
# If this works without an error, then validation is
|
# If this works without an error, then validation is
|
||||||
# successful.
|
# successful.
|
||||||
try:
|
try:
|
||||||
@ -893,8 +895,8 @@ class Retrieve(object):
|
|||||||
d.addCallback(_process)
|
d.addCallback(_process)
|
||||||
return d
|
return d
|
||||||
|
|
||||||
|
@deferredutil.async_to_deferred
|
||||||
def _decrypt_segment(self, segment_and_salt):
|
async def _decrypt_segment(self, segment_and_salt):
|
||||||
"""
|
"""
|
||||||
I take a single segment and its salt, and decrypt it. I return
|
I take a single segment and its salt, and decrypt it. I return
|
||||||
the plaintext of the segment that is in my argument.
|
the plaintext of the segment that is in my argument.
|
||||||
@ -903,9 +905,14 @@ class Retrieve(object):
|
|||||||
self._set_current_status("decrypting")
|
self._set_current_status("decrypting")
|
||||||
self.log("decrypting segment %d" % self._current_segment)
|
self.log("decrypting segment %d" % self._current_segment)
|
||||||
started = time.time()
|
started = time.time()
|
||||||
key = hashutil.ssk_readkey_data_hash(salt, self._node.get_readkey())
|
readkey = self._node.get_readkey()
|
||||||
decryptor = aes.create_decryptor(key)
|
|
||||||
plaintext = aes.decrypt_data(decryptor, segment)
|
def decrypt():
|
||||||
|
key = hashutil.ssk_readkey_data_hash(salt, readkey)
|
||||||
|
decryptor = aes.create_decryptor(key)
|
||||||
|
return aes.decrypt_data(decryptor, segment)
|
||||||
|
|
||||||
|
plaintext = await defer_to_thread(decrypt)
|
||||||
self._status.accumulate_decrypt_time(time.time() - started)
|
self._status.accumulate_decrypt_time(time.time() - started)
|
||||||
return plaintext
|
return plaintext
|
||||||
|
|
||||||
|
@ -64,6 +64,7 @@ from .common import si_b2a, si_to_human_readable
|
|||||||
from ..util.hashutil import timing_safe_compare
|
from ..util.hashutil import timing_safe_compare
|
||||||
from ..util.deferredutil import async_to_deferred
|
from ..util.deferredutil import async_to_deferred
|
||||||
from ..util.tor_provider import _Provider as TorProvider
|
from ..util.tor_provider import _Provider as TorProvider
|
||||||
|
from ..util.cputhreadpool import defer_to_thread
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from txtorcon import Tor # type: ignore
|
from txtorcon import Tor # type: ignore
|
||||||
@ -473,7 +474,8 @@ class StorageClient(object):
|
|||||||
into corresponding HTTP headers.
|
into corresponding HTTP headers.
|
||||||
|
|
||||||
If ``message_to_serialize`` is set, it will be serialized (by default
|
If ``message_to_serialize`` is set, it will be serialized (by default
|
||||||
with CBOR) and set as the request body.
|
with CBOR) and set as the request body. It should not be mutated
|
||||||
|
during execution of this function!
|
||||||
|
|
||||||
Default timeout is 60 seconds.
|
Default timeout is 60 seconds.
|
||||||
"""
|
"""
|
||||||
@ -539,7 +541,7 @@ class StorageClient(object):
|
|||||||
"Can't use both `message_to_serialize` and `data` "
|
"Can't use both `message_to_serialize` and `data` "
|
||||||
"as keyword arguments at the same time"
|
"as keyword arguments at the same time"
|
||||||
)
|
)
|
||||||
kwargs["data"] = dumps(message_to_serialize)
|
kwargs["data"] = await defer_to_thread(dumps, message_to_serialize)
|
||||||
headers.addRawHeader("Content-Type", CBOR_MIME_TYPE)
|
headers.addRawHeader("Content-Type", CBOR_MIME_TYPE)
|
||||||
|
|
||||||
response = await self._treq.request(
|
response = await self._treq.request(
|
||||||
@ -557,8 +559,12 @@ class StorageClient(object):
|
|||||||
if content_type == CBOR_MIME_TYPE:
|
if content_type == CBOR_MIME_TYPE:
|
||||||
f = await limited_content(response, self._clock)
|
f = await limited_content(response, self._clock)
|
||||||
data = f.read()
|
data = f.read()
|
||||||
schema.validate_cbor(data)
|
|
||||||
return loads(data)
|
def validate_and_decode():
|
||||||
|
schema.validate_cbor(data)
|
||||||
|
return loads(data)
|
||||||
|
|
||||||
|
return await defer_to_thread(validate_and_decode)
|
||||||
else:
|
else:
|
||||||
raise ClientException(
|
raise ClientException(
|
||||||
-1,
|
-1,
|
||||||
@ -1232,7 +1238,8 @@ class StorageClientMutables:
|
|||||||
return cast(
|
return cast(
|
||||||
Set[int],
|
Set[int],
|
||||||
await self._client.decode_cbor(
|
await self._client.decode_cbor(
|
||||||
response, _SCHEMAS["mutable_list_shares"]
|
response,
|
||||||
|
_SCHEMAS["mutable_list_shares"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -638,10 +638,7 @@ async def read_encoded(
|
|||||||
|
|
||||||
# Pycddl will release the GIL when validating larger documents, so
|
# Pycddl will release the GIL when validating larger documents, so
|
||||||
# let's take advantage of multiple CPUs:
|
# let's take advantage of multiple CPUs:
|
||||||
if size > 10_000:
|
await defer_to_thread(schema.validate_cbor, message)
|
||||||
await defer_to_thread(reactor, schema.validate_cbor, message)
|
|
||||||
else:
|
|
||||||
schema.validate_cbor(message)
|
|
||||||
|
|
||||||
# The CBOR parser will allocate more memory, but at least we can feed
|
# The CBOR parser will allocate more memory, but at least we can feed
|
||||||
# it the file-like object, so that if it's large it won't be make two
|
# it the file-like object, so that if it's large it won't be make two
|
||||||
|
40
src/allmydata/test/blocking.py
Normal file
40
src/allmydata/test/blocking.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
import signal
|
||||||
|
import threading
|
||||||
|
|
||||||
|
from twisted.internet import reactor
|
||||||
|
|
||||||
|
|
||||||
|
def print_stacks():
|
||||||
|
print("Uh oh, something is blocking the event loop!")
|
||||||
|
current_thread = threading.get_ident()
|
||||||
|
for thread_id, frame in sys._current_frames().items():
|
||||||
|
if thread_id == current_thread:
|
||||||
|
traceback.print_stack(frame, limit=10)
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
def catch_blocking_in_event_loop(test=None):
|
||||||
|
"""
|
||||||
|
Print tracebacks if the event loop is blocked for more than a short amount
|
||||||
|
of time.
|
||||||
|
"""
|
||||||
|
signal.signal(signal.SIGALRM, lambda *args: print_stacks())
|
||||||
|
|
||||||
|
current_scheduled = [None]
|
||||||
|
|
||||||
|
def cancel_and_rerun():
|
||||||
|
signal.setitimer(signal.ITIMER_REAL, 0)
|
||||||
|
signal.setitimer(signal.ITIMER_REAL, 0.015)
|
||||||
|
current_scheduled[0] = reactor.callLater(0.01, cancel_and_rerun)
|
||||||
|
|
||||||
|
cancel_and_rerun()
|
||||||
|
|
||||||
|
def cleanup():
|
||||||
|
signal.signal(signal.SIGALRM, signal.SIG_DFL)
|
||||||
|
signal.setitimer(signal.ITIMER_REAL, 0)
|
||||||
|
current_scheduled[0].cancel()
|
||||||
|
|
||||||
|
if test is not None:
|
||||||
|
test.addCleanup(cleanup)
|
@ -424,7 +424,7 @@ class Check(GridTestMixin, CLITestMixin, unittest.TestCase):
|
|||||||
def _stash_uri(n):
|
def _stash_uri(n):
|
||||||
self.uriList.append(n.get_uri())
|
self.uriList.append(n.get_uri())
|
||||||
d.addCallback(_stash_uri)
|
d.addCallback(_stash_uri)
|
||||||
d = c0.create_dirnode()
|
d.addCallback(lambda _: c0.create_dirnode())
|
||||||
d.addCallback(_stash_uri)
|
d.addCallback(_stash_uri)
|
||||||
|
|
||||||
d.addCallback(lambda ign: self.do_cli("check", self.uriList[0], self.uriList[1]))
|
d.addCallback(lambda ign: self.do_cli("check", self.uriList[0], self.uriList[1]))
|
||||||
|
@ -685,6 +685,10 @@ class SystemTestMixin(pollmixin.PollMixin, testutil.StallMixin):
|
|||||||
REDUCE_HTTP_CLIENT_TIMEOUT : bool = True
|
REDUCE_HTTP_CLIENT_TIMEOUT : bool = True
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
if os.getenv("TAHOE_DEBUG_BLOCKING") == "1":
|
||||||
|
from .blocking import catch_blocking_in_event_loop
|
||||||
|
catch_blocking_in_event_loop(self)
|
||||||
|
|
||||||
self._http_client_pools = []
|
self._http_client_pools = []
|
||||||
http_client.StorageClientFactory.start_test_mode(self._got_new_http_connection_pool)
|
http_client.StorageClientFactory.start_test_mode(self._got_new_http_connection_pool)
|
||||||
self.addCleanup(http_client.StorageClientFactory.stop_test_mode)
|
self.addCleanup(http_client.StorageClientFactory.stop_test_mode)
|
||||||
|
@ -252,7 +252,7 @@ def create_no_network_client(basedir):
|
|||||||
i2p_provider=None,
|
i2p_provider=None,
|
||||||
tor_provider=None,
|
tor_provider=None,
|
||||||
introducer_clients=[],
|
introducer_clients=[],
|
||||||
storage_farm_broker=storage_broker,
|
storage_farm_broker=storage_broker
|
||||||
)
|
)
|
||||||
# this is a (pre-existing) reference-cycle and also a bad idea, see:
|
# this is a (pre-existing) reference-cycle and also a bad idea, see:
|
||||||
# https://tahoe-lafs.org/trac/tahoe-lafs/ticket/2949
|
# https://tahoe-lafs.org/trac/tahoe-lafs/ticket/2949
|
||||||
|
@ -43,6 +43,7 @@ from testtools.matchers import Equals
|
|||||||
from zope.interface import implementer
|
from zope.interface import implementer
|
||||||
|
|
||||||
from ..util.deferredutil import async_to_deferred
|
from ..util.deferredutil import async_to_deferred
|
||||||
|
from ..util.cputhreadpool import disable_thread_pool_for_test
|
||||||
from .common import SyncTestCase
|
from .common import SyncTestCase
|
||||||
from ..storage.http_common import (
|
from ..storage.http_common import (
|
||||||
get_content_type,
|
get_content_type,
|
||||||
@ -345,6 +346,7 @@ class CustomHTTPServerTests(SyncTestCase):
|
|||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super(CustomHTTPServerTests, self).setUp()
|
super(CustomHTTPServerTests, self).setUp()
|
||||||
|
disable_thread_pool_for_test(self)
|
||||||
StorageClientFactory.start_test_mode(
|
StorageClientFactory.start_test_mode(
|
||||||
lambda pool: self.addCleanup(pool.closeCachedConnections)
|
lambda pool: self.addCleanup(pool.closeCachedConnections)
|
||||||
)
|
)
|
||||||
@ -701,6 +703,7 @@ class GenericHTTPAPITests(SyncTestCase):
|
|||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super(GenericHTTPAPITests, self).setUp()
|
super(GenericHTTPAPITests, self).setUp()
|
||||||
|
disable_thread_pool_for_test(self)
|
||||||
self.http = self.useFixture(HttpTestFixture())
|
self.http = self.useFixture(HttpTestFixture())
|
||||||
|
|
||||||
def test_missing_authentication(self) -> None:
|
def test_missing_authentication(self) -> None:
|
||||||
@ -808,6 +811,7 @@ class ImmutableHTTPAPITests(SyncTestCase):
|
|||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super(ImmutableHTTPAPITests, self).setUp()
|
super(ImmutableHTTPAPITests, self).setUp()
|
||||||
|
disable_thread_pool_for_test(self)
|
||||||
self.http = self.useFixture(HttpTestFixture())
|
self.http = self.useFixture(HttpTestFixture())
|
||||||
self.imm_client = StorageClientImmutables(self.http.client)
|
self.imm_client = StorageClientImmutables(self.http.client)
|
||||||
self.general_client = StorageClientGeneral(self.http.client)
|
self.general_client = StorageClientGeneral(self.http.client)
|
||||||
@ -1317,6 +1321,7 @@ class MutableHTTPAPIsTests(SyncTestCase):
|
|||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super(MutableHTTPAPIsTests, self).setUp()
|
super(MutableHTTPAPIsTests, self).setUp()
|
||||||
|
disable_thread_pool_for_test(self)
|
||||||
self.http = self.useFixture(HttpTestFixture())
|
self.http = self.useFixture(HttpTestFixture())
|
||||||
self.mut_client = StorageClientMutables(self.http.client)
|
self.mut_client = StorageClientMutables(self.http.client)
|
||||||
|
|
||||||
@ -1734,6 +1739,7 @@ class ImmutableSharedTests(SharedImmutableMutableTestsMixin, SyncTestCase):
|
|||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super(ImmutableSharedTests, self).setUp()
|
super(ImmutableSharedTests, self).setUp()
|
||||||
|
disable_thread_pool_for_test(self)
|
||||||
self.http = self.useFixture(HttpTestFixture())
|
self.http = self.useFixture(HttpTestFixture())
|
||||||
self.client = self.clientFactory(self.http.client)
|
self.client = self.clientFactory(self.http.client)
|
||||||
self.general_client = StorageClientGeneral(self.http.client)
|
self.general_client = StorageClientGeneral(self.http.client)
|
||||||
@ -1788,6 +1794,7 @@ class MutableSharedTests(SharedImmutableMutableTestsMixin, SyncTestCase):
|
|||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super(MutableSharedTests, self).setUp()
|
super(MutableSharedTests, self).setUp()
|
||||||
|
disable_thread_pool_for_test(self)
|
||||||
self.http = self.useFixture(HttpTestFixture())
|
self.http = self.useFixture(HttpTestFixture())
|
||||||
self.client = self.clientFactory(self.http.client)
|
self.client = self.clientFactory(self.http.client)
|
||||||
self.general_client = StorageClientGeneral(self.http.client)
|
self.general_client = StorageClientGeneral(self.http.client)
|
||||||
|
@ -18,7 +18,6 @@ import json
|
|||||||
from threading import current_thread
|
from threading import current_thread
|
||||||
|
|
||||||
from twisted.trial import unittest
|
from twisted.trial import unittest
|
||||||
from twisted.internet import reactor
|
|
||||||
from foolscap.api import Violation, RemoteException
|
from foolscap.api import Violation, RemoteException
|
||||||
|
|
||||||
from allmydata.util import idlib, mathutil
|
from allmydata.util import idlib, mathutil
|
||||||
@ -28,7 +27,7 @@ from allmydata.util import pollmixin
|
|||||||
from allmydata.util import yamlutil
|
from allmydata.util import yamlutil
|
||||||
from allmydata.util import rrefutil
|
from allmydata.util import rrefutil
|
||||||
from allmydata.util.fileutil import EncryptedTemporaryFile
|
from allmydata.util.fileutil import EncryptedTemporaryFile
|
||||||
from allmydata.util.cputhreadpool import defer_to_thread
|
from allmydata.util.cputhreadpool import defer_to_thread, disable_thread_pool_for_test
|
||||||
from allmydata.test.common_util import ReallyEqualMixin
|
from allmydata.test.common_util import ReallyEqualMixin
|
||||||
from .no_network import fireNow, LocalWrapper
|
from .no_network import fireNow, LocalWrapper
|
||||||
|
|
||||||
@ -613,20 +612,25 @@ class CPUThreadPool(unittest.TestCase):
|
|||||||
return current_thread(), args, kwargs
|
return current_thread(), args, kwargs
|
||||||
|
|
||||||
this_thread = current_thread().ident
|
this_thread = current_thread().ident
|
||||||
result = defer_to_thread(reactor, f, 1, 3, key=4, value=5)
|
thread, args, kwargs = await defer_to_thread(f, 1, 3, key=4, value=5)
|
||||||
|
|
||||||
# Callbacks run in the correct thread:
|
|
||||||
callback_thread_ident = []
|
|
||||||
def passthrough(result):
|
|
||||||
callback_thread_ident.append(current_thread().ident)
|
|
||||||
return result
|
|
||||||
|
|
||||||
result.addCallback(passthrough)
|
|
||||||
|
|
||||||
# The task ran in a different thread:
|
# The task ran in a different thread:
|
||||||
thread, args, kwargs = await result
|
|
||||||
self.assertEqual(callback_thread_ident[0], this_thread)
|
|
||||||
self.assertNotEqual(thread.ident, this_thread)
|
self.assertNotEqual(thread.ident, this_thread)
|
||||||
self.assertEqual(args, (1, 3))
|
self.assertEqual(args, (1, 3))
|
||||||
self.assertEqual(kwargs, {"key": 4, "value": 5})
|
self.assertEqual(kwargs, {"key": 4, "value": 5})
|
||||||
|
|
||||||
|
async def test_when_disabled_runs_in_same_thread(self):
|
||||||
|
"""
|
||||||
|
If the CPU thread pool is disabled, the given function runs in the
|
||||||
|
current thread.
|
||||||
|
"""
|
||||||
|
disable_thread_pool_for_test(self)
|
||||||
|
def f(*args, **kwargs):
|
||||||
|
return current_thread().ident, args, kwargs
|
||||||
|
|
||||||
|
this_thread = current_thread().ident
|
||||||
|
thread, args, kwargs = await defer_to_thread(f, 1, 3, key=4, value=5)
|
||||||
|
|
||||||
|
self.assertEqual(thread, this_thread)
|
||||||
|
self.assertEqual(args, (1, 3))
|
||||||
|
self.assertEqual(kwargs, {"key": 4, "value": 5})
|
||||||
|
@ -15,16 +15,15 @@ scheduler affinity or cgroups, but that's not the end of the world.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import TypeVar, Callable, cast
|
from typing import TypeVar, Callable
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import threading
|
import threading
|
||||||
from typing_extensions import ParamSpec
|
from typing_extensions import ParamSpec
|
||||||
|
from unittest import TestCase
|
||||||
|
|
||||||
from twisted.python.threadpool import ThreadPool
|
from twisted.python.threadpool import ThreadPool
|
||||||
from twisted.internet.defer import Deferred
|
|
||||||
from twisted.internet.threads import deferToThreadPool
|
from twisted.internet.threads import deferToThreadPool
|
||||||
from twisted.internet.interfaces import IReactorFromThreads
|
from twisted.internet import reactor
|
||||||
|
|
||||||
|
|
||||||
_CPU_THREAD_POOL = ThreadPool(minthreads=0, maxthreads=os.cpu_count(), name="TahoeCPU")
|
_CPU_THREAD_POOL = ThreadPool(minthreads=0, maxthreads=os.cpu_count(), name="TahoeCPU")
|
||||||
if hasattr(threading, "_register_atexit"):
|
if hasattr(threading, "_register_atexit"):
|
||||||
@ -46,14 +45,43 @@ _CPU_THREAD_POOL.start()
|
|||||||
P = ParamSpec("P")
|
P = ParamSpec("P")
|
||||||
R = TypeVar("R")
|
R = TypeVar("R")
|
||||||
|
|
||||||
|
# Is running in a thread pool disabled? Should only be true in synchronous unit
|
||||||
|
# tests.
|
||||||
|
_DISABLED = False
|
||||||
|
|
||||||
|
|
||||||
|
async def defer_to_thread(f: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
|
||||||
|
"""
|
||||||
|
Run the function in a thread, return the result.
|
||||||
|
|
||||||
|
However, if ``disable_thread_pool_for_test()`` was called the function will
|
||||||
|
be called synchronously inside the current thread.
|
||||||
|
|
||||||
|
To reduce chances of synchronous tests being misleading as a result, this
|
||||||
|
is an async function on presumption that will encourage immediate ``await``ing.
|
||||||
|
"""
|
||||||
|
if _DISABLED:
|
||||||
|
return f(*args, **kwargs)
|
||||||
|
|
||||||
def defer_to_thread(
|
|
||||||
reactor: IReactorFromThreads, f: Callable[P, R], *args: P.args, **kwargs: P.kwargs
|
|
||||||
) -> Deferred[R]:
|
|
||||||
"""Run the function in a thread, return the result as a ``Deferred``."""
|
|
||||||
# deferToThreadPool has no type annotations...
|
# deferToThreadPool has no type annotations...
|
||||||
result = deferToThreadPool(reactor, _CPU_THREAD_POOL, f, *args, **kwargs)
|
result = await deferToThreadPool(reactor, _CPU_THREAD_POOL, f, *args, **kwargs)
|
||||||
return cast(Deferred[R], result)
|
return result
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["defer_to_thread"]
|
def disable_thread_pool_for_test(test: TestCase) -> None:
|
||||||
|
"""
|
||||||
|
For the duration of the test, calls to ``defer_to_thread()`` will actually
|
||||||
|
run synchronously, which is useful for synchronous unit tests.
|
||||||
|
"""
|
||||||
|
global _DISABLED
|
||||||
|
|
||||||
|
def restore():
|
||||||
|
global _DISABLED
|
||||||
|
_DISABLED = False
|
||||||
|
|
||||||
|
test.addCleanup(restore)
|
||||||
|
|
||||||
|
_DISABLED = True
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["defer_to_thread", "disable_thread_pool_for_test"]
|
||||||
|
Loading…
Reference in New Issue
Block a user