Introduce database transactions around message handling, web API and in protocols.

This commit is contained in:
rick.parker 2016-09-12 14:08:18 +01:00
parent dab883dcba
commit 27cb1c3597
10 changed files with 248 additions and 79 deletions

View File

@ -45,10 +45,7 @@ import com.r3corda.node.services.transactions.SimpleNotaryService
import com.r3corda.node.services.transactions.ValidatingNotaryService import com.r3corda.node.services.transactions.ValidatingNotaryService
import com.r3corda.node.services.wallet.CashBalanceAsMetricsObserver import com.r3corda.node.services.wallet.CashBalanceAsMetricsObserver
import com.r3corda.node.services.wallet.NodeWalletService import com.r3corda.node.services.wallet.NodeWalletService
import com.r3corda.node.utilities.ANSIProgressObserver import com.r3corda.node.utilities.*
import com.r3corda.node.utilities.AddOrRemove
import com.r3corda.node.utilities.AffinityExecutor
import com.r3corda.node.utilities.configureDatabase
import org.slf4j.Logger import org.slf4j.Logger
import java.nio.file.FileAlreadyExistsException import java.nio.file.FileAlreadyExistsException
import java.nio.file.Files import java.nio.file.Files
@ -125,7 +122,7 @@ abstract class AbstractNode(val dir: Path, val configuration: NodeConfiguration,
lateinit var checkpointStorage: CheckpointStorage lateinit var checkpointStorage: CheckpointStorage
lateinit var smm: StateMachineManager lateinit var smm: StateMachineManager
lateinit var wallet: WalletService lateinit var wallet: WalletService
lateinit var keyManagement: E2ETestKeyManagementService lateinit var keyManagement: KeyManagementService
var inNodeNetworkMapService: NetworkMapService? = null var inNodeNetworkMapService: NetworkMapService? = null
var inNodeWalletMonitorService: WalletMonitorService? = null var inNodeWalletMonitorService: WalletMonitorService? = null
var inNodeNotaryService: NotaryService? = null var inNodeNotaryService: NotaryService? = null
@ -163,60 +160,62 @@ abstract class AbstractNode(val dir: Path, val configuration: NodeConfiguration,
require(!started) { "Node has already been started" } require(!started) { "Node has already been started" }
log.info("Node starting up ...") log.info("Node starting up ...")
initialiseDatabasePersistence() // Do all of this in a database transaction so anything that might need a connection has one.
val storageServices = initialiseStorageService(dir) initialiseDatabasePersistence() {
storage = storageServices.first val storageServices = initialiseStorageService(dir)
checkpointStorage = storageServices.second storage = storageServices.first
netMapCache = InMemoryNetworkMapCache() checkpointStorage = storageServices.second
net = makeMessagingService() netMapCache = InMemoryNetworkMapCache()
wallet = makeWalletService() net = makeMessagingService()
wallet = makeWalletService()
identity = makeIdentityService() identity = makeIdentityService()
// Place the long term identity key in the KMS. Eventually, this is likely going to be separated again because // Place the long term identity key in the KMS. Eventually, this is likely going to be separated again because
// the KMS is meant for derived temporary keys used in transactions, and we're not supposed to sign things with // the KMS is meant for derived temporary keys used in transactions, and we're not supposed to sign things with
// the identity key. But the infrastructure to make that easy isn't here yet. // the identity key. But the infrastructure to make that easy isn't here yet.
keyManagement = E2ETestKeyManagementService(setOf(storage.myLegalIdentityKey)) keyManagement = makeKeyManagementService()
api = APIServerImpl(this) api = APIServerImpl(this@AbstractNode)
scheduler = NodeSchedulerService(services) scheduler = NodeSchedulerService(services)
protocolLogicFactory = initialiseProtocolLogicFactory() protocolLogicFactory = initialiseProtocolLogicFactory()
val tokenizableServices = mutableListOf(storage, net, wallet, keyManagement, identity, platformClock, scheduler) val tokenizableServices = mutableListOf(storage, net, wallet, keyManagement, identity, platformClock, scheduler)
customServices.clear() customServices.clear()
customServices.addAll(buildPluginServices(tokenizableServices)) customServices.addAll(buildPluginServices(tokenizableServices))
// TODO: uniquenessProvider creation should be inside makeNotaryService(), but notary service initialisation // TODO: uniquenessProvider creation should be inside makeNotaryService(), but notary service initialisation
// depends on smm, while smm depends on tokenizableServices, which uniquenessProvider is part of // depends on smm, while smm depends on tokenizableServices, which uniquenessProvider is part of
advertisedServices.singleOrNull { it.isSubTypeOf(NotaryService.Type) }?.let { advertisedServices.singleOrNull { it.isSubTypeOf(NotaryService.Type) }?.let {
uniquenessProvider = makeUniquenessProvider() uniquenessProvider = makeUniquenessProvider()
tokenizableServices.add(uniquenessProvider!!) tokenizableServices.add(uniquenessProvider!!)
}
smm = StateMachineManager(services,
listOf(tokenizableServices),
checkpointStorage,
serverThread)
if (serverThread is ExecutorService) {
runOnStop += Runnable {
// We wait here, even though any in-flight messages should have been drained away because the
// server thread can potentially have other non-messaging tasks scheduled onto it. The timeout value is
// arbitrary and might be inappropriate.
MoreExecutors.shutdownAndAwaitTermination(serverThread as ExecutorService, 50, TimeUnit.SECONDS)
} }
smm = StateMachineManager(services,
listOf(tokenizableServices),
checkpointStorage,
serverThread)
if (serverThread is ExecutorService) {
runOnStop += Runnable {
// We wait here, even though any in-flight messages should have been drained away because the
// server thread can potentially have other non-messaging tasks scheduled onto it. The timeout value is
// arbitrary and might be inappropriate.
MoreExecutors.shutdownAndAwaitTermination(serverThread as ExecutorService, 50, TimeUnit.SECONDS)
}
}
inNodeWalletMonitorService = makeWalletMonitorService() // Note this HAS to be after smm is set
buildAdvertisedServices()
// TODO: this model might change but for now it provides some de-coupling
// Add SMM observers
ANSIProgressObserver(smm)
// Add wallet observers
CashBalanceAsMetricsObserver(services)
ScheduledActivityObserver(services)
} }
inNodeWalletMonitorService = makeWalletMonitorService() // Note this HAS to be after smm is set
buildAdvertisedServices()
// TODO: this model might change but for now it provides some de-coupling
// Add SMM observers
ANSIProgressObserver(smm)
// Add wallet observers
CashBalanceAsMetricsObserver(services)
ScheduledActivityObserver(services)
startMessagingService() startMessagingService()
runOnStop += Runnable { net.stop() } runOnStop += Runnable { net.stop() }
_networkMapRegistrationFuture.setFuture(registerWithNetworkMap()) _networkMapRegistrationFuture.setFuture(registerWithNetworkMap())
@ -226,13 +225,21 @@ abstract class AbstractNode(val dir: Path, val configuration: NodeConfiguration,
return this return this
} }
private fun initialiseDatabasePersistence() { // Specific class so that MockNode can catch it.
class DatabaseConfigurationException(msg: String) : Exception(msg)
protected open fun initialiseDatabasePersistence(insideTransaction: () -> Unit) {
val props = configuration.dataSourceProperties val props = configuration.dataSourceProperties
if (props.isNotEmpty()) { if (props.isNotEmpty()) {
val (toClose, database) = configureDatabase(props) val (toClose, database) = configureDatabase(props)
// Now log the vendor string as this will also cause a connection to be tested eagerly. // Now log the vendor string as this will also cause a connection to be tested eagerly.
log.info("Connected to ${database.vendor} database.") log.info("Connected to ${database.vendor} database.")
runOnStop += Runnable { toClose.close() } runOnStop += Runnable { toClose.close() }
databaseTransaction {
insideTransaction()
}
} else {
throw DatabaseConfigurationException("There must be a database configured.")
} }
} }
@ -326,6 +333,8 @@ abstract class AbstractNode(val dir: Path, val configuration: NodeConfiguration,
return future return future
} }
protected open fun makeKeyManagementService(): KeyManagementService = E2ETestKeyManagementService(setOf(storage.myLegalIdentityKey))
open protected fun makeNetworkMapService() { open protected fun makeNetworkMapService() {
val expires = platformClock.instant() + NetworkMapService.DEFAULT_EXPIRATION_PERIOD val expires = platformClock.instant() + NetworkMapService.DEFAULT_EXPIRATION_PERIOD
val reg = NodeRegistration(info, Long.MAX_VALUE, AddOrRemove.ADD, expires) val reg = NodeRegistration(info, Long.MAX_VALUE, AddOrRemove.ADD, expires)

View File

@ -18,9 +18,11 @@ import com.r3corda.node.servlets.Config
import com.r3corda.node.servlets.DataUploadServlet import com.r3corda.node.servlets.DataUploadServlet
import com.r3corda.node.servlets.ResponseFilter import com.r3corda.node.servlets.ResponseFilter
import com.r3corda.node.utilities.AffinityExecutor import com.r3corda.node.utilities.AffinityExecutor
import com.r3corda.node.utilities.databaseTransaction
import org.eclipse.jetty.server.* import org.eclipse.jetty.server.*
import org.eclipse.jetty.server.handler.HandlerCollection import org.eclipse.jetty.server.handler.HandlerCollection
import org.eclipse.jetty.servlet.DefaultServlet import org.eclipse.jetty.servlet.DefaultServlet
import org.eclipse.jetty.servlet.FilterHolder
import org.eclipse.jetty.servlet.ServletContextHandler import org.eclipse.jetty.servlet.ServletContextHandler
import org.eclipse.jetty.servlet.ServletHolder import org.eclipse.jetty.servlet.ServletHolder
import org.eclipse.jetty.util.ssl.SslContextFactory import org.eclipse.jetty.util.ssl.SslContextFactory
@ -33,7 +35,10 @@ import java.lang.management.ManagementFactory
import java.nio.channels.FileLock import java.nio.channels.FileLock
import java.nio.file.Path import java.nio.file.Path
import java.time.Clock import java.time.Clock
import java.util.*
import javax.management.ObjectName import javax.management.ObjectName
import javax.servlet.*
import javax.servlet.http.HttpServletResponse
import kotlin.concurrent.thread import kotlin.concurrent.thread
class ConfigurationException(message: String) : Exception(message) class ConfigurationException(message: String) : Exception(message)
@ -231,6 +236,10 @@ class Node(dir: Path, val p2pAddr: HostAndPort, val webServerAddr: HostAndPort,
val jerseyServlet = ServletHolder(container) val jerseyServlet = ServletHolder(container)
addServlet(jerseyServlet, "/api/*") addServlet(jerseyServlet, "/api/*")
jerseyServlet.initOrder = 0 // Initialise at server start jerseyServlet.initOrder = 0 // Initialise at server start
// Wrap all API calls in a database transaction.
val filterHolder = FilterHolder(DatabaseTransactionFilter())
addFilter(filterHolder, "/api/*", EnumSet.of(DispatcherType.REQUEST))
} }
} }
@ -324,4 +333,19 @@ class Node(dir: Path, val p2pAddr: HostAndPort, val webServerAddr: HostAndPort,
f.setLength(0) f.setLength(0)
f.write(ourProcessID.toByteArray()) f.write(ourProcessID.toByteArray())
} }
// Servlet filter to wrap API requests with a database transaction.
private class DatabaseTransactionFilter : Filter {
override fun init(filterConfig: FilterConfig?) {
}
override fun destroy() {
}
override fun doFilter(request: ServletRequest, response: ServletResponse, chain: FilterChain) {
databaseTransaction {
chain.doFilter(request, response)
}
}
}
} }

