Switch defer_to_thread() API to hopefully be harder to screw up.

This commit is contained in:
Itamar Turner-Trauring 2023-10-19 13:49:41 -04:00
parent 303e45b1e5
commit 20cfe70d48
4 changed files with 26 additions and 34 deletions

View File

@ -48,6 +48,7 @@ from allmydata.util.time_format import parse_duration, parse_date
from allmydata.util.i2p_provider import create as create_i2p_provider
from allmydata.util.tor_provider import create as create_tor_provider, _Provider as TorProvider
from allmydata.util.cputhreadpool import defer_to_thread
from allmydata.util.deferredutil import async_to_deferred
from allmydata.stats import StatsProvider
from allmydata.history import History
from allmydata.interfaces import (
@ -171,15 +172,17 @@ class KeyGenerator(object):
"""I create RSA keys for mutable files. Each call to generate() returns a
single keypair."""
def generate(self) -> defer.Deferred[tuple[rsa.PublicKey, rsa.PrivateKey]]:
@async_to_deferred
async def generate(self) -> tuple[rsa.PublicKey, rsa.PrivateKey]:
"""
I return a Deferred that fires with a (verifyingkey, signingkey)
pair. The returned key will be 2048 bit.
"""
keysize = 2048
return defer_to_thread(
private, public = await defer_to_thread(
rsa.create_signing_keypair, keysize
).addCallback(lambda t: (t[1], t[0]))
)
return public, private
class Terminator(service.Service):

View File

@ -16,6 +16,7 @@ from zope.interface import implementer
from allmydata.util import mathutil
from allmydata.util.assertutil import precondition
from allmydata.util.cputhreadpool import defer_to_thread
from allmydata.util.deferredutil import async_to_deferred
from allmydata.interfaces import ICodecEncoder, ICodecDecoder
import zfec
@ -45,7 +46,8 @@ class CRSEncoder(object):
def get_block_size(self):
return self.share_size
def encode(self, inshares, desired_share_ids=None):
@async_to_deferred
async def encode(self, inshares, desired_share_ids=None):
precondition(desired_share_ids is None or len(desired_share_ids) <= self.max_shares, desired_share_ids, self.max_shares)
if desired_share_ids is None:
@ -53,9 +55,8 @@ class CRSEncoder(object):
for inshare in inshares:
assert len(inshare) == self.share_size, (len(inshare), self.share_size, self.data_size, self.required_shares)
d = defer_to_thread(self.encoder.encode, inshares, desired_share_ids)
d.addCallback(lambda shares: (shares, desired_share_ids))
return d
shares = await defer_to_thread(self.encoder.encode, inshares, desired_share_ids)
return (shares, desired_share_ids)
def encode_proposal(self, data, desired_share_ids=None):
raise NotImplementedError()
@ -77,12 +78,13 @@ class CRSDecoder(object):
def get_needed_shares(self):
return self.required_shares
def decode(self, some_shares, their_shareids):
@async_to_deferred
async def decode(self, some_shares, their_shareids):
precondition(len(some_shares) == len(their_shareids),
len(some_shares), len(their_shareids))
precondition(len(some_shares) == self.required_shares,
len(some_shares), self.required_shares)
return defer_to_thread(
return await defer_to_thread(
self.decoder.decode,
some_shares,
[int(s) for s in their_shareids]

View File

@ -612,24 +612,14 @@ class CPUThreadPool(unittest.TestCase):
return current_thread(), args, kwargs
this_thread = current_thread().ident
result = defer_to_thread(f, 1, 3, key=4, value=5)
# Callbacks run in the correct thread:
callback_thread_ident = []
def passthrough(result):
callback_thread_ident.append(current_thread().ident)
return result
result.addCallback(passthrough)
thread, args, kwargs = await defer_to_thread(f, 1, 3, key=4, value=5)
# The task ran in a different thread:
thread, args, kwargs = await result
self.assertEqual(callback_thread_ident[0], this_thread)
self.assertNotEqual(thread.ident, this_thread)
self.assertEqual(args, (1, 3))
self.assertEqual(kwargs, {"key": 4, "value": 5})
def test_when_disabled_runs_in_same_thread(self):
async def test_when_disabled_runs_in_same_thread(self):
"""
If the CPU thread pool is disabled, the given function runs in the
current thread.
@ -639,10 +629,7 @@ class CPUThreadPool(unittest.TestCase):
return current_thread().ident, args, kwargs
this_thread = current_thread().ident
result = defer_to_thread(f, 1, 3, key=4, value=5)
l = []
result.addCallback(l.append)
thread, args, kwargs = l[0]
thread, args, kwargs = await defer_to_thread(f, 1, 3, key=4, value=5)
self.assertEqual(thread, this_thread)
self.assertEqual(args, (1, 3))

View File

@ -15,14 +15,13 @@ scheduler affinity or cgroups, but that's not the end of the world.
"""
import os
from typing import TypeVar, Callable, cast
from typing import TypeVar, Callable
from functools import partial
import threading
from typing_extensions import ParamSpec
from unittest import TestCase
from twisted.python.threadpool import ThreadPool
from twisted.internet.defer import Deferred, maybeDeferred
from twisted.internet.threads import deferToThreadPool
from twisted.internet import reactor
@ -51,21 +50,22 @@ R = TypeVar("R")
_DISABLED = False
def defer_to_thread(
f: Callable[P, R], *args: P.args, **kwargs: P.kwargs
) -> Deferred[R]:
async def defer_to_thread(f: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
"""
Run the function in a thread, return the result as a ``Deferred``.
Run the function in a thread, return the result.
However, if ``disable_thread_pool_for_test()`` was called the function will
be called synchronously inside the current thread.
To reduce chances of synchronous tests being misleading as a result, this
is an async function on presumption that will encourage immediate ``await``ing.
"""
if _DISABLED:
return maybeDeferred(f, *args, **kwargs)
return f(*args, **kwargs)
# deferToThreadPool has no type annotations...
result = deferToThreadPool(reactor, _CPU_THREAD_POOL, f, *args, **kwargs)
return cast(Deferred[R], result)
result = await deferToThreadPool(reactor, _CPU_THREAD_POOL, f, *args, **kwargs)
return result
def disable_thread_pool_for_test(test: TestCase) -> None: