From 2dbf90cafeb0e58f6724edbced183a776e2e7195 Mon Sep 17 00:00:00 2001 From: Joseph Zuniga-Daly <59851625+josephzunigadaly@users.noreply.github.com> Date: Wed, 25 Mar 2020 11:53:06 +0000 Subject: [PATCH] ENT-4857: Fix race condition in trackTransaction (#6096) - Fix issue - Emit warning if we are inside a DB transaction - Include a path that does not emit warning - Add unit tests --- .../node/services/api/ServiceHubInternal.kt | 6 + .../persistence/DBTransactionStorage.kt | 14 +++ .../statemachine/ActionExecutorImpl.kt | 2 +- .../node/messaging/TwoPartyTradeFlowTests.kt | 6 + .../persistence/DBTransactionStorageTests.kt | 113 +++++++++++++++++- .../node/internal/MockTransactionStorage.kt | 4 + 6 files changed, 143 insertions(+), 2 deletions(-) diff --git a/node/src/main/kotlin/net/corda/node/services/api/ServiceHubInternal.kt b/node/src/main/kotlin/net/corda/node/services/api/ServiceHubInternal.kt index 984d0b216f..0009424917 100644 --- a/node/src/main/kotlin/net/corda/node/services/api/ServiceHubInternal.kt +++ b/node/src/main/kotlin/net/corda/node/services/api/ServiceHubInternal.kt @@ -260,6 +260,12 @@ interface WritableTransactionStorage : TransactionStorage { * ID exists. */ fun getTransactionInternal(id: SecureHash): Pair? + + /** + * Returns a future that completes with the transaction corresponding to [id] once it has been committed. Do not warn when run inside + * a DB transaction. + */ + fun trackTransactionWithNoWarning(id: SecureHash): CordaFuture } /** diff --git a/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorage.kt b/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorage.kt index 14f7492139..3c8b8fac56 100644 --- a/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorage.kt +++ b/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorage.kt @@ -133,6 +133,8 @@ class DBTransactionStorage(private val database: CordaPersistence, cacheFactory: val actTx = tx.peekableValue ?: return 0 return actTx.sigs.sumBy { it.size + transactionSignatureOverheadEstimate } + actTx.txBits.size } + + private val log = contextLogger() } private val txStorage = ThreadBox(createTransactionsMap(cacheFactory)) @@ -211,12 +213,24 @@ class DBTransactionStorage(private val database: CordaPersistence, cacheFactory: } override fun trackTransaction(id: SecureHash): CordaFuture { + + if (contextTransactionOrNull != null) { + log.warn("trackTransaction is called with an already existing, open DB transaction. As a result, there might be transactions missing from the returned data feed, because of race conditions.") + } + + return trackTransactionWithNoWarning(id) + } + + override fun trackTransactionWithNoWarning(id: SecureHash): CordaFuture { + val updateFuture = updates.filter { it.id == id }.toFuture() return database.transaction { txStorage.locked { val existingTransaction = getTransaction(id) if (existingTransaction == null) { updates.filter { it.id == id }.toFuture() + updateFuture } else { + updateFuture.cancel(false) doneFuture(existingTransaction) } } diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/ActionExecutorImpl.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/ActionExecutorImpl.kt index 237b1097f7..7db0f4c8dc 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/ActionExecutorImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/ActionExecutorImpl.kt @@ -83,7 +83,7 @@ class ActionExecutorImpl( @Suspendable private fun executeTrackTransaction(fiber: FlowFiber, action: Action.TrackTransaction) { - services.validatedTransactions.trackTransaction(action.hash).thenMatch( + services.validatedTransactions.trackTransactionWithNoWarning(action.hash).thenMatch( success = { transaction -> fiber.scheduleEvent(Event.TransactionCommitted(transaction)) }, diff --git a/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt b/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt index a11e44c844..cdeb571f01 100644 --- a/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt +++ b/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt @@ -733,6 +733,12 @@ class TwoPartyTradeFlowTests(private val anonymous: Boolean) { } } + override fun trackTransactionWithNoWarning(id: SecureHash): CordaFuture { + return database.transaction { + delegate.trackTransactionWithNoWarning(id) + } + } + override fun track(): DataFeed, SignedTransaction> { return database.transaction { delegate.track() diff --git a/node/src/test/kotlin/net/corda/node/services/persistence/DBTransactionStorageTests.kt b/node/src/test/kotlin/net/corda/node/services/persistence/DBTransactionStorageTests.kt index fff7305075..8873f4d25a 100644 --- a/node/src/test/kotlin/net/corda/node/services/persistence/DBTransactionStorageTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/persistence/DBTransactionStorageTests.kt @@ -1,6 +1,7 @@ package net.corda.node.services.persistence import junit.framework.TestCase.assertTrue +import net.corda.core.concurrent.CordaFuture import net.corda.core.contracts.StateRef import net.corda.core.crypto.Crypto import net.corda.core.crypto.SecureHash @@ -17,12 +18,21 @@ import net.corda.testing.internal.TestingNamedCacheFactory import net.corda.testing.internal.configureDatabase import net.corda.testing.internal.createWireTransaction import net.corda.testing.node.MockServices.Companion.makeTestDataSourceProperties +import org.apache.logging.log4j.LogManager +import org.apache.logging.log4j.core.Appender +import org.apache.logging.log4j.core.LoggerContext +import org.apache.logging.log4j.core.appender.WriterAppender import org.assertj.core.api.Assertions.assertThat import org.junit.After +import org.junit.Assert import org.junit.Before import org.junit.Rule import org.junit.Test +import rx.plugins.RxJavaHooks +import java.io.StringWriter +import java.util.concurrent.Semaphore import java.util.concurrent.TimeUnit +import kotlin.concurrent.thread import kotlin.test.assertEquals class DBTransactionStorageTests { @@ -33,7 +43,7 @@ class DBTransactionStorageTests { @Rule @JvmField - val testSerialization = SerializationEnvironmentRule() + val testSerialization = SerializationEnvironmentRule(inheritable = true) private lateinit var database: CordaPersistence private lateinit var transactionStorage: DBTransactionStorage @@ -198,6 +208,107 @@ class DBTransactionStorageTests { assertTransactionIsRetrievable(secondTransaction) } + @Suppress("UnstableApiUsage") + @Test(timeout=300_000) + fun `race condition - failure path`() { + + // Insert a sleep into trackTransaction + RxJavaHooks.setOnObservableCreate { + Thread.sleep(1_000) + it + } + try { + `race condition - ok path`() + } finally { + // Remove sleep so it does not affect other tests + RxJavaHooks.setOnObservableCreate { it } + } + } + + @Test(timeout=300_000) + fun `race condition - ok path`() { + + // Arrange + + val signedTransaction = newTransaction() + + val threadCount = 2 + val finishedThreadsSemaphore = Semaphore(threadCount) + finishedThreadsSemaphore.acquire(threadCount) + + // Act + + thread(name = "addTransaction") { + transactionStorage.addTransaction(signedTransaction) + finishedThreadsSemaphore.release() + } + + var result: CordaFuture? = null + thread(name = "trackTransaction") { + result = transactionStorage.trackTransaction(signedTransaction.id) + finishedThreadsSemaphore.release() + } + + if (!finishedThreadsSemaphore.tryAcquire(threadCount, 1, TimeUnit.MINUTES)) { + Assert.fail("Threads did not finish") + } + + // Assert + + assertThat(result).isNotNull() + assertThat(result?.get(20, TimeUnit.SECONDS)?.id).isEqualTo(signedTransaction.id) + } + + @Test(timeout=300_000) + fun `race condition - transaction warning`() { + + // Arrange + val signedTransaction = newTransaction() + + // Act + val logMessages = collectLogsFrom { + database.transaction { + val result = transactionStorage.trackTransaction(signedTransaction.id) + result.cancel(false) + } + } + + // Assert + assertThat(logMessages).contains("trackTransaction is called with an already existing, open DB transaction. As a result, there might be transactions missing from the returned data feed, because of race conditions.") + } + + private fun collectLogsFrom(statement: () -> Unit): String { + // Create test appender + val stringWriter = StringWriter() + val appenderName = this::collectLogsFrom.name + val appender: Appender = WriterAppender.createAppender( + null, + null, + stringWriter, + appenderName, + false, + true + ) + appender.start() + + // Add test appender + val context = LogManager.getContext(false) as LoggerContext + val configuration = context.configuration + configuration.addAppender(appender) + configuration.loggers.values.forEach { it.addAppender(appender, null, null) } + + try { + statement() + } finally { + // Remove test appender + configuration.loggers.values.forEach { it.removeAppender(appenderName) } + configuration.appenders.remove(appenderName) + appender.stop() + } + + return stringWriter.toString() + } + private fun newTransactionStorage(cacheSizeBytesOverride: Long? = null) { transactionStorage = DBTransactionStorage(database, TestingNamedCacheFactory(cacheSizeBytesOverride ?: 1024)) diff --git a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/MockTransactionStorage.kt b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/MockTransactionStorage.kt index d9d090b62f..c1cebf95e1 100644 --- a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/MockTransactionStorage.kt +++ b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/MockTransactionStorage.kt @@ -21,6 +21,10 @@ open class MockTransactionStorage : WritableTransactionStorage, SingletonSeriali return getTransaction(id)?.let { doneFuture(it) } ?: _updatesPublisher.filter { it.id == id }.toFuture() } + override fun trackTransactionWithNoWarning(id: SecureHash): CordaFuture { + return trackTransaction(id) + } + override fun track(): DataFeed, SignedTransaction> { return DataFeed(txns.values.mapNotNull { if (it.isVerified) it.stx else null }, _updatesPublisher) }