mirror of
https://github.com/tahoe-lafs/tahoe-lafs.git
synced 2025-02-25 19:21:36 +00:00
Switch defer_to_thread() API to hopefully be harder to screw up.
This commit is contained in:
parent
303e45b1e5
commit
20cfe70d48
@ -48,6 +48,7 @@ from allmydata.util.time_format import parse_duration, parse_date
|
||||
from allmydata.util.i2p_provider import create as create_i2p_provider
|
||||
from allmydata.util.tor_provider import create as create_tor_provider, _Provider as TorProvider
|
||||
from allmydata.util.cputhreadpool import defer_to_thread
|
||||
from allmydata.util.deferredutil import async_to_deferred
|
||||
from allmydata.stats import StatsProvider
|
||||
from allmydata.history import History
|
||||
from allmydata.interfaces import (
|
||||
@ -171,15 +172,17 @@ class KeyGenerator(object):
|
||||
"""I create RSA keys for mutable files. Each call to generate() returns a
|
||||
single keypair."""
|
||||
|
||||
def generate(self) -> defer.Deferred[tuple[rsa.PublicKey, rsa.PrivateKey]]:
|
||||
@async_to_deferred
|
||||
async def generate(self) -> tuple[rsa.PublicKey, rsa.PrivateKey]:
|
||||
"""
|
||||
I return a Deferred that fires with a (verifyingkey, signingkey)
|
||||
pair. The returned key will be 2048 bit.
|
||||
"""
|
||||
keysize = 2048
|
||||
return defer_to_thread(
|
||||
private, public = await defer_to_thread(
|
||||
rsa.create_signing_keypair, keysize
|
||||
).addCallback(lambda t: (t[1], t[0]))
|
||||
)
|
||||
return public, private
|
||||
|
||||
|
||||
class Terminator(service.Service):
|
||||
|
@ -16,6 +16,7 @@ from zope.interface import implementer
|
||||
from allmydata.util import mathutil
|
||||
from allmydata.util.assertutil import precondition
|
||||
from allmydata.util.cputhreadpool import defer_to_thread
|
||||
from allmydata.util.deferredutil import async_to_deferred
|
||||
from allmydata.interfaces import ICodecEncoder, ICodecDecoder
|
||||
import zfec
|
||||
|
||||
@ -45,7 +46,8 @@ class CRSEncoder(object):
|
||||
def get_block_size(self):
|
||||
return self.share_size
|
||||
|
||||
def encode(self, inshares, desired_share_ids=None):
|
||||
@async_to_deferred
|
||||
async def encode(self, inshares, desired_share_ids=None):
|
||||
precondition(desired_share_ids is None or len(desired_share_ids) <= self.max_shares, desired_share_ids, self.max_shares)
|
||||
|
||||
if desired_share_ids is None:
|
||||
@ -53,9 +55,8 @@ class CRSEncoder(object):
|
||||
|
||||
for inshare in inshares:
|
||||
assert len(inshare) == self.share_size, (len(inshare), self.share_size, self.data_size, self.required_shares)
|
||||
d = defer_to_thread(self.encoder.encode, inshares, desired_share_ids)
|
||||
d.addCallback(lambda shares: (shares, desired_share_ids))
|
||||
return d
|
||||
shares = await defer_to_thread(self.encoder.encode, inshares, desired_share_ids)
|
||||
return (shares, desired_share_ids)
|
||||
|
||||
def encode_proposal(self, data, desired_share_ids=None):
|
||||
raise NotImplementedError()
|
||||
@ -77,12 +78,13 @@ class CRSDecoder(object):
|
||||
def get_needed_shares(self):
|
||||
return self.required_shares
|
||||
|
||||
def decode(self, some_shares, their_shareids):
|
||||
@async_to_deferred
|
||||
async def decode(self, some_shares, their_shareids):
|
||||
precondition(len(some_shares) == len(their_shareids),
|
||||
len(some_shares), len(their_shareids))
|
||||
precondition(len(some_shares) == self.required_shares,
|
||||
len(some_shares), self.required_shares)
|
||||
return defer_to_thread(
|
||||
return await defer_to_thread(
|
||||
self.decoder.decode,
|
||||
some_shares,
|
||||
[int(s) for s in their_shareids]
|
||||
|
@ -612,24 +612,14 @@ class CPUThreadPool(unittest.TestCase):
|
||||
return current_thread(), args, kwargs
|
||||
|
||||
this_thread = current_thread().ident
|
||||
result = defer_to_thread(f, 1, 3, key=4, value=5)
|
||||
|
||||
# Callbacks run in the correct thread:
|
||||
callback_thread_ident = []
|
||||
def passthrough(result):
|
||||
callback_thread_ident.append(current_thread().ident)
|
||||
return result
|
||||
|
||||
result.addCallback(passthrough)
|
||||
thread, args, kwargs = await defer_to_thread(f, 1, 3, key=4, value=5)
|
||||
|
||||
# The task ran in a different thread:
|
||||
thread, args, kwargs = await result
|
||||
self.assertEqual(callback_thread_ident[0], this_thread)
|
||||
self.assertNotEqual(thread.ident, this_thread)
|
||||
self.assertEqual(args, (1, 3))
|
||||
self.assertEqual(kwargs, {"key": 4, "value": 5})
|
||||
|
||||
def test_when_disabled_runs_in_same_thread(self):
|
||||
async def test_when_disabled_runs_in_same_thread(self):
|
||||
"""
|
||||
If the CPU thread pool is disabled, the given function runs in the
|
||||
current thread.
|
||||
@ -639,10 +629,7 @@ class CPUThreadPool(unittest.TestCase):
|
||||
return current_thread().ident, args, kwargs
|
||||
|
||||
this_thread = current_thread().ident
|
||||
result = defer_to_thread(f, 1, 3, key=4, value=5)
|
||||
l = []
|
||||
result.addCallback(l.append)
|
||||
thread, args, kwargs = l[0]
|
||||
thread, args, kwargs = await defer_to_thread(f, 1, 3, key=4, value=5)
|
||||
|
||||
self.assertEqual(thread, this_thread)
|
||||
self.assertEqual(args, (1, 3))
|
||||
|
@ -15,14 +15,13 @@ scheduler affinity or cgroups, but that's not the end of the world.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import TypeVar, Callable, cast
|
||||
from typing import TypeVar, Callable
|
||||
from functools import partial
|
||||
import threading
|
||||
from typing_extensions import ParamSpec
|
||||
from unittest import TestCase
|
||||
|
||||
from twisted.python.threadpool import ThreadPool
|
||||
from twisted.internet.defer import Deferred, maybeDeferred
|
||||
from twisted.internet.threads import deferToThreadPool
|
||||
from twisted.internet import reactor
|
||||
|
||||
@ -51,21 +50,22 @@ R = TypeVar("R")
|
||||
_DISABLED = False
|
||||
|
||||
|
||||
def defer_to_thread(
|
||||
f: Callable[P, R], *args: P.args, **kwargs: P.kwargs
|
||||
) -> Deferred[R]:
|
||||
async def defer_to_thread(f: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
"""
|
||||
Run the function in a thread, return the result as a ``Deferred``.
|
||||
Run the function in a thread, return the result.
|
||||
|
||||
However, if ``disable_thread_pool_for_test()`` was called the function will
|
||||
be called synchronously inside the current thread.
|
||||
|
||||
To reduce chances of synchronous tests being misleading as a result, this
|
||||
is an async function on presumption that will encourage immediate ``await``ing.
|
||||
"""
|
||||
if _DISABLED:
|
||||
return maybeDeferred(f, *args, **kwargs)
|
||||
return f(*args, **kwargs)
|
||||
|
||||
# deferToThreadPool has no type annotations...
|
||||
result = deferToThreadPool(reactor, _CPU_THREAD_POOL, f, *args, **kwargs)
|
||||
return cast(Deferred[R], result)
|
||||
result = await deferToThreadPool(reactor, _CPU_THREAD_POOL, f, *args, **kwargs)
|
||||
return result
|
||||
|
||||
|
||||
def disable_thread_pool_for_test(test: TestCase) -> None:
|
||||
|
Loading…
x
Reference in New Issue
Block a user