diff --git a/core/src/test/kotlin/com/r3corda/core/protocols/ResolveTransactionsProtocolTest.kt b/core/src/test/kotlin/com/r3corda/core/protocols/ResolveTransactionsProtocolTest.kt index d91d7ccc19..fcbce44151 100644 --- a/core/src/test/kotlin/com/r3corda/core/protocols/ResolveTransactionsProtocolTest.kt +++ b/core/src/test/kotlin/com/r3corda/core/protocols/ResolveTransactionsProtocolTest.kt @@ -8,6 +8,7 @@ import com.r3corda.core.crypto.SecureHash import com.r3corda.core.node.recordTransactions import com.r3corda.core.serialization.opaque import com.r3corda.core.utilities.DUMMY_NOTARY_KEY +import com.r3corda.node.utilities.databaseTransaction import com.r3corda.testing.node.MockNetwork import com.r3corda.protocols.ResolveTransactionsProtocol import com.r3corda.testing.* @@ -49,8 +50,10 @@ class ResolveTransactionsProtocolTest { net.runNetwork() val results = future.get() assertEquals(listOf(stx1.id, stx2.id), results.map { it.id }) - assertEquals(stx1, b.storage.validatedTransactions.getTransaction(stx1.id)) - assertEquals(stx2, b.storage.validatedTransactions.getTransaction(stx2.id)) + databaseTransaction(b.database) { + assertEquals(stx1, b.storage.validatedTransactions.getTransaction(stx1.id)) + assertEquals(stx2, b.storage.validatedTransactions.getTransaction(stx2.id)) + } } @Test @@ -71,9 +74,11 @@ class ResolveTransactionsProtocolTest { val future = b.services.startProtocol(p) net.runNetwork() future.get() - assertEquals(stx1, b.storage.validatedTransactions.getTransaction(stx1.id)) - // But stx2 wasn't inserted, just stx1. - assertNull(b.storage.validatedTransactions.getTransaction(stx2.id)) + databaseTransaction(b.database) { + assertEquals(stx1, b.storage.validatedTransactions.getTransaction(stx1.id)) + // But stx2 wasn't inserted, just stx1. + assertNull(b.storage.validatedTransactions.getTransaction(stx2.id)) + } } @Test @@ -86,7 +91,9 @@ class ResolveTransactionsProtocolTest { val stx = DummyContract.move(cursor.tx.outRef(0), MINI_CORP_PUBKEY) .addSignatureUnchecked(NullSignature) .toSignedTransaction(false) - a.services.recordTransactions(stx) + databaseTransaction(a.database) { + a.services.recordTransactions(stx) + } cursor = stx } val p = ResolveTransactionsProtocol(setOf(cursor.id), a.info.legalIdentity) @@ -114,7 +121,9 @@ class ResolveTransactionsProtocolTest { toSignedTransaction() } - a.services.recordTransactions(stx2, stx3) + databaseTransaction(a.database) { + a.services.recordTransactions(stx2, stx3) + } val p = ResolveTransactionsProtocol(setOf(stx3.id), a.info.legalIdentity) val future = b.services.startProtocol(p) @@ -148,7 +157,9 @@ class ResolveTransactionsProtocolTest { it.signWith(DUMMY_NOTARY_KEY) it.toSignedTransaction() } - a.services.recordTransactions(dummy1, dummy2) + databaseTransaction(a.database) { + a.services.recordTransactions(dummy1, dummy2) + } return Pair(dummy1, dummy2) } } diff --git a/node/src/main/kotlin/com/r3corda/node/internal/AbstractNode.kt b/node/src/main/kotlin/com/r3corda/node/internal/AbstractNode.kt index a5893b4455..63c13faff5 100644 --- a/node/src/main/kotlin/com/r3corda/node/internal/AbstractNode.kt +++ b/node/src/main/kotlin/com/r3corda/node/internal/AbstractNode.kt @@ -123,7 +123,11 @@ abstract class AbstractNode(open val configuration: NodeConfiguration, val netwo return protocolFactories[markerClass] } - override fun recordTransactions(txs: Iterable) = recordTransactionsInternal(storage, txs) + override fun recordTransactions(txs: Iterable) { + databaseTransaction(database) { + recordTransactionsInternal(storage, txs) + } + } } val info: NodeInfo by lazy { @@ -206,7 +210,7 @@ abstract class AbstractNode(open val configuration: NodeConfiguration, val netwo // the identity key. But the infrastructure to make that easy isn't here yet. keyManagement = makeKeyManagementService() api = APIServerImpl(this@AbstractNode) - scheduler = NodeSchedulerService(services) + scheduler = NodeSchedulerService(database, services) protocolLogicFactory = initialiseProtocolLogicFactory() @@ -440,7 +444,7 @@ abstract class AbstractNode(open val configuration: NodeConfiguration, val netwo protected open fun initialiseStorageService(dir: Path): Pair { val attachments = makeAttachmentStorage(dir) val checkpointStorage = initialiseCheckpointService(dir) - val transactionStorage = PerFileTransactionStorage(dir.resolve("transactions")) + val transactionStorage = DBTransactionStorage() _servicesThatAcceptUploads += attachments // Populate the partyKeys set. obtainKeyPair(dir, PRIVATE_KEY_FILE_NAME, PUBLIC_IDENTITY_FILE_NAME) diff --git a/node/src/main/kotlin/com/r3corda/node/internal/ServerRPCOps.kt b/node/src/main/kotlin/com/r3corda/node/internal/ServerRPCOps.kt index d9df88437f..6c4cfedf4a 100644 --- a/node/src/main/kotlin/com/r3corda/node/internal/ServerRPCOps.kt +++ b/node/src/main/kotlin/com/r3corda/node/internal/ServerRPCOps.kt @@ -6,7 +6,9 @@ import com.r3corda.core.contracts.* import com.r3corda.core.crypto.Party import com.r3corda.core.crypto.toStringShort import com.r3corda.core.node.ServiceHub +import com.r3corda.core.node.services.StateMachineTransactionMapping import com.r3corda.core.node.services.Vault +import com.r3corda.core.transactions.SignedTransaction import com.r3corda.core.transactions.TransactionBuilder import com.r3corda.node.services.messaging.CordaRPCOps import com.r3corda.node.services.messaging.StateMachineInfo @@ -37,7 +39,12 @@ class ServerRPCOps( Pair(vault.states.toList(), updates) } } - override fun verifiedTransactions() = services.storageService.validatedTransactions.track() + override fun verifiedTransactions(): Pair, Observable> { + return databaseTransaction(database) { + services.storageService.validatedTransactions.track() + } + } + override fun stateMachinesAndUpdates(): Pair, Observable> { val (allStateMachines, changes) = smm.track() return Pair( @@ -45,7 +52,11 @@ class ServerRPCOps( changes.map { StateMachineUpdate.fromStateMachineChange(it) } ) } - override fun stateMachineRecordedTransactionMapping() = services.storageService.stateMachineRecordedTransactionMapping.track() + override fun stateMachineRecordedTransactionMapping(): Pair, Observable> { + return databaseTransaction(database) { + services.storageService.stateMachineRecordedTransactionMapping.track() + } + } override fun executeCommand(command: ClientToServiceCommand): TransactionBuildResult { return databaseTransaction(database) { diff --git a/node/src/main/kotlin/com/r3corda/node/services/events/NodeSchedulerService.kt b/node/src/main/kotlin/com/r3corda/node/services/events/NodeSchedulerService.kt index ed953b34b0..df03d70e32 100644 --- a/node/src/main/kotlin/com/r3corda/node/services/events/NodeSchedulerService.kt +++ b/node/src/main/kotlin/com/r3corda/node/services/events/NodeSchedulerService.kt @@ -12,6 +12,8 @@ import com.r3corda.core.utilities.loggerFor import com.r3corda.core.utilities.trace import com.r3corda.node.services.api.ServiceHubInternal import com.r3corda.node.utilities.awaitWithDeadline +import com.r3corda.node.utilities.databaseTransaction +import org.jetbrains.exposed.sql.Database import java.time.Instant import java.util.* import java.util.concurrent.Executor @@ -38,7 +40,8 @@ import javax.annotation.concurrent.ThreadSafe * activity. Only replace this for unit testing purposes. This is not the executor the [ProtocolLogic] is launched on. */ @ThreadSafe -class NodeSchedulerService(private val services: ServiceHubInternal, +class NodeSchedulerService(private val database: Database, + private val services: ServiceHubInternal, private val protocolLogicRefFactory: ProtocolLogicRefFactory = ProtocolLogicRefFactory(), private val schedulerTimerExecutor: Executor = Executors.newSingleThreadExecutor()) : SchedulerService, SingletonSerializeAsToken() { @@ -121,7 +124,9 @@ class NodeSchedulerService(private val services: ServiceHubInternal, private fun onTimeReached(scheduledState: ScheduledStateRef) { try { - runScheduledActionForState(scheduledState) + databaseTransaction(database) { + runScheduledActionForState(scheduledState) + } } finally { // Unschedule once complete (or checkpointed) mutex.locked { diff --git a/node/src/main/kotlin/com/r3corda/node/services/persistence/DBTransactionStorage.kt b/node/src/main/kotlin/com/r3corda/node/services/persistence/DBTransactionStorage.kt new file mode 100644 index 0000000000..2a16f541f7 --- /dev/null +++ b/node/src/main/kotlin/com/r3corda/node/services/persistence/DBTransactionStorage.kt @@ -0,0 +1,64 @@ +package com.r3corda.node.services.persistence + +import com.google.common.annotations.VisibleForTesting +import com.r3corda.core.bufferUntilSubscribed +import com.r3corda.core.crypto.SecureHash +import com.r3corda.core.node.services.TransactionStorage +import com.r3corda.core.transactions.SignedTransaction +import com.r3corda.node.utilities.* +import org.jetbrains.exposed.sql.ResultRow +import org.jetbrains.exposed.sql.statements.InsertStatement +import rx.Observable +import rx.subjects.PublishSubject +import java.util.Collections.synchronizedMap + +class DBTransactionStorage : TransactionStorage { + private object Table : JDBCHashedTable("${NODE_DATABASE_PREFIX}transactions") { + val txId = secureHash("tx_id") + val transaction = blob("transaction") + } + + private class TransactionsMap : AbstractJDBCHashMap(Table, loadOnInit = false) { + override fun keyFromRow(row: ResultRow): SecureHash = row[table.txId] + + override fun valueFromRow(row: ResultRow): SignedTransaction = deserializeFromBlob(row[table.transaction]) + + override fun addKeyToInsert(insert: InsertStatement, entry: Map.Entry, finalizables: MutableList<() -> Unit>) { + insert[table.txId] = entry.key + } + + override fun addValueToInsert(insert: InsertStatement, entry: Map.Entry, finalizables: MutableList<() -> Unit>) { + insert[table.transaction] = serializeToBlob(entry.value, finalizables) + } + } + + private val txStorage = synchronizedMap(TransactionsMap()) + + override fun addTransaction(transaction: SignedTransaction) { + synchronized(txStorage) { + txStorage.put(transaction.id, transaction) + updatesPublisher.onNext(transaction) + } + } + + override fun getTransaction(id: SecureHash): SignedTransaction? { + synchronized(txStorage) { + return txStorage.get(id) + } + } + + val updatesPublisher = PublishSubject.create().toSerialized() + override val updates: Observable + get() = updatesPublisher + + override fun track(): Pair, Observable> { + synchronized(txStorage) { + return Pair(txStorage.values.toList(), updates.bufferUntilSubscribed()) + } + } + + @VisibleForTesting + val transactions: Iterable get() = synchronized(txStorage) { + txStorage.values.toList() + } +} \ No newline at end of file diff --git a/node/src/test/kotlin/com/r3corda/node/messaging/TwoPartyTradeProtocolTests.kt b/node/src/test/kotlin/com/r3corda/node/messaging/TwoPartyTradeProtocolTests.kt index b5a3041475..c035bfa5e2 100644 --- a/node/src/test/kotlin/com/r3corda/node/messaging/TwoPartyTradeProtocolTests.kt +++ b/node/src/test/kotlin/com/r3corda/node/messaging/TwoPartyTradeProtocolTests.kt @@ -21,10 +21,7 @@ import com.r3corda.core.utilities.LogHelper import com.r3corda.core.utilities.TEST_TX_TIME import com.r3corda.node.internal.AbstractNode import com.r3corda.node.services.config.NodeConfiguration -import com.r3corda.node.services.persistence.NodeAttachmentService -import com.r3corda.node.services.persistence.PerFileTransactionStorage -import com.r3corda.node.services.persistence.StorageServiceImpl -import com.r3corda.node.services.persistence.checkpoints +import com.r3corda.node.services.persistence.* import com.r3corda.node.utilities.databaseTransaction import com.r3corda.protocols.TwoPartyTradeProtocol.Buyer import com.r3corda.protocols.TwoPartyTradeProtocol.Seller @@ -32,6 +29,7 @@ import com.r3corda.testing.* import com.r3corda.testing.node.InMemoryMessagingNetwork import com.r3corda.testing.node.MockNetwork import org.assertj.core.api.Assertions.assertThat +import org.jetbrains.exposed.sql.Database import org.junit.After import org.junit.Before import org.junit.Test @@ -126,6 +124,8 @@ class TwoPartyTradeProtocolTests { notaryNode = net.createNotaryNode(DUMMY_NOTARY.name, DUMMY_NOTARY_KEY) aliceNode = net.createPartyNode(notaryNode.info.address, ALICE.name, ALICE_KEY) bobNode = net.createPartyNode(notaryNode.info.address, BOB.name, BOB_KEY) + aliceNode.disableDBCloseOnStop() + bobNode.disableDBCloseOnStop() val aliceKey = aliceNode.services.legalIdentityKey val notaryKey = notaryNode.services.notaryIdentityKey @@ -157,7 +157,14 @@ class TwoPartyTradeProtocolTests { // OK, now Bob has sent the partial transaction back to Alice and is waiting for Alice's signature. assertThat(bobNode.checkpointStorage.checkpoints()).hasSize(1) - val bobTransactionsBeforeCrash = (bobNode.storage.validatedTransactions as PerFileTransactionStorage).transactions + val storage = bobNode.storage.validatedTransactions + val bobTransactionsBeforeCrash = if (storage is PerFileTransactionStorage) { + storage.transactions + } else if (storage is DBTransactionStorage) { + databaseTransaction(bobNode.database) { + storage.transactions + } + } else throw IllegalArgumentException("Unknown storage implementation") assertThat(bobTransactionsBeforeCrash).isNotEmpty() // .. and let's imagine that Bob's computer has a power cut. He now has nothing now beyond what was on disk. @@ -186,12 +193,20 @@ class TwoPartyTradeProtocolTests { assertThat(bobFuture.get()).isEqualTo(aliceFuture.get()) assertThat(bobNode.smm.findStateMachines(Buyer::class.java)).isEmpty() + databaseTransaction(bobNode.database) { + assertThat(bobNode.checkpointStorage.checkpoints()).isEmpty() + } + databaseTransaction(aliceNode.database) { + assertThat(aliceNode.checkpointStorage.checkpoints()).isEmpty() + } - assertThat(bobNode.checkpointStorage.checkpoints()).isEmpty() - assertThat(aliceNode.checkpointStorage.checkpoints()).isEmpty() + databaseTransaction(bobNode.database) { + val restoredBobTransactions = bobTransactionsBeforeCrash.filter { bobNode.storage.validatedTransactions.getTransaction(it.id) != null } + assertThat(restoredBobTransactions).containsAll(bobTransactionsBeforeCrash) + } - val restoredBobTransactions = bobTransactionsBeforeCrash.filter { bobNode.storage.validatedTransactions.getTransaction(it.id) != null } - assertThat(restoredBobTransactions).containsAll(bobTransactionsBeforeCrash) + aliceNode.manuallyCloseDB() + bobNode.manuallyCloseDB() } } @@ -209,7 +224,7 @@ class TwoPartyTradeProtocolTests { transactionStorage: TransactionStorage, stateMachineRecordedTransactionMappingStorage: StateMachineRecordedTransactionMappingStorage ): StorageServiceImpl { - return StorageServiceImpl(attachments, RecordingTransactionStorage(transactionStorage), stateMachineRecordedTransactionMappingStorage) + return StorageServiceImpl(attachments, RecordingTransactionStorage(database, transactionStorage), stateMachineRecordedTransactionMappingStorage) } } } @@ -529,9 +544,11 @@ class TwoPartyTradeProtocolTests { } - class RecordingTransactionStorage(val delegate: TransactionStorage) : TransactionStorage { + class RecordingTransactionStorage(val database: Database, val delegate: TransactionStorage) : TransactionStorage { override fun track(): Pair, Observable> { - return delegate.track() + return databaseTransaction(database) { + delegate.track() + } } val records: MutableList = Collections.synchronizedList(ArrayList()) @@ -539,13 +556,17 @@ class TwoPartyTradeProtocolTests { get() = delegate.updates override fun addTransaction(transaction: SignedTransaction) { - records.add(TxRecord.Add(transaction)) - delegate.addTransaction(transaction) + databaseTransaction(database) { + records.add(TxRecord.Add(transaction)) + delegate.addTransaction(transaction) + } } override fun getTransaction(id: SecureHash): SignedTransaction? { - records.add(TxRecord.Get(id)) - return delegate.getTransaction(id) + return databaseTransaction(database) { + records.add(TxRecord.Get(id)) + delegate.getTransaction(id) + } } } diff --git a/node/src/test/kotlin/com/r3corda/node/services/NodeSchedulerServiceTest.kt b/node/src/test/kotlin/com/r3corda/node/services/NodeSchedulerServiceTest.kt index d941efe38e..c917e96e52 100644 --- a/node/src/test/kotlin/com/r3corda/node/services/NodeSchedulerServiceTest.kt +++ b/node/src/test/kotlin/com/r3corda/node/services/NodeSchedulerServiceTest.kt @@ -85,7 +85,7 @@ class NodeSchedulerServiceTest : SingletonSerializeAsToken() { val dataSourceAndDatabase = configureDatabase(makeTestDataSourceProperties()) dataSource = dataSourceAndDatabase.first val database = dataSourceAndDatabase.second - scheduler = NodeSchedulerService(services, factory, schedulerGatedExecutor) + scheduler = NodeSchedulerService(database, services, factory, schedulerGatedExecutor) smmExecutor = AffinityExecutor.ServiceAffinityExecutor("test", 1) val mockSMM = StateMachineManager(services, listOf(services), PerFileCheckpointStorage(fs.getPath("checkpoints")), smmExecutor, database) mockSMM.changes.subscribe { change -> diff --git a/node/src/test/kotlin/com/r3corda/node/services/persistence/DBTransactionStorageTests.kt b/node/src/test/kotlin/com/r3corda/node/services/persistence/DBTransactionStorageTests.kt new file mode 100644 index 0000000000..2584185081 --- /dev/null +++ b/node/src/test/kotlin/com/r3corda/node/services/persistence/DBTransactionStorageTests.kt @@ -0,0 +1,135 @@ +package com.r3corda.node.services.persistence + +import com.google.common.primitives.Ints +import com.google.common.util.concurrent.SettableFuture +import com.r3corda.core.crypto.DigitalSignature +import com.r3corda.core.crypto.NullPublicKey +import com.r3corda.core.serialization.SerializedBytes +import com.r3corda.core.transactions.SignedTransaction +import com.r3corda.core.utilities.LogHelper +import com.r3corda.node.services.transactions.PersistentUniquenessProvider +import com.r3corda.node.utilities.configureDatabase +import com.r3corda.node.utilities.databaseTransaction +import com.r3corda.testing.node.makeTestDataSourceProperties +import org.assertj.core.api.Assertions.assertThat +import org.jetbrains.exposed.sql.Database +import org.junit.After +import org.junit.Before +import org.junit.Test +import java.io.Closeable +import java.util.concurrent.TimeUnit +import kotlin.test.assertEquals + +class DBTransactionStorageTests { + lateinit var dataSource: Closeable + lateinit var database: Database + lateinit var transactionStorage: DBTransactionStorage + + @Before + fun setUp() { + LogHelper.setLevel(PersistentUniquenessProvider::class) + val dataSourceAndDatabase = configureDatabase(makeTestDataSourceProperties()) + dataSource = dataSourceAndDatabase.first + database = dataSourceAndDatabase.second + newTransactionStorage() + } + + @After + fun cleanUp() { + dataSource.close() + LogHelper.reset(PersistentUniquenessProvider::class) + } + + @Test + fun `empty store`() { + databaseTransaction(database) { + assertThat(transactionStorage.getTransaction(newTransaction().id)).isNull() + } + databaseTransaction(database) { + assertThat(transactionStorage.transactions).isEmpty() + } + newTransactionStorage() + databaseTransaction(database) { + assertThat(transactionStorage.transactions).isEmpty() + } + } + + @Test + fun `one transaction`() { + val transaction = newTransaction() + databaseTransaction(database) { + transactionStorage.addTransaction(transaction) + } + assertTransactionIsRetrievable(transaction) + databaseTransaction(database) { + assertThat(transactionStorage.transactions).containsExactly(transaction) + } + newTransactionStorage() + assertTransactionIsRetrievable(transaction) + databaseTransaction(database) { + assertThat(transactionStorage.transactions).containsExactly(transaction) + } + } + + @Test + fun `two transactions across restart`() { + val firstTransaction = newTransaction() + val secondTransaction = newTransaction() + databaseTransaction(database) { + transactionStorage.addTransaction(firstTransaction) + } + newTransactionStorage() + databaseTransaction(database) { + transactionStorage.addTransaction(secondTransaction) + } + assertTransactionIsRetrievable(firstTransaction) + assertTransactionIsRetrievable(secondTransaction) + databaseTransaction(database) { + assertThat(transactionStorage.transactions).containsOnly(firstTransaction, secondTransaction) + } + } + + @Test + fun `two transactions in same DB transaction scope`() { + val firstTransaction = newTransaction() + val secondTransaction = newTransaction() + databaseTransaction(database) { + transactionStorage.addTransaction(firstTransaction) + transactionStorage.addTransaction(secondTransaction) + } + assertTransactionIsRetrievable(firstTransaction) + assertTransactionIsRetrievable(secondTransaction) + databaseTransaction(database) { + assertThat(transactionStorage.transactions).containsOnly(firstTransaction, secondTransaction) + } + } + + @Test + fun `updates are fired`() { + val future = SettableFuture.create() + transactionStorage.updates.subscribe { tx -> future.set(tx) } + val expected = newTransaction() + databaseTransaction(database) { + transactionStorage.addTransaction(expected) + } + val actual = future.get(1, TimeUnit.SECONDS) + assertEquals(expected, actual) + } + + private fun newTransactionStorage() { + databaseTransaction(database) { + transactionStorage = DBTransactionStorage() + } + } + + private fun assertTransactionIsRetrievable(transaction: SignedTransaction) { + databaseTransaction(database) { + assertThat(transactionStorage.getTransaction(transaction.id)).isEqualTo(transaction) + } + } + + private var txCount = 0 + private fun newTransaction() = SignedTransaction( + SerializedBytes(Ints.toByteArray(++txCount)), + listOf(DigitalSignature.WithKey(NullPublicKey, ByteArray(1)))) +} \ No newline at end of file