View File

@ -8,6 +8,7 @@ import com.r3corda.core.utilities.loggerFor
import com.r3corda.node.services.api.MessagingServiceInternal import com.r3corda.node.services.api.MessagingServiceInternal
import com.r3corda.node.services.config.NodeConfiguration import com.r3corda.node.services.config.NodeConfiguration
import com.r3corda.node.utilities.AffinityExecutor import com.r3corda.node.utilities.AffinityExecutor
import com.r3corda.node.utilities.databaseTransaction
import org.apache.activemq.artemis.api.core.ActiveMQObjectClosedException import org.apache.activemq.artemis.api.core.ActiveMQObjectClosedException
import org.apache.activemq.artemis.api.core.SimpleString import org.apache.activemq.artemis.api.core.SimpleString
import org.apache.activemq.artemis.api.core.client.* import org.apache.activemq.artemis.api.core.client.*
@ -217,7 +218,6 @@ class ArtemisMessagingClient(directory: Path,
state.locked { state.locked {
undeliveredMessages += msg undeliveredMessages += msg
} }
return false return false
} }
@ -232,13 +232,20 @@ class ArtemisMessagingClient(directory: Path,
// Note that handlers may re-enter this class. We aren't holding any locks and methods like // Note that handlers may re-enter this class. We aren't holding any locks and methods like
// start/run/stop have re-entrancy assertions at the top, so it is OK. // start/run/stop have re-entrancy assertions at the top, so it is OK.
executor.fetchFrom { executor.fetchFrom {
handler.callback(msg, handler) // TODO: we should be able to clean this up if we separate client and server code, but for now
// interpret persistent as "server" and non-persistent as "client".
if (persistentInbox) {
databaseTransaction {
handler.callback(msg, handler)
}
} else {
handler.callback(msg, handler)
}
} }
} catch(e: Exception) { } catch(e: Exception) {
log.error("Caught exception whilst executing message handler for ${msg.topicSession}", e) log.error("Caught exception whilst executing message handler for ${msg.topicSession}", e)
} }
} }
return true return true
} }

