Merge pull request #1342 from tahoe-lafs/4068-reduce-cpu-in-eventloop-thread

Reduce blocking operations in eventloop thread
This commit is contained in:
Itamar Turner-Trauring 2023-10-20 10:48:27 -04:00 committed by GitHub
commit 4fbf31b00c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 185 additions and 73 deletions

View File

@ -0,0 +1 @@
Some operations now run in threads, improving the responsiveness of Tahoe nodes.

View File

@ -47,6 +47,8 @@ from allmydata.util.abbreviate import parse_abbreviated_size
from allmydata.util.time_format import parse_duration, parse_date
from allmydata.util.i2p_provider import create as create_i2p_provider
from allmydata.util.tor_provider import create as create_tor_provider, _Provider as TorProvider
from allmydata.util.cputhreadpool import defer_to_thread
from allmydata.util.deferredutil import async_to_deferred
from allmydata.stats import StatsProvider
from allmydata.history import History
from allmydata.interfaces import (
@ -170,12 +172,18 @@ class KeyGenerator(object):
"""I create RSA keys for mutable files. Each call to generate() returns a
single keypair."""
def generate(self):
"""I return a Deferred that fires with a (verifyingkey, signingkey)
pair. The returned key will be 2048 bit"""
@async_to_deferred
async def generate(self) -> tuple[rsa.PublicKey, rsa.PrivateKey]:
"""
I return a Deferred that fires with a (verifyingkey, signingkey)
pair. The returned key will be 2048 bit.
"""
keysize = 2048
signer, verifier = rsa.create_signing_keypair(keysize)
return defer.succeed( (verifier, signer) )
private, public = await defer_to_thread(
rsa.create_signing_keypair, keysize
)
return public, private
class Terminator(service.Service):
def __init__(self):

View File

@ -13,9 +13,10 @@ if PY2:
from builtins import filter, map, zip, ascii, chr, hex, input, next, oct, open, pow, round, super, bytes, dict, list, object, range, str, max, min # noqa: F401
from zope.interface import implementer
from twisted.internet import defer
from allmydata.util import mathutil
from allmydata.util.assertutil import precondition
from allmydata.util.cputhreadpool import defer_to_thread
from allmydata.util.deferredutil import async_to_deferred
from allmydata.interfaces import ICodecEncoder, ICodecDecoder
import zfec
@ -45,7 +46,8 @@ class CRSEncoder(object):
def get_block_size(self):
return self.share_size
def encode(self, inshares, desired_share_ids=None):
@async_to_deferred
async def encode(self, inshares, desired_share_ids=None):
precondition(desired_share_ids is None or len(desired_share_ids) <= self.max_shares, desired_share_ids, self.max_shares)
if desired_share_ids is None:
@ -53,9 +55,8 @@ class CRSEncoder(object):
for inshare in inshares:
assert len(inshare) == self.share_size, (len(inshare), self.share_size, self.data_size, self.required_shares)
shares = self.encoder.encode(inshares, desired_share_ids)
return defer.succeed((shares, desired_share_ids))
shares = await defer_to_thread(self.encoder.encode, inshares, desired_share_ids)
return (shares, desired_share_ids)
def encode_proposal(self, data, desired_share_ids=None):
raise NotImplementedError()
@ -77,14 +78,17 @@ class CRSDecoder(object):
def get_needed_shares(self):
return self.required_shares
def decode(self, some_shares, their_shareids):
@async_to_deferred
async def decode(self, some_shares, their_shareids):
precondition(len(some_shares) == len(their_shareids),
len(some_shares), len(their_shareids))
precondition(len(some_shares) == self.required_shares,
len(some_shares), self.required_shares)
data = self.decoder.decode(some_shares,
[int(s) for s in their_shareids])
return defer.succeed(data)
return await defer_to_thread(
self.decoder.decode,
some_shares,
[int(s) for s in their_shareids]
)
def parse_params(serializedparams):
pieces = serializedparams.split(b"-")

View File

@ -23,6 +23,8 @@ from allmydata.interfaces import IPublishStatus, SDMF_VERSION, MDMF_VERSION, \
IMutableUploadable
from allmydata.util import base32, hashutil, mathutil, log
from allmydata.util.dictutil import DictOfSets
from allmydata.util.deferredutil import async_to_deferred
from allmydata.util.cputhreadpool import defer_to_thread
from allmydata import hashtree, codec
from allmydata.storage.server import si_b2a
from foolscap.api import eventually, fireEventually
@ -706,7 +708,8 @@ class Publish(object):
writer.put_salt(salt)
def _encode_segment(self, segnum):
@async_to_deferred
async def _encode_segment(self, segnum):
"""
I encrypt and encode the segment segnum.
"""
@ -726,13 +729,17 @@ class Publish(object):
assert len(data) == segsize, len(data)
salt = os.urandom(16)
key = hashutil.ssk_readkey_data_hash(salt, self.readkey)
self._status.set_status("Encrypting")
encryptor = aes.create_encryptor(key)
crypttext = aes.encrypt_data(encryptor, data)
assert len(crypttext) == len(data)
def encrypt(readkey):
salt = os.urandom(16)
key = hashutil.ssk_readkey_data_hash(salt, readkey)
encryptor = aes.create_encryptor(key)
crypttext = aes.encrypt_data(encryptor, data)
assert len(crypttext) == len(data)
return salt, crypttext
salt, crypttext = await defer_to_thread(encrypt, self.readkey)
now = time.time()
self._status.accumulate_encrypt_time(now - started)
@ -753,16 +760,14 @@ class Publish(object):
piece = piece + b"\x00"*(piece_size - len(piece)) # padding
crypttext_pieces[i] = piece
assert len(piece) == piece_size
d = fec.encode(crypttext_pieces)
def _done_encoding(res):
elapsed = time.time() - started
self._status.accumulate_encode_time(elapsed)
return (res, salt)
d.addCallback(_done_encoding)
return d
res = await fec.encode(crypttext_pieces)
elapsed = time.time() - started
self._status.accumulate_encode_time(elapsed)
return (res, salt)
def _push_segment(self, encoded_and_salt, segnum):
@async_to_deferred
async def _push_segment(self, encoded_and_salt, segnum):
"""
I push (data, salt) as segment number segnum.
"""
@ -776,7 +781,7 @@ class Publish(object):
hashed = salt + sharedata
else:
hashed = sharedata
block_hash = hashutil.block_hash(hashed)
block_hash = await defer_to_thread(hashutil.block_hash, hashed)
self.blockhashes[shareid][segnum] = block_hash
# find the writer for this share
writers = self.writers[shareid]

View File

@ -20,6 +20,7 @@ from allmydata.interfaces import IRetrieveStatus, NotEnoughSharesError, \
from allmydata.util.assertutil import _assert, precondition
from allmydata.util import hashutil, log, mathutil, deferredutil
from allmydata.util.dictutil import DictOfSets
from allmydata.util.cputhreadpool import defer_to_thread
from allmydata import hashtree, codec
from allmydata.storage.server import si_b2a
@ -734,7 +735,8 @@ class Retrieve(object):
return None
def _validate_block(self, results, segnum, reader, server, started):
@deferredutil.async_to_deferred
async def _validate_block(self, results, segnum, reader, server, started):
"""
I validate a block from one share on a remote server.
"""
@ -767,9 +769,9 @@ class Retrieve(object):
"block hash tree failure: %s" % e)
if self._version == MDMF_VERSION:
blockhash = hashutil.block_hash(salt + block)
blockhash = await defer_to_thread(hashutil.block_hash, salt + block)
else:
blockhash = hashutil.block_hash(block)
blockhash = await defer_to_thread(hashutil.block_hash, block)
# If this works without an error, then validation is
# successful.
try:
@ -893,8 +895,8 @@ class Retrieve(object):
d.addCallback(_process)
return d
def _decrypt_segment(self, segment_and_salt):
@deferredutil.async_to_deferred
async def _decrypt_segment(self, segment_and_salt):
"""
I take a single segment and its salt, and decrypt it. I return
the plaintext of the segment that is in my argument.
@ -903,9 +905,14 @@ class Retrieve(object):
self._set_current_status("decrypting")
self.log("decrypting segment %d" % self._current_segment)
started = time.time()
key = hashutil.ssk_readkey_data_hash(salt, self._node.get_readkey())
decryptor = aes.create_decryptor(key)
plaintext = aes.decrypt_data(decryptor, segment)
readkey = self._node.get_readkey()
def decrypt():
key = hashutil.ssk_readkey_data_hash(salt, readkey)
decryptor = aes.create_decryptor(key)
return aes.decrypt_data(decryptor, segment)
plaintext = await defer_to_thread(decrypt)
self._status.accumulate_decrypt_time(time.time() - started)
return plaintext

View File

@ -64,6 +64,7 @@ from .common import si_b2a, si_to_human_readable
from ..util.hashutil import timing_safe_compare
from ..util.deferredutil import async_to_deferred
from ..util.tor_provider import _Provider as TorProvider
from ..util.cputhreadpool import defer_to_thread
try:
from txtorcon import Tor # type: ignore
@ -473,7 +474,8 @@ class StorageClient(object):
into corresponding HTTP headers.
If ``message_to_serialize`` is set, it will be serialized (by default
with CBOR) and set as the request body.
with CBOR) and set as the request body. It should not be mutated
during execution of this function!
Default timeout is 60 seconds.
"""
@ -539,7 +541,7 @@ class StorageClient(object):
"Can't use both `message_to_serialize` and `data` "
"as keyword arguments at the same time"
)
kwargs["data"] = dumps(message_to_serialize)
kwargs["data"] = await defer_to_thread(dumps, message_to_serialize)
headers.addRawHeader("Content-Type", CBOR_MIME_TYPE)
response = await self._treq.request(
@ -557,8 +559,12 @@ class StorageClient(object):
if content_type == CBOR_MIME_TYPE:
f = await limited_content(response, self._clock)
data = f.read()
schema.validate_cbor(data)
return loads(data)
def validate_and_decode():
schema.validate_cbor(data)
return loads(data)
return await defer_to_thread(validate_and_decode)
else:
raise ClientException(
-1,
@ -1232,7 +1238,8 @@ class StorageClientMutables:
return cast(
Set[int],
await self._client.decode_cbor(
response, _SCHEMAS["mutable_list_shares"]
response,
_SCHEMAS["mutable_list_shares"],
),
)
else:

View File

@ -638,10 +638,7 @@ async def read_encoded(
# Pycddl will release the GIL when validating larger documents, so
# let's take advantage of multiple CPUs:
if size > 10_000:
await defer_to_thread(reactor, schema.validate_cbor, message)
else:
schema.validate_cbor(message)
await defer_to_thread(schema.validate_cbor, message)
# The CBOR parser will allocate more memory, but at least we can feed
# it the file-like object, so that if it's large it won't be make two

View File

@ -0,0 +1,40 @@
import sys
import traceback
import signal
import threading
from twisted.internet import reactor
def print_stacks():
print("Uh oh, something is blocking the event loop!")
current_thread = threading.get_ident()
for thread_id, frame in sys._current_frames().items():
if thread_id == current_thread:
traceback.print_stack(frame, limit=10)
break
def catch_blocking_in_event_loop(test=None):
"""
Print tracebacks if the event loop is blocked for more than a short amount
of time.
"""
signal.signal(signal.SIGALRM, lambda *args: print_stacks())
current_scheduled = [None]
def cancel_and_rerun():
signal.setitimer(signal.ITIMER_REAL, 0)
signal.setitimer(signal.ITIMER_REAL, 0.015)
current_scheduled[0] = reactor.callLater(0.01, cancel_and_rerun)
cancel_and_rerun()
def cleanup():
signal.signal(signal.SIGALRM, signal.SIG_DFL)
signal.setitimer(signal.ITIMER_REAL, 0)
current_scheduled[0].cancel()
if test is not None:
test.addCleanup(cleanup)

View File

@ -424,7 +424,7 @@ class Check(GridTestMixin, CLITestMixin, unittest.TestCase):
def _stash_uri(n):
self.uriList.append(n.get_uri())
d.addCallback(_stash_uri)
d = c0.create_dirnode()
d.addCallback(lambda _: c0.create_dirnode())
d.addCallback(_stash_uri)
d.addCallback(lambda ign: self.do_cli("check", self.uriList[0], self.uriList[1]))

View File

@ -685,6 +685,10 @@ class SystemTestMixin(pollmixin.PollMixin, testutil.StallMixin):
REDUCE_HTTP_CLIENT_TIMEOUT : bool = True
def setUp(self):
if os.getenv("TAHOE_DEBUG_BLOCKING") == "1":
from .blocking import catch_blocking_in_event_loop
catch_blocking_in_event_loop(self)
self._http_client_pools = []
http_client.StorageClientFactory.start_test_mode(self._got_new_http_connection_pool)
self.addCleanup(http_client.StorageClientFactory.stop_test_mode)

View File

@ -252,7 +252,7 @@ def create_no_network_client(basedir):
i2p_provider=None,
tor_provider=None,
introducer_clients=[],
storage_farm_broker=storage_broker,
storage_farm_broker=storage_broker
)
# this is a (pre-existing) reference-cycle and also a bad idea, see:
# https://tahoe-lafs.org/trac/tahoe-lafs/ticket/2949

View File

@ -43,6 +43,7 @@ from testtools.matchers import Equals
from zope.interface import implementer
from ..util.deferredutil import async_to_deferred
from ..util.cputhreadpool import disable_thread_pool_for_test
from .common import SyncTestCase
from ..storage.http_common import (
get_content_type,
@ -345,6 +346,7 @@ class CustomHTTPServerTests(SyncTestCase):
def setUp(self):
super(CustomHTTPServerTests, self).setUp()
disable_thread_pool_for_test(self)
StorageClientFactory.start_test_mode(
lambda pool: self.addCleanup(pool.closeCachedConnections)
)
@ -701,6 +703,7 @@ class GenericHTTPAPITests(SyncTestCase):
def setUp(self):
super(GenericHTTPAPITests, self).setUp()
disable_thread_pool_for_test(self)
self.http = self.useFixture(HttpTestFixture())
def test_missing_authentication(self) -> None:
@ -808,6 +811,7 @@ class ImmutableHTTPAPITests(SyncTestCase):
def setUp(self):
super(ImmutableHTTPAPITests, self).setUp()
disable_thread_pool_for_test(self)
self.http = self.useFixture(HttpTestFixture())
self.imm_client = StorageClientImmutables(self.http.client)
self.general_client = StorageClientGeneral(self.http.client)
@ -1317,6 +1321,7 @@ class MutableHTTPAPIsTests(SyncTestCase):
def setUp(self):
super(MutableHTTPAPIsTests, self).setUp()
disable_thread_pool_for_test(self)
self.http = self.useFixture(HttpTestFixture())
self.mut_client = StorageClientMutables(self.http.client)
@ -1734,6 +1739,7 @@ class ImmutableSharedTests(SharedImmutableMutableTestsMixin, SyncTestCase):
def setUp(self):
super(ImmutableSharedTests, self).setUp()
disable_thread_pool_for_test(self)
self.http = self.useFixture(HttpTestFixture())
self.client = self.clientFactory(self.http.client)
self.general_client = StorageClientGeneral(self.http.client)
@ -1788,6 +1794,7 @@ class MutableSharedTests(SharedImmutableMutableTestsMixin, SyncTestCase):
def setUp(self):
super(MutableSharedTests, self).setUp()
disable_thread_pool_for_test(self)
self.http = self.useFixture(HttpTestFixture())
self.client = self.clientFactory(self.http.client)
self.general_client = StorageClientGeneral(self.http.client)

View File

@ -18,7 +18,6 @@ import json
from threading import current_thread
from twisted.trial import unittest
from twisted.internet import reactor
from foolscap.api import Violation, RemoteException
from allmydata.util import idlib, mathutil
@ -28,7 +27,7 @@ from allmydata.util import pollmixin
from allmydata.util import yamlutil
from allmydata.util import rrefutil
from allmydata.util.fileutil import EncryptedTemporaryFile
from allmydata.util.cputhreadpool import defer_to_thread
from allmydata.util.cputhreadpool import defer_to_thread, disable_thread_pool_for_test
from allmydata.test.common_util import ReallyEqualMixin
from .no_network import fireNow, LocalWrapper
@ -613,20 +612,25 @@ class CPUThreadPool(unittest.TestCase):
return current_thread(), args, kwargs
this_thread = current_thread().ident
result = defer_to_thread(reactor, f, 1, 3, key=4, value=5)
# Callbacks run in the correct thread:
callback_thread_ident = []
def passthrough(result):
callback_thread_ident.append(current_thread().ident)
return result
result.addCallback(passthrough)
thread, args, kwargs = await defer_to_thread(f, 1, 3, key=4, value=5)
# The task ran in a different thread:
thread, args, kwargs = await result
self.assertEqual(callback_thread_ident[0], this_thread)
self.assertNotEqual(thread.ident, this_thread)
self.assertEqual(args, (1, 3))
self.assertEqual(kwargs, {"key": 4, "value": 5})
async def test_when_disabled_runs_in_same_thread(self):
"""
If the CPU thread pool is disabled, the given function runs in the
current thread.
"""
disable_thread_pool_for_test(self)
def f(*args, **kwargs):
return current_thread().ident, args, kwargs
this_thread = current_thread().ident
thread, args, kwargs = await defer_to_thread(f, 1, 3, key=4, value=5)
self.assertEqual(thread, this_thread)
self.assertEqual(args, (1, 3))
self.assertEqual(kwargs, {"key": 4, "value": 5})

View File

@ -15,16 +15,15 @@ scheduler affinity or cgroups, but that's not the end of the world.
"""
import os
from typing import TypeVar, Callable, cast
from typing import TypeVar, Callable
from functools import partial
import threading
from typing_extensions import ParamSpec
from unittest import TestCase
from twisted.python.threadpool import ThreadPool
from twisted.internet.defer import Deferred
from twisted.internet.threads import deferToThreadPool
from twisted.internet.interfaces import IReactorFromThreads
from twisted.internet import reactor
_CPU_THREAD_POOL = ThreadPool(minthreads=0, maxthreads=os.cpu_count(), name="TahoeCPU")
if hasattr(threading, "_register_atexit"):
@ -46,14 +45,43 @@ _CPU_THREAD_POOL.start()
P = ParamSpec("P")
R = TypeVar("R")
# Is running in a thread pool disabled? Should only be true in synchronous unit
# tests.
_DISABLED = False
async def defer_to_thread(f: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
"""
Run the function in a thread, return the result.
However, if ``disable_thread_pool_for_test()`` was called the function will
be called synchronously inside the current thread.
To reduce chances of synchronous tests being misleading as a result, this
is an async function on presumption that will encourage immediate ``await``ing.
"""
if _DISABLED:
return f(*args, **kwargs)
def defer_to_thread(
reactor: IReactorFromThreads, f: Callable[P, R], *args: P.args, **kwargs: P.kwargs
) -> Deferred[R]:
"""Run the function in a thread, return the result as a ``Deferred``."""
# deferToThreadPool has no type annotations...
result = deferToThreadPool(reactor, _CPU_THREAD_POOL, f, *args, **kwargs)
return cast(Deferred[R], result)
result = await deferToThreadPool(reactor, _CPU_THREAD_POOL, f, *args, **kwargs)
return result
__all__ = ["defer_to_thread"]
def disable_thread_pool_for_test(test: TestCase) -> None:
"""
For the duration of the test, calls to ``defer_to_thread()`` will actually
run synchronously, which is useful for synchronous unit tests.
"""
global _DISABLED
def restore():
global _DISABLED
_DISABLED = False
test.addCleanup(restore)
_DISABLED = True
__all__ = ["defer_to_thread", "disable_thread_pool_for_test"]