ENT-4857: Fix race condition in trackTransaction (#6096)

- Fix issue
- Emit warning if we are inside a DB transaction
- Include a path that does not emit warning
- Add unit tests
This commit is contained in:
Joseph Zuniga-Daly 2020-03-25 11:53:06 +00:00 committed by GitHub
parent 8f54ef740f
commit 2dbf90cafe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 143 additions and 2 deletions

View File

@ -260,6 +260,12 @@ interface WritableTransactionStorage : TransactionStorage {
* ID exists. * ID exists.
*/ */
fun getTransactionInternal(id: SecureHash): Pair<SignedTransaction, Boolean>? fun getTransactionInternal(id: SecureHash): Pair<SignedTransaction, Boolean>?
/**
* Returns a future that completes with the transaction corresponding to [id] once it has been committed. Do not warn when run inside
* a DB transaction.
*/
fun trackTransactionWithNoWarning(id: SecureHash): CordaFuture<SignedTransaction>
} }
/** /**

View File

@ -133,6 +133,8 @@ class DBTransactionStorage(private val database: CordaPersistence, cacheFactory:
val actTx = tx.peekableValue ?: return 0 val actTx = tx.peekableValue ?: return 0
return actTx.sigs.sumBy { it.size + transactionSignatureOverheadEstimate } + actTx.txBits.size return actTx.sigs.sumBy { it.size + transactionSignatureOverheadEstimate } + actTx.txBits.size
} }
private val log = contextLogger()
} }
private val txStorage = ThreadBox(createTransactionsMap(cacheFactory)) private val txStorage = ThreadBox(createTransactionsMap(cacheFactory))
@ -211,12 +213,24 @@ class DBTransactionStorage(private val database: CordaPersistence, cacheFactory:
} }
override fun trackTransaction(id: SecureHash): CordaFuture<SignedTransaction> { override fun trackTransaction(id: SecureHash): CordaFuture<SignedTransaction> {
if (contextTransactionOrNull != null) {
log.warn("trackTransaction is called with an already existing, open DB transaction. As a result, there might be transactions missing from the returned data feed, because of race conditions.")
}
return trackTransactionWithNoWarning(id)
}
override fun trackTransactionWithNoWarning(id: SecureHash): CordaFuture<SignedTransaction> {
val updateFuture = updates.filter { it.id == id }.toFuture()
return database.transaction { return database.transaction {
txStorage.locked { txStorage.locked {
val existingTransaction = getTransaction(id) val existingTransaction = getTransaction(id)
if (existingTransaction == null) { if (existingTransaction == null) {
updates.filter { it.id == id }.toFuture() updates.filter { it.id == id }.toFuture()
updateFuture
} else { } else {
updateFuture.cancel(false)
doneFuture(existingTransaction) doneFuture(existingTransaction)
} }
} }

View File

@ -83,7 +83,7 @@ class ActionExecutorImpl(
@Suspendable @Suspendable
private fun executeTrackTransaction(fiber: FlowFiber, action: Action.TrackTransaction) { private fun executeTrackTransaction(fiber: FlowFiber, action: Action.TrackTransaction) {
services.validatedTransactions.trackTransaction(action.hash).thenMatch( services.validatedTransactions.trackTransactionWithNoWarning(action.hash).thenMatch(
success = { transaction -> success = { transaction ->
fiber.scheduleEvent(Event.TransactionCommitted(transaction)) fiber.scheduleEvent(Event.TransactionCommitted(transaction))
}, },

View File

@ -733,6 +733,12 @@ class TwoPartyTradeFlowTests(private val anonymous: Boolean) {
} }
} }
override fun trackTransactionWithNoWarning(id: SecureHash): CordaFuture<SignedTransaction> {
return database.transaction {
delegate.trackTransactionWithNoWarning(id)
}
}
override fun track(): DataFeed<List<SignedTransaction>, SignedTransaction> { override fun track(): DataFeed<List<SignedTransaction>, SignedTransaction> {
return database.transaction { return database.transaction {
delegate.track() delegate.track()

View File

@ -1,6 +1,7 @@
package net.corda.node.services.persistence package net.corda.node.services.persistence
import junit.framework.TestCase.assertTrue import junit.framework.TestCase.assertTrue
import net.corda.core.concurrent.CordaFuture
import net.corda.core.contracts.StateRef import net.corda.core.contracts.StateRef
import net.corda.core.crypto.Crypto import net.corda.core.crypto.Crypto
import net.corda.core.crypto.SecureHash import net.corda.core.crypto.SecureHash
@ -17,12 +18,21 @@ import net.corda.testing.internal.TestingNamedCacheFactory
import net.corda.testing.internal.configureDatabase import net.corda.testing.internal.configureDatabase
import net.corda.testing.internal.createWireTransaction import net.corda.testing.internal.createWireTransaction
import net.corda.testing.node.MockServices.Companion.makeTestDataSourceProperties import net.corda.testing.node.MockServices.Companion.makeTestDataSourceProperties
import org.apache.logging.log4j.LogManager
import org.apache.logging.log4j.core.Appender
import org.apache.logging.log4j.core.LoggerContext
import org.apache.logging.log4j.core.appender.WriterAppender
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
import org.junit.After import org.junit.After
import org.junit.Assert
import org.junit.Before import org.junit.Before
import org.junit.Rule import org.junit.Rule
import org.junit.Test import org.junit.Test
import rx.plugins.RxJavaHooks
import java.io.StringWriter
import java.util.concurrent.Semaphore
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
import kotlin.concurrent.thread
import kotlin.test.assertEquals import kotlin.test.assertEquals
class DBTransactionStorageTests { class DBTransactionStorageTests {
@ -33,7 +43,7 @@ class DBTransactionStorageTests {
@Rule @Rule
@JvmField @JvmField
val testSerialization = SerializationEnvironmentRule() val testSerialization = SerializationEnvironmentRule(inheritable = true)
private lateinit var database: CordaPersistence private lateinit var database: CordaPersistence
private lateinit var transactionStorage: DBTransactionStorage private lateinit var transactionStorage: DBTransactionStorage
@ -198,6 +208,107 @@ class DBTransactionStorageTests {
assertTransactionIsRetrievable(secondTransaction) assertTransactionIsRetrievable(secondTransaction)
} }
@Suppress("UnstableApiUsage")
@Test(timeout=300_000)
fun `race condition - failure path`() {
// Insert a sleep into trackTransaction
RxJavaHooks.setOnObservableCreate {
Thread.sleep(1_000)
it
}
try {
`race condition - ok path`()
} finally {
// Remove sleep so it does not affect other tests
RxJavaHooks.setOnObservableCreate { it }
}
}
@Test(timeout=300_000)
fun `race condition - ok path`() {
// Arrange
val signedTransaction = newTransaction()
val threadCount = 2
val finishedThreadsSemaphore = Semaphore(threadCount)
finishedThreadsSemaphore.acquire(threadCount)
// Act
thread(name = "addTransaction") {
transactionStorage.addTransaction(signedTransaction)
finishedThreadsSemaphore.release()
}
var result: CordaFuture<SignedTransaction>? = null
thread(name = "trackTransaction") {
result = transactionStorage.trackTransaction(signedTransaction.id)
finishedThreadsSemaphore.release()
}
if (!finishedThreadsSemaphore.tryAcquire(threadCount, 1, TimeUnit.MINUTES)) {
Assert.fail("Threads did not finish")
}
// Assert
assertThat(result).isNotNull()
assertThat(result?.get(20, TimeUnit.SECONDS)?.id).isEqualTo(signedTransaction.id)
}
@Test(timeout=300_000)
fun `race condition - transaction warning`() {
// Arrange
val signedTransaction = newTransaction()
// Act
val logMessages = collectLogsFrom {
database.transaction {
val result = transactionStorage.trackTransaction(signedTransaction.id)
result.cancel(false)
}
}
// Assert
assertThat(logMessages).contains("trackTransaction is called with an already existing, open DB transaction. As a result, there might be transactions missing from the returned data feed, because of race conditions.")
}
private fun collectLogsFrom(statement: () -> Unit): String {
// Create test appender
val stringWriter = StringWriter()
val appenderName = this::collectLogsFrom.name
val appender: Appender = WriterAppender.createAppender(
null,
null,
stringWriter,
appenderName,
false,
true
)
appender.start()
// Add test appender
val context = LogManager.getContext(false) as LoggerContext
val configuration = context.configuration
configuration.addAppender(appender)
configuration.loggers.values.forEach { it.addAppender(appender, null, null) }
try {
statement()
} finally {
// Remove test appender
configuration.loggers.values.forEach { it.removeAppender(appenderName) }
configuration.appenders.remove(appenderName)
appender.stop()
}
return stringWriter.toString()
}
private fun newTransactionStorage(cacheSizeBytesOverride: Long? = null) { private fun newTransactionStorage(cacheSizeBytesOverride: Long? = null) {
transactionStorage = DBTransactionStorage(database, TestingNamedCacheFactory(cacheSizeBytesOverride transactionStorage = DBTransactionStorage(database, TestingNamedCacheFactory(cacheSizeBytesOverride
?: 1024)) ?: 1024))

View File

@ -21,6 +21,10 @@ open class MockTransactionStorage : WritableTransactionStorage, SingletonSeriali
return getTransaction(id)?.let { doneFuture(it) } ?: _updatesPublisher.filter { it.id == id }.toFuture() return getTransaction(id)?.let { doneFuture(it) } ?: _updatesPublisher.filter { it.id == id }.toFuture()
} }
override fun trackTransactionWithNoWarning(id: SecureHash): CordaFuture<SignedTransaction> {
return trackTransaction(id)
}
override fun track(): DataFeed<List<SignedTransaction>, SignedTransaction> { override fun track(): DataFeed<List<SignedTransaction>, SignedTransaction> {
return DataFeed(txns.values.mapNotNull { if (it.isVerified) it.stx else null }, _updatesPublisher) return DataFeed(txns.values.mapNotNull { if (it.isVerified) it.stx else null }, _updatesPublisher)
} }