View File

@ -3,15 +3,20 @@ package com.r3corda.node.services.statemachine
import co.paralleluniverse.fibers.Fiber import co.paralleluniverse.fibers.Fiber
import co.paralleluniverse.fibers.FiberScheduler import co.paralleluniverse.fibers.FiberScheduler
import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.fibers.Suspendable
import co.paralleluniverse.strands.Strand
import com.google.common.util.concurrent.ListenableFuture import com.google.common.util.concurrent.ListenableFuture
import com.google.common.util.concurrent.SettableFuture import com.google.common.util.concurrent.SettableFuture
import com.r3corda.core.crypto.Party import com.r3corda.core.crypto.Party
import com.r3corda.core.protocols.ProtocolLogic import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.protocols.ProtocolStateMachine import com.r3corda.core.protocols.ProtocolStateMachine
import com.r3corda.core.utilities.UntrustworthyData import com.r3corda.core.utilities.UntrustworthyData
import com.r3corda.core.utilities.trace
import com.r3corda.node.services.api.ServiceHubInternal import com.r3corda.node.services.api.ServiceHubInternal
import org.jetbrains.exposed.sql.transactions.TransactionManager
import org.slf4j.Logger import org.slf4j.Logger
import org.slf4j.LoggerFactory import org.slf4j.LoggerFactory
import java.sql.Connection
import java.sql.SQLException
import java.util.concurrent.ExecutionException import java.util.concurrent.ExecutionException
/** /**
@ -63,20 +68,43 @@ class ProtocolStateMachineImpl<R>(val logic: ProtocolLogic<R>,
@Suspendable @Suppress("UNCHECKED_CAST") @Suspendable @Suppress("UNCHECKED_CAST")
override fun run(): R { override fun run(): R {
createTransaction()
val result = try { val result = try {
logic.call() logic.call()
} catch (t: Throwable) { } catch (t: Throwable) {
commitTransaction()
actionOnEnd() actionOnEnd()
_resultFuture?.setException(t) _resultFuture?.setException(t)
throw ExecutionException(t) throw ExecutionException(t)
} }
// This is to prevent actionOnEnd being called twice if it throws an exception // This is to prevent actionOnEnd being called twice if it throws an exception
commitTransaction()
actionOnEnd() actionOnEnd()
_resultFuture?.set(result) _resultFuture?.set(result)
return result return result
} }
private fun createTransaction() {
// Make sure we have a database transaction
TransactionManager.currentOrNew(Connection.TRANSACTION_REPEATABLE_READ)
logger.trace { "Starting database transaction ${TransactionManager.currentOrNull()} on ${Strand.currentStrand()}." }
}
private fun commitTransaction() {
val transaction = TransactionManager.current()
try {
logger.trace { "Commiting database transaction $transaction on ${Strand.currentStrand()}." }
transaction.commit()
} catch (e: SQLException) {
// TODO: we will get here if the database is not available. Think about how to shutdown and restart cleanly.
logger.error("Transaction commit failed: ${e.message}", e)
System.exit(1)
} finally {
transaction.close()
}
}
@Suspendable @Suspendable
private fun <T : Any> suspendAndExpectReceive(receiveRequest: ReceiveRequest<T>): UntrustworthyData<T> { private fun <T : Any> suspendAndExpectReceive(receiveRequest: ReceiveRequest<T>): UntrustworthyData<T> {
suspend(receiveRequest) suspend(receiveRequest)
@ -108,6 +136,7 @@ class ProtocolStateMachineImpl<R>(val logic: ProtocolLogic<R>,
@Suspendable @Suspendable
private fun suspend(protocolIORequest: ProtocolIORequest) { private fun suspend(protocolIORequest: ProtocolIORequest) {
commitTransaction()
parkAndSerialize { fiber, serializer -> parkAndSerialize { fiber, serializer ->
try { try {
suspendAction(protocolIORequest) suspendAction(protocolIORequest)
@ -118,6 +147,7 @@ class ProtocolStateMachineImpl<R>(val logic: ProtocolLogic<R>,
_resultFuture?.setException(t) _resultFuture?.setException(t)
} }
} }
createTransaction()
} }
} }

View File

@ -5,9 +5,11 @@ import com.zaxxer.hikari.HikariDataSource
import org.jetbrains.exposed.sql.Database import org.jetbrains.exposed.sql.Database
import org.jetbrains.exposed.sql.Transaction import org.jetbrains.exposed.sql.Transaction
import java.io.Closeable import java.io.Closeable
import java.sql.Connection
import java.util.* import java.util.*
fun <T> databaseTransaction(statement: Transaction.() -> T): T = org.jetbrains.exposed.sql.transactions.transaction(statement) // TODO: Handle commit failure due to database unavailable. Better to shutdown and await database reconnect/recovery.
fun <T> databaseTransaction(statement: Transaction.() -> T): T = org.jetbrains.exposed.sql.transactions.transaction(Connection.TRANSACTION_REPEATABLE_READ, 1, statement)
fun configureDatabase(props: Properties): Pair<Closeable, Database> { fun configureDatabase(props: Properties): Pair<Closeable, Database> {
val config = HikariConfig(props) val config = HikariConfig(props)

View File

@ -5,7 +5,7 @@ keyStorePassword = "cordacadevpass"
trustStorePassword = "trustpass" trustStorePassword = "trustpass"
dataSourceProperties = { dataSourceProperties = {
dataSourceClassName = org.h2.jdbcx.JdbcDataSource dataSourceClassName = org.h2.jdbcx.JdbcDataSource
"dataSource.url" = "jdbc:h2:"${basedir}"/persistence" "dataSource.url" = "jdbc:h2:"${basedir}"/persistence;DB_CLOSE_ON_EXIT=FALSE"
"dataSource.user" = sa "dataSource.user" = sa
"dataSource.password" = "" "dataSource.password" = ""
} }

View File

@ -11,24 +11,29 @@ import com.r3corda.core.protocols.ProtocolLogicRef
import com.r3corda.core.protocols.ProtocolLogicRefFactory import com.r3corda.core.protocols.ProtocolLogicRefFactory
import com.r3corda.core.serialization.SingletonSerializeAsToken import com.r3corda.core.serialization.SingletonSerializeAsToken
import com.r3corda.core.utilities.DUMMY_NOTARY import com.r3corda.core.utilities.DUMMY_NOTARY
import com.r3corda.core.utilities.LogHelper
import com.r3corda.testing.node.TestClock import com.r3corda.testing.node.TestClock
import com.r3corda.node.services.events.NodeSchedulerService import com.r3corda.node.services.events.NodeSchedulerService
import com.r3corda.testing.node.InMemoryMessagingNetwork import com.r3corda.testing.node.InMemoryMessagingNetwork
import com.r3corda.node.services.persistence.PerFileCheckpointStorage import com.r3corda.node.services.persistence.PerFileCheckpointStorage
import com.r3corda.node.services.statemachine.StateMachineManager import com.r3corda.node.services.statemachine.StateMachineManager
import com.r3corda.node.services.wallet.NodeWalletService
import com.r3corda.node.utilities.AddOrRemove
import com.r3corda.node.utilities.AffinityExecutor import com.r3corda.node.utilities.AffinityExecutor
import com.r3corda.node.utilities.configureDatabase
import com.r3corda.testing.ALICE_KEY import com.r3corda.testing.ALICE_KEY
import com.r3corda.testing.node.MockKeyManagementService import com.r3corda.testing.node.MockKeyManagementService
import com.r3corda.testing.node.makeTestDataSourceProperties
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
import org.junit.After
import org.junit.Before import org.junit.Before
import org.junit.Test import org.junit.Test
import java.io.Closeable
import java.nio.file.FileSystem import java.nio.file.FileSystem
import java.security.PublicKey import java.security.PublicKey
import java.time.Clock import java.time.Clock
import java.time.Instant import java.time.Instant
import java.util.concurrent.CountDownLatch import java.util.concurrent.*
import java.util.concurrent.Executors
import java.util.concurrent.TimeUnit
class NodeSchedulerServiceTest : SingletonSerializeAsToken() { class NodeSchedulerServiceTest : SingletonSerializeAsToken() {
// Use an in memory file system for testing attachment storage. // Use an in memory file system for testing attachment storage.
@ -38,15 +43,21 @@ class NodeSchedulerServiceTest : SingletonSerializeAsToken() {
val stoppedClock = Clock.fixed(realClock.instant(), realClock.zone) val stoppedClock = Clock.fixed(realClock.instant(), realClock.zone)
val testClock = TestClock(stoppedClock) val testClock = TestClock(stoppedClock)
val smmExecutor = AffinityExecutor.ServiceAffinityExecutor("test", 1)
val schedulerGatedExecutor = AffinityExecutor.Gate(true) val schedulerGatedExecutor = AffinityExecutor.Gate(true)
// We have to allow Java boxed primitives but Kotlin warns we shouldn't be using them // We have to allow Java boxed primitives but Kotlin warns we shouldn't be using them
@Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN") @Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN")
val factory = ProtocolLogicRefFactory(mapOf(Pair(TestProtocolLogic::class.java.name, setOf(NodeSchedulerServiceTest::class.java.name, Integer::class.java.name)))) val factory = ProtocolLogicRefFactory(mapOf(Pair(TestProtocolLogic::class.java.name, setOf(NodeSchedulerServiceTest::class.java.name, Integer::class.java.name))))
val scheduler: NodeSchedulerService val services: MockServiceHubInternal
val services: ServiceHub
lateinit var scheduler: NodeSchedulerService
lateinit var smmExecutor: AffinityExecutor.ServiceAffinityExecutor
lateinit var dataSource: Closeable
lateinit var countDown: CountDownLatch
lateinit var smmHasRemovedAllProtocols: CountDownLatch
var calls: Int = 0
/** /**
* Have a reference to this test added to [ServiceHub] so that when the [ProtocolLogic] runs it can access the test instance. * Have a reference to this test added to [ServiceHub] so that when the [ProtocolLogic] runs it can access the test instance.
@ -60,22 +71,37 @@ class NodeSchedulerServiceTest : SingletonSerializeAsToken() {
init { init {
val kms = MockKeyManagementService(ALICE_KEY) val kms = MockKeyManagementService(ALICE_KEY)
val mockMessagingService = InMemoryMessagingNetwork(false).InMemoryMessaging(false, InMemoryMessagingNetwork.Handle(0, "None")) val mockMessagingService = InMemoryMessagingNetwork(false).InMemoryMessaging(false, InMemoryMessagingNetwork.Handle(0, "None"))
val mockServices = object : MockServiceHubInternal(overrideClock = testClock, keyManagement = kms, net = mockMessagingService), TestReference { services = object : MockServiceHubInternal(overrideClock = testClock, keyManagement = kms, net = mockMessagingService), TestReference {
override val testReference = this@NodeSchedulerServiceTest override val testReference = this@NodeSchedulerServiceTest
} }
services = mockServices
scheduler = NodeSchedulerService(mockServices, factory, schedulerGatedExecutor)
val mockSMM = StateMachineManager(mockServices, listOf(mockServices), PerFileCheckpointStorage(fs.getPath("checkpoints")), smmExecutor)
mockServices.smm = mockSMM
} }
lateinit var countDown: CountDownLatch
var calls: Int = 0
@Before @Before
fun setup() { fun setup() {
countDown = CountDownLatch(1) countDown = CountDownLatch(1)
smmHasRemovedAllProtocols = CountDownLatch(1)
calls = 0 calls = 0
dataSource = configureDatabase(makeTestDataSourceProperties()).first
scheduler = NodeSchedulerService(services, factory, schedulerGatedExecutor)
smmExecutor = AffinityExecutor.ServiceAffinityExecutor("test", 1)
val mockSMM = StateMachineManager(services, listOf(services), PerFileCheckpointStorage(fs.getPath("checkpoints")), smmExecutor)
mockSMM.changes.subscribe { change:Triple<ProtocolLogic<*>, AddOrRemove, Long> ->
if(change.second==AddOrRemove.REMOVE && mockSMM.allStateMachines.size==0) {
smmHasRemovedAllProtocols.countDown()
}
}
services.smm = mockSMM
}
@After
fun tearDown() {
// We need to make sure the StateMachineManager is done before shutting down executors.
if(services.smm.allStateMachines.isNotEmpty()) {
smmHasRemovedAllProtocols.await()
}
smmExecutor.shutdown()
smmExecutor.awaitTermination(60, TimeUnit.SECONDS)
dataSource.close()
} }
class TestState(val protocolLogicRef: ProtocolLogicRef, val instant: Instant) : LinearState, SchedulableState { class TestState(val protocolLogicRef: ProtocolLogicRef, val instant: Instant) : LinearState, SchedulableState {
@ -109,7 +135,7 @@ class NodeSchedulerServiceTest : SingletonSerializeAsToken() {
assertThat(calls).isEqualTo(0) assertThat(calls).isEqualTo(0)
schedulerGatedExecutor.waitAndRun() schedulerGatedExecutor.waitAndRun()
countDown.await(60, TimeUnit.SECONDS) countDown.await()
assertThat(calls).isEqualTo(1) assertThat(calls).isEqualTo(1)
} }
@ -120,7 +146,7 @@ class NodeSchedulerServiceTest : SingletonSerializeAsToken() {
assertThat(calls).isEqualTo(0) assertThat(calls).isEqualTo(0)
schedulerGatedExecutor.waitAndRun() schedulerGatedExecutor.waitAndRun()
countDown.await(60, TimeUnit.SECONDS) countDown.await()
assertThat(calls).isEqualTo(1) assertThat(calls).isEqualTo(1)
} }
@ -135,7 +161,7 @@ class NodeSchedulerServiceTest : SingletonSerializeAsToken() {
testClock.advanceBy(1.days) testClock.advanceBy(1.days)
backgroundExecutor.shutdown() backgroundExecutor.shutdown()
backgroundExecutor.awaitTermination(60, TimeUnit.SECONDS) backgroundExecutor.awaitTermination(60, TimeUnit.SECONDS)
countDown.await(60, TimeUnit.SECONDS) countDown.await()
assertThat(calls).isEqualTo(1) assertThat(calls).isEqualTo(1)
} }
@ -151,7 +177,7 @@ class NodeSchedulerServiceTest : SingletonSerializeAsToken() {
backgroundExecutor.execute { schedulerGatedExecutor.waitAndRun() } backgroundExecutor.execute { schedulerGatedExecutor.waitAndRun() }
testClock.advanceBy(1.days) testClock.advanceBy(1.days)
countDown.await(60, TimeUnit.SECONDS) countDown.await()
assertThat(calls).isEqualTo(3) assertThat(calls).isEqualTo(3)
backgroundExecutor.shutdown() backgroundExecutor.shutdown()
backgroundExecutor.awaitTermination(60, TimeUnit.SECONDS) backgroundExecutor.awaitTermination(60, TimeUnit.SECONDS)
@ -169,9 +195,10 @@ class NodeSchedulerServiceTest : SingletonSerializeAsToken() {
backgroundExecutor.execute { schedulerGatedExecutor.waitAndRun() } backgroundExecutor.execute { schedulerGatedExecutor.waitAndRun() }
testClock.advanceBy(1.days) testClock.advanceBy(1.days)
countDown.await(60, TimeUnit.SECONDS) countDown.await()
assertThat(calls).isEqualTo(1) assertThat(calls).isEqualTo(1)
testClock.advanceBy(1.days) testClock.advanceBy(1.days)
backgroundExecutor.execute { schedulerGatedExecutor.waitAndRun() }
backgroundExecutor.shutdown() backgroundExecutor.shutdown()
backgroundExecutor.awaitTermination(60, TimeUnit.SECONDS) backgroundExecutor.awaitTermination(60, TimeUnit.SECONDS)
} }
@ -187,7 +214,7 @@ class NodeSchedulerServiceTest : SingletonSerializeAsToken() {
scheduleTX(time, 3) scheduleTX(time, 3)
testClock.advanceBy(1.days) testClock.advanceBy(1.days)
countDown.await(60, TimeUnit.SECONDS) countDown.await()
assertThat(calls).isEqualTo(1) assertThat(calls).isEqualTo(1)
backgroundExecutor.shutdown() backgroundExecutor.shutdown()
backgroundExecutor.awaitTermination(60, TimeUnit.SECONDS) backgroundExecutor.awaitTermination(60, TimeUnit.SECONDS)
@ -206,7 +233,7 @@ class NodeSchedulerServiceTest : SingletonSerializeAsToken() {
backgroundExecutor.execute { schedulerGatedExecutor.waitAndRun() } backgroundExecutor.execute { schedulerGatedExecutor.waitAndRun() }
scheduler.unscheduleStateActivity(scheduledRef1!!.ref) scheduler.unscheduleStateActivity(scheduledRef1!!.ref)
testClock.advanceBy(1.days) testClock.advanceBy(1.days)
countDown.await(60, TimeUnit.SECONDS) countDown.await()
assertThat(calls).isEqualTo(3) assertThat(calls).isEqualTo(3)
backgroundExecutor.shutdown() backgroundExecutor.shutdown()
backgroundExecutor.awaitTermination(60, TimeUnit.SECONDS) backgroundExecutor.awaitTermination(60, TimeUnit.SECONDS)

View File

@ -5,13 +5,21 @@ import com.google.common.util.concurrent.Futures
import com.r3corda.core.crypto.Party import com.r3corda.core.crypto.Party
import com.r3corda.core.messaging.SingleMessageRecipient import com.r3corda.core.messaging.SingleMessageRecipient
import com.r3corda.core.node.PhysicalLocation import com.r3corda.core.node.PhysicalLocation
import com.r3corda.core.node.services.KeyManagementService
import com.r3corda.core.node.services.ServiceType import com.r3corda.core.node.services.ServiceType
import com.r3corda.core.node.services.WalletService import com.r3corda.core.node.services.WalletService
import com.r3corda.core.testing.InMemoryWalletService import com.r3corda.core.testing.InMemoryWalletService
import com.r3corda.core.utilities.DUMMY_NOTARY_KEY import com.r3corda.core.utilities.DUMMY_NOTARY_KEY
import com.r3corda.core.utilities.loggerFor import com.r3corda.core.utilities.loggerFor
import com.r3corda.testing.node.TestTransactionManager
import com.r3corda.node.services.config.NodeConfiguration import com.r3corda.node.services.config.NodeConfiguration
import com.r3corda.node.services.keys.E2ETestKeyManagementService
import com.r3corda.node.services.network.InMemoryNetworkMapService
import com.r3corda.node.services.network.NetworkMapService
import com.r3corda.node.services.network.NodeRegistration
import com.r3corda.node.utilities.AddOrRemove
import com.r3corda.node.services.transactions.InMemoryUniquenessProvider import com.r3corda.node.services.transactions.InMemoryUniquenessProvider
import org.jetbrains.exposed.sql.transactions.TransactionManager
import org.slf4j.Logger import org.slf4j.Logger
import java.nio.file.Files import java.nio.file.Files
import java.nio.file.Path import java.nio.file.Path
@ -81,10 +89,29 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false,
override fun makeWalletService(): WalletService = InMemoryWalletService(services) override fun makeWalletService(): WalletService = InMemoryWalletService(services)
override fun makeKeyManagementService(): KeyManagementService = E2ETestKeyManagementService(setOf(storage.myLegalIdentityKey))
override fun startMessagingService() { override fun startMessagingService() {
// Nothing to do // Nothing to do
} }
// If the in-memory H2 instance is configured, use that, otherwise mock out the transaction manager.
override fun initialiseDatabasePersistence(insideTransaction: () -> Unit) {
try {
super.initialiseDatabasePersistence(insideTransaction)
} catch(fallback: DatabaseConfigurationException) {
log.info("Using mocked database features.")
TransactionManager.manager = TestTransactionManager()
insideTransaction()
}
}
override fun makeNetworkMapService() {
val expires = platformClock.instant() + NetworkMapService.DEFAULT_EXPIRATION_PERIOD
val reg = NodeRegistration(info, Long.MAX_VALUE, AddOrRemove.ADD, expires)
inNodeNetworkMapService = InMemoryNetworkMapService(net, reg, services.networkMapCache)
}
override fun generateKeyPair(): KeyPair = keyPair ?: super.generateKeyPair() override fun generateKeyPair(): KeyPair = keyPair ?: super.generateKeyPair()
// It's OK to not have a network map service in the mock network. // It's OK to not have a network map service in the mock network.

View File

@ -2,7 +2,6 @@ package com.r3corda.testing.node
import com.google.common.util.concurrent.ListenableFuture import com.google.common.util.concurrent.ListenableFuture
import com.r3corda.core.contracts.Attachment import com.r3corda.core.contracts.Attachment
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.crypto.Party import com.r3corda.core.crypto.Party
import com.r3corda.core.crypto.SecureHash import com.r3corda.core.crypto.SecureHash
import com.r3corda.core.crypto.generateKeyPair import com.r3corda.core.crypto.generateKeyPair
@ -12,6 +11,7 @@ import com.r3corda.core.node.ServiceHub
import com.r3corda.core.node.services.* import com.r3corda.core.node.services.*
import com.r3corda.core.protocols.ProtocolLogic import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.serialization.SingletonSerializeAsToken import com.r3corda.core.serialization.SingletonSerializeAsToken
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.utilities.DUMMY_NOTARY import com.r3corda.core.utilities.DUMMY_NOTARY
import com.r3corda.testing.MEGA_CORP import com.r3corda.testing.MEGA_CORP
import com.r3corda.testing.MINI_CORP import com.r3corda.testing.MINI_CORP
@ -149,7 +149,7 @@ class MockStorageService(override val attachments: AttachmentStorage = MockAttac
fun makeTestDataSourceProperties(nodeName: String = SecureHash.randomSHA256().toString()): Properties { fun makeTestDataSourceProperties(nodeName: String = SecureHash.randomSHA256().toString()): Properties {
val props = Properties() val props = Properties()
props.setProperty("dataSourceClassName", "org.h2.jdbcx.JdbcDataSource") props.setProperty("dataSourceClassName", "org.h2.jdbcx.JdbcDataSource")
props.setProperty("dataSource.url", "jdbc:h2:mem:${nodeName}_persistence") props.setProperty("dataSource.url", "jdbc:h2:mem:${nodeName}_persistence;DB_CLOSE_ON_EXIT=FALSE")
props.setProperty("dataSource.user", "sa") props.setProperty("dataSource.user", "sa")
props.setProperty("dataSource.password", "") props.setProperty("dataSource.password", "")
return props return props

View File

@ -0,0 +1,43 @@
package com.r3corda.testing.node
import org.jetbrains.exposed.sql.Database
import org.jetbrains.exposed.sql.Transaction
import org.jetbrains.exposed.sql.transactions.TransactionInterface
import org.jetbrains.exposed.sql.transactions.TransactionManager
import java.sql.Connection
/**
* A dummy transaction manager used by [MockNode] to avoid uninitialised lateinit var. Any attempt to use this results in an exception.
*/
class TestTransactionManager : TransactionManager {
var current = ThreadLocal<Transaction>()
override fun currentOrNull() = current.get()
override fun newTransaction(isolation: Int): Transaction {
val newTx = Transaction(TestTransactionImpl(this))
current.set(newTx)
return newTx
}
class TestTransactionImpl(val manager: TestTransactionManager) : TransactionInterface {
override val connection: Connection
get() = throw UnsupportedOperationException()
override val db: Database
get() = throw UnsupportedOperationException()
override val outerTransaction: Transaction?
get() = throw UnsupportedOperationException()
override fun close() {
manager.current.set(null)
}
override fun commit() {
}
override fun rollback() {
}
}
}