From a6bf8e35dd385c6669dc35b7ec2f4f73ef9aab2c Mon Sep 17 00:00:00 2001 From: szymonsztuka Date: Mon, 14 Aug 2017 15:06:06 +0100 Subject: [PATCH 1/3] rewrite few services to use Hibernate * DBTransactionMappingStorage * DBTransactionStorage * DBCheckpointStorage * PersistentUniquenessProvider * PersistentKeyManagementService --- ...bstractPartyToX500NameAsStringConverter.kt | 10 +- .../node/utilities/JDBCHashMapTestSuite.kt | 11 +- .../net/corda/node/internal/AbstractNode.kt | 6 +- .../database/HibernateConfiguration.kt | 2 +- .../keys/PersistentKeyManagementService.kt | 73 ++++----- .../persistence/DBCheckpointStorage.kt | 71 ++++----- .../DBTransactionMappingStorage.kt | 75 +++++----- .../persistence/DBTransactionStorage.kt | 81 +++++----- .../node/services/schema/NodeSchemaService.kt | 20 ++- .../PersistentUniquenessProvider.kt | 138 +++++++++++------- .../node/utilities/AppendOnlyPersistentMap.kt | 113 ++++++++++++++ .../corda/node/utilities/CordaPersistence.kt | 44 +++--- .../utilities/DatabaseTransactionManager.kt | 17 +++ .../node/utilities/NonInvalidatingCache.kt | 33 +++++ .../database/HibernateConfigurationTest.kt | 7 +- .../database/RequeryConfigurationTest.kt | 8 +- .../events/NodeSchedulerServiceTest.kt | 10 +- .../messaging/ArtemisMessagingTests.kt | 4 +- .../persistence/DBCheckpointStorageTests.kt | 14 +- .../persistence/DBTransactionStorageTests.kt | 82 ++++++++++- .../persistence/NodeAttachmentStorageTest.kt | 3 +- .../services/schema/HibernateObserverTests.kt | 6 +- .../DistributedImmutableMapTests.kt | 7 +- .../PersistentUniquenessProviderTests.kt | 8 +- .../node/services/vault/VaultQueryTests.kt | 7 +- .../corda/node/utilities/ObservablesTests.kt | 3 +- .../corda/irs/api/NodeInterestRatesTest.kt | 7 +- .../net/corda/testing/node/MockServices.kt | 10 +- .../net/corda/testing/node/SimpleNode.kt | 2 +- 29 files changed, 563 insertions(+), 309 deletions(-) create mode 100644 node/src/main/kotlin/net/corda/node/utilities/AppendOnlyPersistentMap.kt create mode 100644 node/src/main/kotlin/net/corda/node/utilities/NonInvalidatingCache.kt diff --git a/core/src/main/kotlin/net/corda/core/schemas/converters/AbstractPartyToX500NameAsStringConverter.kt b/core/src/main/kotlin/net/corda/core/schemas/converters/AbstractPartyToX500NameAsStringConverter.kt index 3d88947dbd..9ee9503eef 100644 --- a/core/src/main/kotlin/net/corda/core/schemas/converters/AbstractPartyToX500NameAsStringConverter.kt +++ b/core/src/main/kotlin/net/corda/core/schemas/converters/AbstractPartyToX500NameAsStringConverter.kt @@ -11,18 +11,22 @@ import javax.persistence.Converter * Completely anonymous parties are stored as null (to preserve privacy) */ @Converter(autoApply = true) -class AbstractPartyToX500NameAsStringConverter(val identitySvc: IdentityService) : AttributeConverter { +class AbstractPartyToX500NameAsStringConverter(identitySvc: () -> IdentityService) : AttributeConverter { + + private val identityService: IdentityService by lazy { + identitySvc() + } override fun convertToDatabaseColumn(party: AbstractParty?): String? { party?.let { - return identitySvc.partyFromAnonymous(party)?.toString() + return identityService.partyFromAnonymous(party)?.toString() } return null // non resolvable anonymous parties } override fun convertToEntityAttribute(dbData: String?): AbstractParty? { dbData?.let { - val party = identitySvc.partyFromX500Name(X500Name(dbData)) + val party = identityService.partyFromX500Name(X500Name(dbData)) return party as AbstractParty } return null // non resolvable anonymous parties are stored as nulls diff --git a/node/src/integration-test/kotlin/net/corda/node/utilities/JDBCHashMapTestSuite.kt b/node/src/integration-test/kotlin/net/corda/node/utilities/JDBCHashMapTestSuite.kt index a66304388e..0057ba0470 100644 --- a/node/src/integration-test/kotlin/net/corda/node/utilities/JDBCHashMapTestSuite.kt +++ b/node/src/integration-test/kotlin/net/corda/node/utilities/JDBCHashMapTestSuite.kt @@ -10,18 +10,13 @@ import com.google.common.collect.testing.features.MapFeature import com.google.common.collect.testing.features.SetFeature import com.google.common.collect.testing.testers.* import junit.framework.TestSuite -import net.corda.testing.TestDependencyInjectionBase -import net.corda.testing.initialiseTestSerialization +import net.corda.testing.* import net.corda.testing.node.makeTestDataSourceProperties import net.corda.testing.node.makeTestDatabaseProperties -import net.corda.testing.resetTestSerialization import org.assertj.core.api.Assertions.assertThat -import org.jetbrains.exposed.sql.Transaction -import org.jetbrains.exposed.sql.transactions.TransactionManager import org.junit.* import org.junit.runner.RunWith import org.junit.runners.Suite -import java.sql.Connection import java.util.* @RunWith(Suite::class) @@ -47,7 +42,7 @@ class JDBCHashMapTestSuite { @BeforeClass fun before() { initialiseTestSerialization() - database = configureDatabase(makeTestDataSourceProperties(), makeTestDatabaseProperties()) + database = configureDatabase(makeTestDataSourceProperties(), makeTestDatabaseProperties(), identitySvc = { throw UnsupportedOperationException("Identity Service should not be in use") }) setUpDatabaseTx() loadOnInitFalseMap = JDBCHashMap("test_map_false", loadOnInit = false) memoryConstrainedMap = JDBCHashMap("test_map_constrained", loadOnInit = false, maxBuckets = 1) @@ -233,7 +228,7 @@ class JDBCHashMapTestSuite { @Before fun before() { - database = configureDatabase(makeTestDataSourceProperties(), makeTestDatabaseProperties()) + database = configureDatabase(makeTestDataSourceProperties(), makeTestDatabaseProperties(), identitySvc = { throw UnsupportedOperationException("Identity Service should not be in use") }) } @After diff --git a/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt b/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt index 32caa6f7a3..71848d1616 100644 --- a/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt +++ b/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt @@ -486,7 +486,7 @@ abstract class AbstractNode(open val configuration: NodeConfiguration, private fun makeVaultObservers() { VaultSoftLockManager(services.vaultService, smm) ScheduledActivityObserver(services) - HibernateObserver(services.vaultService.rawUpdates, HibernateConfiguration(services.schemaService, configuration.database ?: Properties(), services.identityService)) + HibernateObserver(services.vaultService.rawUpdates, HibernateConfiguration(services.schemaService, configuration.database ?: Properties(), {services.identityService})) } private fun makeInfo(): NodeInfo { @@ -545,7 +545,7 @@ abstract class AbstractNode(open val configuration: NodeConfiguration, protected open fun initialiseDatabasePersistence(insideTransaction: () -> Unit) { val props = configuration.dataSourceProperties if (props.isNotEmpty()) { - this.database = configureDatabase(props, configuration.database) + this.database = configureDatabase(props, configuration.database, identitySvc = { _services.identityService }) // Now log the vendor string as this will also cause a connection to be tested eagerly. database.transaction { log.info("Connected to ${database.database.vendor} database.") @@ -773,7 +773,7 @@ abstract class AbstractNode(open val configuration: NodeConfiguration, override val networkMapCache by lazy { InMemoryNetworkMapCache(this) } override val vaultService by lazy { NodeVaultService(this, configuration.dataSourceProperties, configuration.database) } override val vaultQueryService by lazy { - HibernateVaultQueryImpl(HibernateConfiguration(schemaService, configuration.database ?: Properties(), identityService), vaultService.updatesPublisher) + HibernateVaultQueryImpl(HibernateConfiguration(schemaService, configuration.database ?: Properties(), { identityService }), vaultService.updatesPublisher) } // Place the long term identity key in the KMS. Eventually, this is likely going to be separated again because // the KMS is meant for derived temporary keys used in transactions, and we're not supposed to sign things with diff --git a/node/src/main/kotlin/net/corda/node/services/database/HibernateConfiguration.kt b/node/src/main/kotlin/net/corda/node/services/database/HibernateConfiguration.kt index 4ab3dc248f..d0f300b9ee 100644 --- a/node/src/main/kotlin/net/corda/node/services/database/HibernateConfiguration.kt +++ b/node/src/main/kotlin/net/corda/node/services/database/HibernateConfiguration.kt @@ -20,7 +20,7 @@ import java.sql.Connection import java.util.* import java.util.concurrent.ConcurrentHashMap -class HibernateConfiguration(val schemaService: SchemaService, val databaseProperties: Properties, val identitySvc: IdentityService) { +class HibernateConfiguration(val schemaService: SchemaService, val databaseProperties: Properties, private val identitySvc: () -> IdentityService) { companion object { val logger = loggerFor() } diff --git a/node/src/main/kotlin/net/corda/node/services/keys/PersistentKeyManagementService.kt b/node/src/main/kotlin/net/corda/node/services/keys/PersistentKeyManagementService.kt index d07e7958fc..4eae474f3f 100644 --- a/node/src/main/kotlin/net/corda/node/services/keys/PersistentKeyManagementService.kt +++ b/node/src/main/kotlin/net/corda/node/services/keys/PersistentKeyManagementService.kt @@ -6,14 +6,13 @@ import net.corda.core.identity.PartyAndCertificate import net.corda.core.internal.ThreadBox import net.corda.core.node.services.IdentityService import net.corda.core.node.services.KeyManagementService -import net.corda.core.serialization.SingletonSerializeAsToken +import net.corda.core.serialization.* import net.corda.node.utilities.* import org.bouncycastle.operator.ContentSigner -import org.jetbrains.exposed.sql.ResultRow -import org.jetbrains.exposed.sql.statements.InsertStatement import java.security.KeyPair import java.security.PrivateKey import java.security.PublicKey +import javax.persistence.* /** * A persistent re-implementation of [E2ETestKeyManagementService] to support node re-start. @@ -25,60 +24,62 @@ import java.security.PublicKey class PersistentKeyManagementService(val identityService: IdentityService, initialKeys: Set) : SingletonSerializeAsToken(), KeyManagementService { - private object Table : JDBCHashedTable("${NODE_DATABASE_PREFIX}our_key_pairs") { - val publicKey = publicKey("public_key") - val privateKey = blob("private_key") - } + @Entity + @javax.persistence.Table(name = "${NODE_DATABASE_PREFIX}our_key_pairs") + class PersistentKey( - private class InnerState { - val keys = object : AbstractJDBCHashMap(Table, loadOnInit = false) { - override fun keyFromRow(row: ResultRow): PublicKey = row[table.publicKey] + @Id + @Column(name = "public_key") + var publicKey: String = "", - override fun valueFromRow(row: ResultRow): PrivateKey = deserializeFromBlob(row[table.privateKey]) + @Lob + @Column(name = "private_key") + var privateKey: ByteArray = ByteArray(0) + ) - override fun addKeyToInsert(insert: InsertStatement, entry: Map.Entry, finalizables: MutableList<() -> Unit>) { - insert[table.publicKey] = entry.key - } - - override fun addValueToInsert(insert: InsertStatement, entry: Map.Entry, finalizables: MutableList<() -> Unit>) { - insert[table.privateKey] = serializeToBlob(entry.value, finalizables) - } + private companion object { + fun createKeyMap(): AppendOnlyPersistentMap { + return AppendOnlyPersistentMap( + toPersistentEntityKey = { it.toBase58String() }, + fromPersistentEntity = { Pair(parsePublicKeyBase58(it.publicKey), + it.privateKey.deserialize(context = SerializationDefaults.STORAGE_CONTEXT)) }, + toPersistentEntity = { key: PublicKey, value: PrivateKey -> + PersistentKey().apply { + publicKey = key.toBase58String() + privateKey = value.serialize(context = SerializationDefaults.STORAGE_CONTEXT).bytes + } + }, + persistentEntityClass = PersistentKey::class.java + ) } } - private val mutex = ThreadBox(InnerState()) + val keysMap = createKeyMap() init { - mutex.locked { - keys.putAll(initialKeys.associate { Pair(it.public, it.private) }) - } + initialKeys.forEach({ it -> keysMap.addWithDuplicatesAllowed(it.public, it.private) }) } - override val keys: Set get() = mutex.locked { keys.keys } + override val keys: Set get() = keysMap.allPersisted().map { it.first }.toSet() - override fun filterMyKeys(candidateKeys: Iterable): Iterable { - return mutex.locked { candidateKeys.filter { it in this.keys } } - } + override fun filterMyKeys(candidateKeys: Iterable): Iterable = + candidateKeys.filter { keysMap[it] != null } override fun freshKey(): PublicKey { val keyPair = generateKeyPair() - mutex.locked { - keys[keyPair.public] = keyPair.private - } + keysMap[keyPair.public] = keyPair.private return keyPair.public } - override fun freshKeyAndCert(identity: PartyAndCertificate, revocationEnabled: Boolean): AnonymousPartyAndPath { - return freshCertificate(identityService, freshKey(), identity, getSigner(identity.owningKey), revocationEnabled) - } + override fun freshKeyAndCert(identity: PartyAndCertificate, revocationEnabled: Boolean): AnonymousPartyAndPath = + freshCertificate(identityService, freshKey(), identity, getSigner(identity.owningKey), revocationEnabled) private fun getSigner(publicKey: PublicKey): ContentSigner = getSigner(getSigningKeyPair(publicKey)) + //It looks for the PublicKey in the (potentially) CompositeKey that is ours, and then returns the associated PrivateKey to use in signing private fun getSigningKeyPair(publicKey: PublicKey): KeyPair { - return mutex.locked { - val pk = publicKey.keys.first { keys.containsKey(it) } - KeyPair(pk, keys[pk]!!) - } + val pk = publicKey.keys.first { keysMap[it] != null } //TODO here for us to re-write this using an actual query if publicKey.keys.size > 1 + return KeyPair(pk, keysMap[pk]!!) } override fun sign(bytes: ByteArray, publicKey: PublicKey): DigitalSignature.WithKey { diff --git a/node/src/main/kotlin/net/corda/node/services/persistence/DBCheckpointStorage.kt b/node/src/main/kotlin/net/corda/node/services/persistence/DBCheckpointStorage.kt index 1d4ccde166..b977f36875 100644 --- a/node/src/main/kotlin/net/corda/node/services/persistence/DBCheckpointStorage.kt +++ b/node/src/main/kotlin/net/corda/node/services/persistence/DBCheckpointStorage.kt @@ -1,57 +1,60 @@ package net.corda.node.services.persistence -import net.corda.core.crypto.SecureHash -import net.corda.core.serialization.SerializationDefaults.CHECKPOINT_CONTEXT -import net.corda.core.serialization.SerializedBytes +import net.corda.core.serialization.SerializationDefaults import net.corda.core.serialization.deserialize import net.corda.core.serialization.serialize import net.corda.node.services.api.Checkpoint import net.corda.node.services.api.CheckpointStorage import net.corda.node.utilities.* -import org.jetbrains.exposed.sql.ResultRow -import org.jetbrains.exposed.sql.statements.InsertStatement -import java.util.Collections.synchronizedMap +import javax.persistence.Column +import javax.persistence.Entity +import javax.persistence.Id +import javax.persistence.Lob /** - * Simple checkpoint key value storage in DB using the underlying JDBCHashMap and transactional context of the call sites. + * Simple checkpoint key value storage in DB. */ class DBCheckpointStorage : CheckpointStorage { - private object Table : JDBCHashedTable("${NODE_DATABASE_PREFIX}checkpoints") { - val checkpointId = secureHash("checkpoint_id") - val checkpoint = blob("checkpoint") - } + @Entity + @javax.persistence.Table(name = "${NODE_DATABASE_PREFIX}checkpoints") + class DBCheckpoint( + @Id + @Column(name = "checkpoint_id", length = 64) + var checkpointId: String = "", - private class CheckpointMap : AbstractJDBCHashMap, Table>(Table, loadOnInit = false) { - override fun keyFromRow(row: ResultRow): SecureHash = row[table.checkpointId] + @Lob + @Column(name = "checkpoint") + var checkpoint: ByteArray = ByteArray(0) + ) - override fun valueFromRow(row: ResultRow): SerializedBytes = bytesFromBlob(row[table.checkpoint]) - - override fun addKeyToInsert(insert: InsertStatement, entry: Map.Entry>, finalizables: MutableList<() -> Unit>) { - insert[table.checkpointId] = entry.key - } - - override fun addValueToInsert(insert: InsertStatement, entry: Map.Entry>, finalizables: MutableList<() -> Unit>) { - insert[table.checkpoint] = bytesToBlob(entry.value, finalizables) - } - } - - private val checkpointStorage = synchronizedMap(CheckpointMap()) - - override fun addCheckpoint(checkpoint: Checkpoint) { - checkpointStorage.put(checkpoint.id, checkpoint.serialize(context = CHECKPOINT_CONTEXT)) + override fun addCheckpoint(value: Checkpoint) { + val session = DatabaseTransactionManager.current().session + session.save(DBCheckpoint().apply { + checkpointId = value.id.toString() + checkpoint = value.serialize(context = SerializationDefaults.CHECKPOINT_CONTEXT).bytes + }) } override fun removeCheckpoint(checkpoint: Checkpoint) { - checkpointStorage.remove(checkpoint.id) ?: throw IllegalArgumentException("Checkpoint not found") + val session = DatabaseTransactionManager.current().session + val criteriaBuilder = session.criteriaBuilder + val delete = criteriaBuilder.createCriteriaDelete(DBCheckpoint::class.java) + val root = delete.from(DBCheckpoint::class.java) + delete.where(criteriaBuilder.equal(root.get(DBCheckpoint::checkpointId.name), checkpoint.id.toString())) + session.createQuery(delete).executeUpdate() } override fun forEach(block: (Checkpoint) -> Boolean) { - synchronized(checkpointStorage) { - for (checkpoint in checkpointStorage.values) { - if (!block(checkpoint.deserialize(context = CHECKPOINT_CONTEXT))) { - break - } + val session = DatabaseTransactionManager.current().session + val criteriaQuery = session.criteriaBuilder.createQuery(DBCheckpoint::class.java) + val root = criteriaQuery.from(DBCheckpoint::class.java) + criteriaQuery.select(root) + val query = session.createQuery(criteriaQuery) + val checkpoints = query.resultList.map { e -> e.checkpoint.deserialize(context = SerializationDefaults.CHECKPOINT_CONTEXT) }.asSequence() + for (e in checkpoints) { + if (!block(e)) { + break } } } diff --git a/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionMappingStorage.kt b/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionMappingStorage.kt index 1e9b9da71e..6c24e81ab0 100644 --- a/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionMappingStorage.kt +++ b/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionMappingStorage.kt @@ -1,6 +1,5 @@ package net.corda.node.services.persistence -import net.corda.core.internal.ThreadBox import net.corda.core.internal.bufferUntilSubscribed import net.corda.core.crypto.SecureHash import net.corda.core.flows.StateMachineRunId @@ -8,59 +7,57 @@ import net.corda.core.messaging.DataFeed import net.corda.core.messaging.StateMachineTransactionMapping import net.corda.node.services.api.StateMachineRecordedTransactionMappingStorage import net.corda.node.utilities.* -import org.jetbrains.exposed.sql.ResultRow -import org.jetbrains.exposed.sql.statements.InsertStatement import rx.subjects.PublishSubject +import java.util.* import javax.annotation.concurrent.ThreadSafe +import javax.persistence.* /** * Database storage of a txhash -> state machine id mapping. * * Mappings are added as transactions are persisted by [ServiceHub.recordTransaction], and never deleted. Used in the * RPC API to correlate transaction creation with flows. - * */ @ThreadSafe class DBTransactionMappingStorage : StateMachineRecordedTransactionMappingStorage { - private object Table : JDBCHashedTable("${NODE_DATABASE_PREFIX}transaction_mappings") { - val txId = secureHash("tx_id") - val stateMachineRunId = uuidString("state_machine_run_id") - } + @Entity + @javax.persistence.Table(name = "${NODE_DATABASE_PREFIX}transaction_mappings") + class DBTransactionMapping( + @Id + @Column(name = "tx_id", length = 64) + var txId: String = "", - private class TransactionMappingsMap : AbstractJDBCHashMap(Table, loadOnInit = false) { - override fun keyFromRow(row: ResultRow): SecureHash = row[table.txId] + @Column(name = "state_machine_run_id", length = 36) + var stateMachineRunId: String = "" + ) - override fun valueFromRow(row: ResultRow): StateMachineRunId = StateMachineRunId(row[table.stateMachineRunId]) - - override fun addKeyToInsert(insert: InsertStatement, entry: Map.Entry, finalizables: MutableList<() -> Unit>) { - insert[table.txId] = entry.key - } - - override fun addValueToInsert(insert: InsertStatement, entry: Map.Entry, finalizables: MutableList<() -> Unit>) { - insert[table.stateMachineRunId] = entry.value.uuid - } - } - - private class InnerState { - val stateMachineTransactionMap = TransactionMappingsMap() - val updates: PublishSubject = PublishSubject.create() - } - private val mutex = ThreadBox(InnerState()) - - override fun addMapping(stateMachineRunId: StateMachineRunId, transactionId: SecureHash) { - mutex.locked { - stateMachineTransactionMap[transactionId] = stateMachineRunId - updates.bufferUntilDatabaseCommit().onNext(StateMachineTransactionMapping(stateMachineRunId, transactionId)) - } - } - - override fun track(): DataFeed, StateMachineTransactionMapping> { - mutex.locked { - return DataFeed( - stateMachineTransactionMap.map { StateMachineTransactionMapping(it.value, it.key) }, - updates.bufferUntilSubscribed().wrapWithDatabaseTransaction() + private companion object { + fun createMap(): AppendOnlyPersistentMap { + return AppendOnlyPersistentMap( + toPersistentEntityKey = { it.toString() }, + fromPersistentEntity = { Pair(SecureHash.parse(it.txId), StateMachineRunId(UUID.fromString(it.stateMachineRunId))) }, + toPersistentEntity = { key: SecureHash, value: StateMachineRunId -> + DBTransactionMapping().apply { + txId = key.toString() + stateMachineRunId = value.uuid.toString() + } + }, + persistentEntityClass = DBTransactionMapping::class.java ) } } + + val stateMachineTransactionMap = createMap() + val updates: PublishSubject = PublishSubject.create() + + override fun addMapping(stateMachineRunId: StateMachineRunId, transactionId: SecureHash) { + stateMachineTransactionMap[transactionId] = stateMachineRunId + updates.bufferUntilDatabaseCommit().onNext(StateMachineTransactionMapping(stateMachineRunId, transactionId)) + } + + override fun track(): DataFeed, StateMachineTransactionMapping> = + DataFeed(stateMachineTransactionMap.allPersisted().map { StateMachineTransactionMapping(it.second, it.first) }.toList(), + updates.bufferUntilSubscribed().wrapWithDatabaseTransaction()) + } diff --git a/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorage.kt b/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorage.kt index dd63f08291..fb637abd4c 100644 --- a/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorage.kt +++ b/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorage.kt @@ -4,73 +4,60 @@ import com.google.common.annotations.VisibleForTesting import net.corda.core.internal.bufferUntilSubscribed import net.corda.core.crypto.SecureHash import net.corda.core.messaging.DataFeed -import net.corda.core.serialization.SingletonSerializeAsToken +import net.corda.core.serialization.* import net.corda.core.transactions.SignedTransaction import net.corda.node.services.api.WritableTransactionStorage import net.corda.node.utilities.* -import org.jetbrains.exposed.sql.ResultRow -import org.jetbrains.exposed.sql.exposedLogger -import org.jetbrains.exposed.sql.statements.InsertStatement import rx.Observable import rx.subjects.PublishSubject -import java.util.Collections.synchronizedMap +import javax.persistence.* class DBTransactionStorage : WritableTransactionStorage, SingletonSerializeAsToken() { - private object Table : JDBCHashedTable("${NODE_DATABASE_PREFIX}transactions") { - val txId = secureHash("tx_id") - val transaction = blob("transaction") - } - private class TransactionsMap : AbstractJDBCHashMap(Table, loadOnInit = false) { - override fun keyFromRow(row: ResultRow): SecureHash = row[table.txId] + @Entity + @Table(name = "${NODE_DATABASE_PREFIX}transactions") + class DBTransaction( + @Id + @Column(name = "tx_id", length = 64) + var txId: String = "", - override fun valueFromRow(row: ResultRow): SignedTransaction = deserializeFromBlob(row[table.transaction]) + @Lob + @Column + var transaction: ByteArray = ByteArray(0) + ) - override fun addKeyToInsert(insert: InsertStatement, entry: Map.Entry, finalizables: MutableList<() -> Unit>) { - insert[table.txId] = entry.key - } - - override fun addValueToInsert(insert: InsertStatement, entry: Map.Entry, finalizables: MutableList<() -> Unit>) { - insert[table.transaction] = serializeToBlob(entry.value, finalizables) + private companion object { + fun createTransactionsMap(): AppendOnlyPersistentMap { + return AppendOnlyPersistentMap( + toPersistentEntityKey = { it.toString() }, + fromPersistentEntity = { Pair(SecureHash.parse(it.txId), + it.transaction.deserialize( context = SerializationDefaults.STORAGE_CONTEXT)) }, + toPersistentEntity = { key: SecureHash, value: SignedTransaction -> + DBTransaction().apply { + txId = key.toString() + transaction = value.serialize(context = SerializationDefaults.STORAGE_CONTEXT).bytes + } + }, + persistentEntityClass = DBTransaction::class.java + ) } } - private val txStorage = synchronizedMap(TransactionsMap()) + private val txStorage = createTransactionsMap() - override fun addTransaction(transaction: SignedTransaction): Boolean { - val recorded = synchronized(txStorage) { - val old = txStorage[transaction.id] - if (old == null) { - txStorage.put(transaction.id, transaction) - updatesPublisher.bufferUntilDatabaseCommit().onNext(transaction) - true - } else { - false - } + override fun addTransaction(transaction: SignedTransaction): Boolean = + txStorage.addWithDuplicatesAllowed(transaction.id, transaction).apply { + updatesPublisher.bufferUntilDatabaseCommit().onNext(transaction) } - if (!recorded) { - exposedLogger.warn("Duplicate recording of transaction ${transaction.id}") - } - return recorded - } - override fun getTransaction(id: SecureHash): SignedTransaction? { - synchronized(txStorage) { - return txStorage[id] - } - } + override fun getTransaction(id: SecureHash): SignedTransaction? = txStorage[id] private val updatesPublisher = PublishSubject.create().toSerialized() override val updates: Observable = updatesPublisher.wrapWithDatabaseTransaction() - override fun track(): DataFeed, SignedTransaction> { - synchronized(txStorage) { - return DataFeed(txStorage.values.toList(), updatesPublisher.bufferUntilSubscribed().wrapWithDatabaseTransaction()) - } - } + override fun track(): DataFeed, SignedTransaction> = + DataFeed(txStorage.allPersisted().map { it.second }.toList(), updatesPublisher.bufferUntilSubscribed().wrapWithDatabaseTransaction()) @VisibleForTesting - val transactions: Iterable get() = synchronized(txStorage) { - txStorage.values.toList() - } + val transactions: Iterable get() = txStorage.allPersisted().map { it.second }.toList() } diff --git a/node/src/main/kotlin/net/corda/node/services/schema/NodeSchemaService.kt b/node/src/main/kotlin/net/corda/node/services/schema/NodeSchemaService.kt index 1483ac4e9a..731b82df36 100644 --- a/node/src/main/kotlin/net/corda/node/services/schema/NodeSchemaService.kt +++ b/node/src/main/kotlin/net/corda/node/services/schema/NodeSchemaService.kt @@ -10,6 +10,11 @@ import net.corda.core.schemas.QueryableState import net.corda.core.serialization.SingletonSerializeAsToken import net.corda.node.services.api.SchemaService import net.corda.core.schemas.CommonSchemaV1 +import net.corda.node.services.keys.PersistentKeyManagementService +import net.corda.node.services.persistence.DBCheckpointStorage +import net.corda.node.services.persistence.DBTransactionMappingStorage +import net.corda.node.services.persistence.DBTransactionStorage +import net.corda.node.services.transactions.PersistentUniquenessProvider import net.corda.node.services.vault.VaultSchemaV1 import net.corda.schemas.CashSchemaV1 @@ -23,14 +28,25 @@ import net.corda.schemas.CashSchemaV1 */ class NodeSchemaService(customSchemas: Set = emptySet()) : SchemaService, SingletonSerializeAsToken() { - // Currently does not support configuring schema options. + // Entities for compulsory services + object NodeServices + + object NodeServicesV1 : MappedSchema(schemaFamily = NodeServices.javaClass, version = 1, + mappedTypes = listOf(DBCheckpointStorage.DBCheckpoint::class.java, + DBTransactionStorage.DBTransaction::class.java, + DBTransactionMappingStorage.DBTransactionMapping::class.java, + PersistentKeyManagementService.PersistentKey::class.java, + PersistentUniquenessProvider.PersistentUniqueness::class.java + )) // Required schemas are those used by internal Corda services // For example, cash is used by the vault for coin selection (but will be extracted as a standalone CorDapp in future) val requiredSchemas: Map = mapOf(Pair(CashSchemaV1, SchemaService.SchemaOptions()), Pair(CommonSchemaV1, SchemaService.SchemaOptions()), - Pair(VaultSchemaV1, SchemaService.SchemaOptions())) + Pair(VaultSchemaV1, SchemaService.SchemaOptions()), + Pair(NodeServicesV1, SchemaService.SchemaOptions())) + override val schemaOptions: Map = requiredSchemas.plus(customSchemas.map { mappedSchema -> Pair(mappedSchema, SchemaService.SchemaOptions()) diff --git a/node/src/main/kotlin/net/corda/node/services/transactions/PersistentUniquenessProvider.kt b/node/src/main/kotlin/net/corda/node/services/transactions/PersistentUniquenessProvider.kt index 4f94975f5e..820f3d5809 100644 --- a/node/src/main/kotlin/net/corda/node/services/transactions/PersistentUniquenessProvider.kt +++ b/node/src/main/kotlin/net/corda/node/services/transactions/PersistentUniquenessProvider.kt @@ -1,82 +1,110 @@ package net.corda.node.services.transactions -import net.corda.core.internal.ThreadBox import net.corda.core.contracts.StateRef import net.corda.core.crypto.SecureHash +import net.corda.core.crypto.parsePublicKeyBase58 import net.corda.core.identity.Party +import net.corda.core.internal.ThreadBox import net.corda.core.node.services.UniquenessException import net.corda.core.node.services.UniquenessProvider import net.corda.core.serialization.SingletonSerializeAsToken import net.corda.core.utilities.loggerFor import net.corda.node.utilities.* import org.bouncycastle.asn1.x500.X500Name -import org.jetbrains.exposed.sql.ResultRow -import org.jetbrains.exposed.sql.statements.InsertStatement +import java.io.Serializable import java.util.* import javax.annotation.concurrent.ThreadSafe +import javax.persistence.* /** A RDBMS backed Uniqueness provider */ @ThreadSafe class PersistentUniquenessProvider : UniquenessProvider, SingletonSerializeAsToken() { + + @Entity + @javax.persistence.Table(name = "${NODE_DATABASE_PREFIX}notary_commit_log") + class PersistentUniqueness ( + + @EmbeddedId + var id: StateRef = StateRef(), + + @Column(name = "consuming_transaction_id") + var consumingTxHash: String = "", + + @Column(name = "consuming_input_index", length = 36) + var consumingIndex: Int = 0, + + @Embedded + var party: Party = Party() + ) { + + @Embeddable + data class StateRef ( + @Column(name = "transaction_id") + var txId: String = "", + + @Column(name = "output_index", length = 36) + var index: Int = 0 + ) : Serializable + + @Embeddable + data class Party ( + @Column(name = "requesting_party_name") + var name: String = "", + + @Column(name = "requesting_party_key", length = 255) + var owningKey: String = "" + ) : Serializable + } + + private class InnerState { + val committedStates = createMap() + } + + private val mutex = ThreadBox(InnerState()) + companion object { - private val TABLE_NAME = "${NODE_DATABASE_PREFIX}notary_commit_log" private val log = loggerFor() - } - /** - * For each input state store the consuming transaction information. - */ - private object Table : JDBCHashedTable(TABLE_NAME) { - val output = stateRef("transaction_id", "output_index") - val consumingTxHash = secureHash("consuming_transaction_id") - val consumingIndex = integer("consuming_input_index") - val requestingParty = party("requesting_party_name", "requesting_party_key") - } - - private val committedStates = ThreadBox(object : AbstractJDBCHashMap(Table, loadOnInit = false) { - override fun keyFromRow(row: ResultRow): StateRef = StateRef(row[table.output.txId], row[table.output.index]) - - override fun valueFromRow(row: ResultRow): UniquenessProvider.ConsumingTx = UniquenessProvider.ConsumingTx( - row[table.consumingTxHash], - row[table.consumingIndex], - Party(X500Name(row[table.requestingParty.name]), row[table.requestingParty.owningKey]) - ) - - override fun addKeyToInsert(insert: InsertStatement, - entry: Map.Entry, - finalizables: MutableList<() -> Unit>) { - insert[table.output.txId] = entry.key.txhash - insert[table.output.index] = entry.key.index + fun createMap(): AppendOnlyPersistentMap { + return AppendOnlyPersistentMap( + toPersistentEntityKey = { PersistentUniqueness.StateRef(it.txhash.toString(), it.index) }, + fromPersistentEntity = { + Pair(StateRef(SecureHash.parse(it.id.txId), it.id.index), + UniquenessProvider.ConsumingTx(SecureHash.parse(it.consumingTxHash), it.consumingIndex, + Party(X500Name(it.party.name), parsePublicKeyBase58(it.party.owningKey)))) + }, + toPersistentEntity = { key: StateRef, value: UniquenessProvider.ConsumingTx -> + PersistentUniqueness().apply { + id = PersistentUniqueness.StateRef(key.txhash.toString(), key.index) + consumingTxHash = value.id.toString() + consumingIndex = value.inputIndex + party = PersistentUniqueness.Party(value.requestingParty.name.toString()) + } + }, + persistentEntityClass = PersistentUniqueness::class.java + ) } - - override fun addValueToInsert(insert: InsertStatement, - entry: Map.Entry, - finalizables: MutableList<() -> Unit>) { - insert[table.consumingTxHash] = entry.value.id - insert[table.consumingIndex] = entry.value.inputIndex - insert[table.requestingParty.name] = entry.value.requestingParty.name.toString() - insert[table.requestingParty.owningKey] = entry.value.requestingParty.owningKey - } - }) + } override fun commit(states: List, txId: SecureHash, callerIdentity: Party) { - val conflict = committedStates.locked { - val conflictingStates = LinkedHashMap() - for (inputState in states) { - val consumingTx = get(inputState) - if (consumingTx != null) conflictingStates[inputState] = consumingTx - } - if (conflictingStates.isNotEmpty()) { - log.debug("Failure, input states already committed: ${conflictingStates.keys}") - UniquenessProvider.Conflict(conflictingStates) - } else { - states.forEachIndexed { i, stateRef -> - put(stateRef, UniquenessProvider.ConsumingTx(txId, i, callerIdentity)) + + val conflict = mutex.locked { + val conflictingStates = LinkedHashMap() + for (inputState in states) { + val consumingTx = committedStates.get(inputState) + if (consumingTx != null) conflictingStates[inputState] = consumingTx + } + if (conflictingStates.isNotEmpty()) { + log.debug("Failure, input states already committed: ${conflictingStates.keys}") + UniquenessProvider.Conflict(conflictingStates) + } else { + states.forEachIndexed { i, stateRef -> + committedStates[stateRef] = UniquenessProvider.ConsumingTx(txId, i, callerIdentity) + } + log.debug("Successfully committed all input states: $states") + null + } } - log.debug("Successfully committed all input states: $states") - null - } - } if (conflict != null) throw UniquenessException(conflict) } diff --git a/node/src/main/kotlin/net/corda/node/utilities/AppendOnlyPersistentMap.kt b/node/src/main/kotlin/net/corda/node/utilities/AppendOnlyPersistentMap.kt new file mode 100644 index 0000000000..c9fb26b8b3 --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/utilities/AppendOnlyPersistentMap.kt @@ -0,0 +1,113 @@ +package net.corda.node.utilities + +import net.corda.core.utilities.loggerFor +import java.util.* + + +/** + * Implements a caching layer on top of an *append-only* table accessed via Hibernate mapping. Note that if the same key is [put] twice the + * behaviour is unpredictable! There is a best-effort check for double inserts, but this should *not* be relied on, so + * ONLY USE THIS IF YOUR TABLE IS APPEND-ONLY + */ +class AppendOnlyPersistentMap ( + val toPersistentEntityKey: (K) -> EK, + val fromPersistentEntity: (E) -> Pair, + val toPersistentEntity: (key: K, value: V) -> E, + val persistentEntityClass: Class, + cacheBound: Long = 1024 +) { //TODO determine cacheBound based on entity class later or with node config allowing tuning, or using some heuristic based on heap size + + private companion object { + val log = loggerFor>() + } + + private val cache = NonInvalidatingCache>( + bound = cacheBound, + concurrencyLevel = 8, + loadFunction = { key -> Optional.ofNullable(loadValue(key)) } + ) + + /** + * Returns the value associated with the key, first loading that value from the storage if necessary. + */ + operator fun get(key: K): V? { + return cache.get(key).orElse(null) + } + + /** + * Returns all key/value pairs from the underlying storage. + */ + fun allPersisted(): Sequence> { + val criteriaQuery = DatabaseTransactionManager.current().session.criteriaBuilder.createQuery(persistentEntityClass) + val root = criteriaQuery.from(persistentEntityClass) + criteriaQuery.select(root) + val query = DatabaseTransactionManager.current().session.createQuery(criteriaQuery) + val result = query.resultList + return result.map { x -> fromPersistentEntity(x) }.asSequence() + } + + private tailrec fun set(key: K, value: V, logWarning: Boolean = true, store: (K,V) -> V?): Boolean { + var insertionAttempt = false + var isUnique = true + val existingInCache = cache.get(key) { // Thread safe, if multiple threads may wait until the first one has loaded. + insertionAttempt = true + // Key wasn't in the cache and might be in the underlying storage. + // Depending on 'store' method, this may insert without checking key duplication or it may avoid inserting a duplicated key. + val existingInDb = store(key, value) + if (existingInDb != null) { // Always reuse an existing value from the storage of a duplicated key. + Optional.of(existingInDb) + } else { + Optional.of(value) + } + } + if (!insertionAttempt) { + if (existingInCache.isPresent) { + // Key already exists in cache, do nothing. + isUnique = false + } else { + // This happens when the key was queried before with no value associated. We invalidate the cached null + // value and recursively call set again. This is to avoid race conditions where another thread queries after + // the invalidate but before the set. + cache.invalidate(key) + return set(key, value, logWarning, store) + } + } + if (logWarning && !isUnique) { + log.warn("Double insert in ${this.javaClass.name} for entity class $persistentEntityClass key $key, not inserting the second time") + } + return isUnique + } + + /** + * Puts the value into the map and the underlying storage. + * Inserting the duplicated key may be unpredictable. + */ + operator fun set(key: K, value: V) = + set(key, value, logWarning = false) { + key,value -> DatabaseTransactionManager.current().session.save(toPersistentEntity(key,value)) + null + } + + /** + * Puts the value into the map and underlying storage. + * Duplicated key is not added into the map and underlying storage. + * @return true if added key was unique, otherwise false + */ + fun addWithDuplicatesAllowed(key: K, value: V): Boolean = + set(key, value) { + key, value -> + val existingEntry = DatabaseTransactionManager.current().session.find(persistentEntityClass, toPersistentEntityKey(key)) + if (existingEntry == null) { + DatabaseTransactionManager.current().session.save(toPersistentEntity(key,value)) + null + } else { + fromPersistentEntity(existingEntry).second + } + } + + private fun loadValue(key: K): V? { + val result = DatabaseTransactionManager.current().session.find(persistentEntityClass, toPersistentEntityKey(key)) + return result?.let(fromPersistentEntity)?.second + } + +} diff --git a/node/src/main/kotlin/net/corda/node/utilities/CordaPersistence.kt b/node/src/main/kotlin/net/corda/node/utilities/CordaPersistence.kt index 1d7e2d79a2..a29fee242e 100644 --- a/node/src/main/kotlin/net/corda/node/utilities/CordaPersistence.kt +++ b/node/src/main/kotlin/net/corda/node/utilities/CordaPersistence.kt @@ -2,6 +2,11 @@ package net.corda.node.utilities import com.zaxxer.hikari.HikariConfig import com.zaxxer.hikari.HikariDataSource +import net.corda.core.node.services.IdentityService +import net.corda.core.schemas.MappedSchema +import net.corda.node.services.database.HibernateConfiguration +import net.corda.node.services.schema.NodeSchemaService +import org.hibernate.SessionFactory import org.jetbrains.exposed.sql.Database import rx.Observable @@ -15,15 +20,21 @@ import java.util.concurrent.CopyOnWriteArrayList //HikariDataSource implements Closeable which allows CordaPersistence to be Closeable -class CordaPersistence(var dataSource: HikariDataSource, databaseProperties: Properties): Closeable { +class CordaPersistence(var dataSource: HikariDataSource, var nodeSchemaService: NodeSchemaService, val identitySvc: ()-> IdentityService, databaseProperties: Properties): Closeable { /** Holds Exposed database, the field will be removed once Exposed library is removed */ lateinit var database: Database var transactionIsolationLevel = parserTransactionIsolationLevel(databaseProperties.getProperty("transactionIsolationLevel")) + val entityManagerFactory: SessionFactory by lazy(LazyThreadSafetyMode.NONE) { + transaction { + HibernateConfiguration(nodeSchemaService, databaseProperties, identitySvc).sessionFactoryForRegisteredSchemas() + } + } + companion object { - fun connect(dataSource: HikariDataSource, databaseProperties: Properties): CordaPersistence { - return CordaPersistence(dataSource, databaseProperties).apply { + fun connect(dataSource: HikariDataSource, nodeSchemaService: NodeSchemaService, identitySvc: () -> IdentityService, databaseProperties: Properties): CordaPersistence { + return CordaPersistence(dataSource, nodeSchemaService, identitySvc, databaseProperties).apply { DatabaseTransactionManager(this) } } @@ -89,10 +100,10 @@ class CordaPersistence(var dataSource: HikariDataSource, databaseProperties: Pro } } -fun configureDatabase(dataSourceProperties: Properties, databaseProperties: Properties?): CordaPersistence { +fun configureDatabase(dataSourceProperties: Properties, databaseProperties: Properties?, entitySchemas: Set = emptySet(), identitySvc: ()-> IdentityService): CordaPersistence { val config = HikariConfig(dataSourceProperties) val dataSource = HikariDataSource(config) - val persistence = CordaPersistence.connect(dataSource, databaseProperties ?: Properties()) + val persistence = CordaPersistence.connect(dataSource, NodeSchemaService(entitySchemas), identitySvc, databaseProperties ?: Properties()) //org.jetbrains.exposed.sql.Database will be removed once Exposed library is removed val database = Database.connect(dataSource) { _ -> ExposedTransactionManager() } @@ -156,7 +167,7 @@ private class DatabaseTransactionWrappingSubscriber(val db: CordaPersistence? } // A subscriber that wraps another but does not pass on observations to it. -private class NoOpSubscriber(t: Subscriber) : Subscriber(t) { +private class NoOpSubscriber(t: Subscriber): Subscriber(t) { override fun onCompleted() { } @@ -191,15 +202,14 @@ fun rx.Observable.wrapWithDatabaseTransaction(db: CordaPersistence? } } - -fun parserTransactionIsolationLevel(property: String?) : Int = - when (property) { - "none" -> Connection.TRANSACTION_NONE - "readUncommitted" -> Connection.TRANSACTION_READ_UNCOMMITTED - "readCommitted" -> Connection.TRANSACTION_READ_COMMITTED - "repeatableRead" -> Connection.TRANSACTION_REPEATABLE_READ - "serializable" -> Connection.TRANSACTION_SERIALIZABLE - else -> { - Connection.TRANSACTION_REPEATABLE_READ +fun parserTransactionIsolationLevel(property: String?): Int = + when (property) { + "none" -> Connection.TRANSACTION_NONE + "readUncommitted" -> Connection.TRANSACTION_READ_UNCOMMITTED + "readCommitted" -> Connection.TRANSACTION_READ_COMMITTED + "repeatableRead" -> Connection.TRANSACTION_REPEATABLE_READ + "serializable" -> Connection.TRANSACTION_SERIALIZABLE + else -> { + Connection.TRANSACTION_REPEATABLE_READ + } } - } diff --git a/node/src/main/kotlin/net/corda/node/utilities/DatabaseTransactionManager.kt b/node/src/main/kotlin/net/corda/node/utilities/DatabaseTransactionManager.kt index c530f956a3..da64097850 100644 --- a/node/src/main/kotlin/net/corda/node/utilities/DatabaseTransactionManager.kt +++ b/node/src/main/kotlin/net/corda/node/utilities/DatabaseTransactionManager.kt @@ -1,6 +1,8 @@ package net.corda.node.utilities import co.paralleluniverse.strands.Strand +import org.hibernate.Session +import org.hibernate.Transaction import rx.subjects.PublishSubject import rx.subjects.Subject import java.sql.Connection @@ -21,13 +23,28 @@ class DatabaseTransaction(isolation: Int, val threadLocal: ThreadLocal private constructor( + val cache: LoadingCache +): LoadingCache by cache { + + constructor(bound: Long, concurrencyLevel: Int, loadFunction: (K) -> V) : + this(buildCache(bound, concurrencyLevel, loadFunction)) + + private companion object { + private fun buildCache(bound: Long, concurrencyLevel: Int, loadFunction: (K) -> V): LoadingCache { + val builder = CacheBuilder.newBuilder().maximumSize(bound).concurrencyLevel(concurrencyLevel) + return builder.build(NonInvalidatingCacheLoader(loadFunction)) + } + } + + // TODO look into overriding loadAll() if we ever use it + private class NonInvalidatingCacheLoader(val loadFunction: (K) -> V) : CacheLoader() { + override fun reload(key: K, oldValue: V): ListenableFuture { + throw IllegalStateException("Non invalidating cache refreshed") + } + override fun load(key: K) = loadFunction(key) + override fun loadAll(keys: Iterable): MutableMap { + return super.loadAll(keys) + } + } +} \ No newline at end of file diff --git a/node/src/test/kotlin/net/corda/node/services/database/HibernateConfigurationTest.kt b/node/src/test/kotlin/net/corda/node/services/database/HibernateConfigurationTest.kt index e4ac1bf76f..923457f44b 100644 --- a/node/src/test/kotlin/net/corda/node/services/database/HibernateConfigurationTest.kt +++ b/node/src/test/kotlin/net/corda/node/services/database/HibernateConfigurationTest.kt @@ -10,7 +10,6 @@ import net.corda.core.schemas.CommonSchemaV1 import net.corda.core.schemas.PersistentStateRef import net.corda.core.serialization.deserialize import net.corda.core.transactions.SignedTransaction -import net.corda.node.services.identity.InMemoryIdentityService import net.corda.node.services.schema.HibernateObserver import net.corda.node.services.schema.NodeSchemaService import net.corda.node.services.vault.VaultSchemaV1 @@ -27,6 +26,7 @@ import net.corda.testing.contracts.fillWithSomeTestLinearStates import net.corda.testing.node.MockServices import net.corda.testing.node.makeTestDataSourceProperties import net.corda.testing.node.makeTestDatabaseProperties +import net.corda.testing.node.makeTestIdentityService import net.corda.testing.schemas.DummyLinearStateSchemaV1 import net.corda.testing.schemas.DummyLinearStateSchemaV2 import org.assertj.core.api.Assertions @@ -65,11 +65,10 @@ class HibernateConfigurationTest : TestDependencyInjectionBase() { issuerServices = MockServices(DUMMY_CASH_ISSUER_KEY, BOB_KEY, BOC_KEY) val dataSourceProps = makeTestDataSourceProperties() val defaultDatabaseProperties = makeTestDatabaseProperties() - database = configureDatabase(dataSourceProps, defaultDatabaseProperties) + database = configureDatabase(dataSourceProps, defaultDatabaseProperties, identitySvc = ::makeTestIdentityService) val customSchemas = setOf(VaultSchemaV1, CashSchemaV1, SampleCashSchemaV2, SampleCashSchemaV3) database.transaction { - val identityService = InMemoryIdentityService(MOCK_IDENTITIES, trustRoot = DUMMY_CA.certificate) - hibernateConfig = HibernateConfiguration(NodeSchemaService(customSchemas), makeTestDatabaseProperties(), identityService) + hibernateConfig = HibernateConfiguration(NodeSchemaService(customSchemas), makeTestDatabaseProperties(), ::makeTestIdentityService) services = object : MockServices(BOB_KEY, BOC_KEY, DUMMY_NOTARY_KEY) { override val vaultService: VaultService = makeVaultService(dataSourceProps, hibernateConfig) diff --git a/node/src/test/kotlin/net/corda/node/services/database/RequeryConfigurationTest.kt b/node/src/test/kotlin/net/corda/node/services/database/RequeryConfigurationTest.kt index 519008ad97..75b5ba5dab 100644 --- a/node/src/test/kotlin/net/corda/node/services/database/RequeryConfigurationTest.kt +++ b/node/src/test/kotlin/net/corda/node/services/database/RequeryConfigurationTest.kt @@ -17,13 +17,11 @@ import net.corda.node.services.vault.schemas.requery.VaultSchema import net.corda.node.services.vault.schemas.requery.VaultStatesEntity import net.corda.node.utilities.CordaPersistence import net.corda.node.utilities.configureDatabase -import net.corda.testing.ALICE_PUBKEY -import net.corda.testing.DUMMY_NOTARY -import net.corda.testing.DUMMY_PUBKEY_1 -import net.corda.testing.TestDependencyInjectionBase +import net.corda.testing.* import net.corda.testing.contracts.DummyContract import net.corda.testing.node.makeTestDataSourceProperties import net.corda.testing.node.makeTestDatabaseProperties +import net.corda.testing.node.makeTestIdentityService import org.assertj.core.api.Assertions import org.junit.After import org.junit.Assert.assertEquals @@ -42,7 +40,7 @@ class RequeryConfigurationTest : TestDependencyInjectionBase() { @Before fun setUp() { val dataSourceProperties = makeTestDataSourceProperties() - database = configureDatabase(dataSourceProperties, makeTestDatabaseProperties()) + database = configureDatabase(dataSourceProperties, makeTestDatabaseProperties(), identitySvc = ::makeTestIdentityService) newTransactionStorage() newRequeryStorage(dataSourceProperties) } diff --git a/node/src/test/kotlin/net/corda/node/services/events/NodeSchedulerServiceTest.kt b/node/src/test/kotlin/net/corda/node/services/events/NodeSchedulerServiceTest.kt index 6b2fb1d039..2480ad4cab 100644 --- a/node/src/test/kotlin/net/corda/node/services/events/NodeSchedulerServiceTest.kt +++ b/node/src/test/kotlin/net/corda/node/services/events/NodeSchedulerServiceTest.kt @@ -10,9 +10,6 @@ import net.corda.core.node.ServiceHub import net.corda.core.node.services.VaultService import net.corda.core.serialization.SingletonSerializeAsToken import net.corda.core.transactions.TransactionBuilder -import net.corda.testing.ALICE_KEY -import net.corda.testing.DUMMY_CA -import net.corda.testing.DUMMY_NOTARY import net.corda.node.services.MockServiceHubInternal import net.corda.node.services.identity.InMemoryIdentityService import net.corda.node.services.persistence.DBCheckpointStorage @@ -22,14 +19,11 @@ import net.corda.node.services.vault.NodeVaultService import net.corda.node.utilities.AffinityExecutor import net.corda.node.utilities.CordaPersistence import net.corda.node.utilities.configureDatabase +import net.corda.testing.* import net.corda.testing.node.InMemoryMessagingNetwork import net.corda.testing.node.MockKeyManagementService -import net.corda.testing.getTestX509Name -import net.corda.testing.testNodeConfiguration -import net.corda.testing.initialiseTestSerialization import net.corda.testing.node.* import net.corda.testing.node.TestClock -import net.corda.testing.resetTestSerialization import org.assertj.core.api.Assertions.assertThat import org.bouncycastle.asn1.x500.X500Name import org.junit.After @@ -77,7 +71,7 @@ class NodeSchedulerServiceTest : SingletonSerializeAsToken() { smmHasRemovedAllFlows = CountDownLatch(1) calls = 0 val dataSourceProps = makeTestDataSourceProperties() - database = configureDatabase(dataSourceProps, makeTestDatabaseProperties()) + database = configureDatabase(dataSourceProps, makeTestDatabaseProperties(), identitySvc = ::makeTestIdentityService) val identityService = InMemoryIdentityService(trustRoot = DUMMY_CA.certificate) val kms = MockKeyManagementService(identityService, ALICE_KEY) diff --git a/node/src/test/kotlin/net/corda/node/services/messaging/ArtemisMessagingTests.kt b/node/src/test/kotlin/net/corda/node/services/messaging/ArtemisMessagingTests.kt index c3dd27595d..0d19c27a65 100644 --- a/node/src/test/kotlin/net/corda/node/services/messaging/ArtemisMessagingTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/messaging/ArtemisMessagingTests.kt @@ -13,6 +13,7 @@ import net.corda.node.services.api.DEFAULT_SESSION_ID import net.corda.node.services.api.MonitoringService import net.corda.node.services.config.NodeConfiguration import net.corda.node.services.config.configureWithDevSSLCertificate +import net.corda.node.services.identity.InMemoryIdentityService import net.corda.node.services.network.InMemoryNetworkMapCache import net.corda.node.services.network.NetworkMapService import net.corda.node.services.transactions.PersistentUniquenessProvider @@ -23,6 +24,7 @@ import net.corda.testing.* import net.corda.testing.node.MOCK_VERSION_INFO import net.corda.testing.node.makeTestDataSourceProperties import net.corda.testing.node.makeTestDatabaseProperties +import net.corda.testing.node.makeTestIdentityService import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThatThrownBy import org.junit.After @@ -69,7 +71,7 @@ class ArtemisMessagingTests : TestDependencyInjectionBase() { baseDirectory = baseDirectory, myLegalName = ALICE.name) LogHelper.setLevel(PersistentUniquenessProvider::class) - database = configureDatabase(makeTestDataSourceProperties(), makeTestDatabaseProperties()) + database = configureDatabase(makeTestDataSourceProperties(), makeTestDatabaseProperties(), identitySvc = ::makeTestIdentityService) networkMapRegistrationFuture = doneFuture(Unit) } diff --git a/node/src/test/kotlin/net/corda/node/services/persistence/DBCheckpointStorageTests.kt b/node/src/test/kotlin/net/corda/node/services/persistence/DBCheckpointStorageTests.kt index 11e99901d8..e060680d41 100644 --- a/node/src/test/kotlin/net/corda/node/services/persistence/DBCheckpointStorageTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/persistence/DBCheckpointStorageTests.kt @@ -11,8 +11,8 @@ import net.corda.testing.LogHelper import net.corda.testing.TestDependencyInjectionBase import net.corda.testing.node.makeTestDataSourceProperties import net.corda.testing.node.makeTestDatabaseProperties +import net.corda.testing.node.makeTestIdentityService import org.assertj.core.api.Assertions.assertThat -import org.assertj.core.api.Assertions.assertThatExceptionOfType import org.junit.After import org.junit.Before import org.junit.Test @@ -33,7 +33,7 @@ class DBCheckpointStorageTests : TestDependencyInjectionBase() { @Before fun setUp() { LogHelper.setLevel(PersistentUniquenessProvider::class) - database = configureDatabase(makeTestDataSourceProperties(), makeTestDatabaseProperties()) + database = configureDatabase(makeTestDataSourceProperties(), makeTestDatabaseProperties(), identitySvc = ::makeTestIdentityService) newCheckpointStorage() } @@ -94,16 +94,6 @@ class DBCheckpointStorageTests : TestDependencyInjectionBase() { } } - @Test - fun `remove unknown checkpoint`() { - val checkpoint = newCheckpoint() - database.transaction { - assertThatExceptionOfType(IllegalArgumentException::class.java).isThrownBy { - checkpointStorage.removeCheckpoint(checkpoint) - } - } - } - @Test fun `add two checkpoints then remove first one`() { val firstCheckpoint = newCheckpoint() diff --git a/node/src/test/kotlin/net/corda/node/services/persistence/DBTransactionStorageTests.kt b/node/src/test/kotlin/net/corda/node/services/persistence/DBTransactionStorageTests.kt index 65aa041280..ba59ceefa1 100644 --- a/node/src/test/kotlin/net/corda/node/services/persistence/DBTransactionStorageTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/persistence/DBTransactionStorageTests.kt @@ -4,19 +4,28 @@ import net.corda.core.contracts.StateRef import net.corda.core.crypto.Crypto import net.corda.core.crypto.SecureHash import net.corda.core.crypto.SignatureMetadata +import net.corda.core.node.services.VaultService import net.corda.core.crypto.TransactionSignature +import net.corda.core.schemas.MappedSchema import net.corda.core.toFuture import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.WireTransaction +import net.corda.node.services.database.HibernateConfiguration +import net.corda.node.services.schema.HibernateObserver +import net.corda.node.services.schema.NodeSchemaService import net.corda.node.services.transactions.PersistentUniquenessProvider +import net.corda.node.services.vault.NodeVaultService +import net.corda.node.services.vault.VaultSchemaV1 import net.corda.node.utilities.CordaPersistence import net.corda.node.utilities.configureDatabase -import net.corda.testing.ALICE_PUBKEY -import net.corda.testing.DUMMY_NOTARY -import net.corda.testing.LogHelper -import net.corda.testing.TestDependencyInjectionBase +import net.corda.schemas.CashSchemaV1 +import net.corda.schemas.SampleCashSchemaV2 +import net.corda.schemas.SampleCashSchemaV3 +import net.corda.testing.* +import net.corda.testing.node.MockServices import net.corda.testing.node.makeTestDataSourceProperties import net.corda.testing.node.makeTestDatabaseProperties +import net.corda.testing.node.makeTestIdentityService import org.assertj.core.api.Assertions.assertThat import org.junit.After import org.junit.Before @@ -27,11 +36,43 @@ import kotlin.test.assertEquals class DBTransactionStorageTests : TestDependencyInjectionBase() { lateinit var database: CordaPersistence lateinit var transactionStorage: DBTransactionStorage + lateinit var services: MockServices + val vault: VaultService get() = services.vaultService + // Hibernate configuration objects + lateinit var hibernateConfig: HibernateConfiguration @Before fun setUp() { LogHelper.setLevel(PersistentUniquenessProvider::class) - database = configureDatabase(makeTestDataSourceProperties(), makeTestDatabaseProperties()) + val dataSourceProps = makeTestDataSourceProperties() + + val transactionSchema = MappedSchema(schemaFamily = javaClass, version = 1, + mappedTypes = listOf(DBTransactionStorage.DBTransaction::class.java)) + + val customSchemas = setOf(VaultSchemaV1, CashSchemaV1, SampleCashSchemaV2, SampleCashSchemaV3, transactionSchema) + + database = configureDatabase(dataSourceProps, makeTestDatabaseProperties(), customSchemas, identitySvc = ::makeTestIdentityService) + + database.transaction { + + hibernateConfig = HibernateConfiguration(NodeSchemaService(customSchemas), makeTestDatabaseProperties(), identitySvc = ::makeTestIdentityService) + + services = object : MockServices(BOB_KEY) { + override val vaultService: VaultService get() { + val vaultService = NodeVaultService(this, dataSourceProps, makeTestDatabaseProperties()) + hibernatePersister = HibernateObserver(vaultService.rawUpdates, hibernateConfig) + return vaultService + } + + override fun recordTransactions(txs: Iterable) { + for (stx in txs) { + validatedTransactions.addTransaction(stx) + } + // Refactored to use notifyAll() as we have no other unit test for that method with multiple transactions. + vaultService.notifyAll(txs.map { it.tx }) + } + } + } newTransactionStorage() } @@ -120,6 +161,37 @@ class DBTransactionStorageTests : TestDependencyInjectionBase() { } } + @Test + fun `transaction saved twice in same DB transaction scope`() { + val firstTransaction = newTransaction() + database.transaction { + transactionStorage.addTransaction(firstTransaction) + transactionStorage.addTransaction(firstTransaction) + } + assertTransactionIsRetrievable(firstTransaction) + database.transaction { + assertThat(transactionStorage.transactions).containsOnly(firstTransaction) + } + } + + @Test + fun `transaction saved twice in two DB transaction scopes`() { + val firstTransaction = newTransaction() + val secondTransaction = newTransaction() + database.transaction { + transactionStorage.addTransaction(firstTransaction) + } + + database.transaction { + transactionStorage.addTransaction(secondTransaction) + transactionStorage.addTransaction(firstTransaction) + } + assertTransactionIsRetrievable(firstTransaction) + database.transaction { + assertThat(transactionStorage.transactions).containsOnly(firstTransaction, secondTransaction) + } + } + @Test fun `updates are fired`() { val future = transactionStorage.updates.toFuture() diff --git a/node/src/test/kotlin/net/corda/node/services/persistence/NodeAttachmentStorageTest.kt b/node/src/test/kotlin/net/corda/node/services/persistence/NodeAttachmentStorageTest.kt index 076db8da53..287c4638cb 100644 --- a/node/src/test/kotlin/net/corda/node/services/persistence/NodeAttachmentStorageTest.kt +++ b/node/src/test/kotlin/net/corda/node/services/persistence/NodeAttachmentStorageTest.kt @@ -17,6 +17,7 @@ import net.corda.node.utilities.CordaPersistence import net.corda.node.utilities.configureDatabase import net.corda.testing.node.makeTestDataSourceProperties import net.corda.testing.node.makeTestDatabaseProperties +import net.corda.testing.node.makeTestIdentityService import org.junit.After import org.junit.Before import org.junit.Test @@ -43,7 +44,7 @@ class NodeAttachmentStorageTest { LogHelper.setLevel(PersistentUniquenessProvider::class) dataSourceProperties = makeTestDataSourceProperties() - database = configureDatabase(dataSourceProperties, makeTestDatabaseProperties()) + database = configureDatabase(dataSourceProperties, makeTestDatabaseProperties(), identitySvc = ::makeTestIdentityService) configuration = RequeryConfiguration(dataSourceProperties, databaseProperties = makeTestDatabaseProperties()) fs = Jimfs.newFileSystem(Configuration.unix()) diff --git a/node/src/test/kotlin/net/corda/node/services/schema/HibernateObserverTests.kt b/node/src/test/kotlin/net/corda/node/services/schema/HibernateObserverTests.kt index ba470599c6..f81aac2b48 100644 --- a/node/src/test/kotlin/net/corda/node/services/schema/HibernateObserverTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/schema/HibernateObserverTests.kt @@ -18,6 +18,7 @@ import net.corda.testing.MEGA_CORP import net.corda.testing.MOCK_IDENTITIES import net.corda.testing.node.makeTestDataSourceProperties import net.corda.testing.node.makeTestDatabaseProperties +import net.corda.testing.node.makeTestIdentityService import org.hibernate.annotations.Cascade import org.hibernate.annotations.CascadeType import org.jetbrains.exposed.sql.transactions.TransactionManager @@ -35,7 +36,7 @@ class HibernateObserverTests { @Before fun setUp() { LogHelper.setLevel(HibernateObserver::class) - database = configureDatabase(makeTestDataSourceProperties(), makeTestDatabaseProperties()) + database = configureDatabase(makeTestDataSourceProperties(), makeTestDatabaseProperties(), identitySvc = ::makeTestIdentityService) } @After @@ -105,8 +106,7 @@ class HibernateObserverTests { } @Suppress("UNUSED_VARIABLE") - val identityService = InMemoryIdentityService(MOCK_IDENTITIES, trustRoot = DUMMY_CA.certificate) - val observer = HibernateObserver(rawUpdatesPublisher, HibernateConfiguration(schemaService, makeTestDatabaseProperties(), identityService)) + val observer = HibernateObserver(rawUpdatesPublisher, HibernateConfiguration(schemaService, makeTestDatabaseProperties(), ::makeTestIdentityService)) database.transaction { rawUpdatesPublisher.onNext(Vault.Update(emptySet(), setOf(StateAndRef(TransactionState(TestState(), MEGA_CORP), StateRef(SecureHash.sha256("dummy"), 0))))) val parentRowCountResult = TransactionManager.current().connection.prepareStatement("select count(*) from Parents").executeQuery() diff --git a/node/src/test/kotlin/net/corda/node/services/transactions/DistributedImmutableMapTests.kt b/node/src/test/kotlin/net/corda/node/services/transactions/DistributedImmutableMapTests.kt index 17aa6b22bc..95082a625b 100644 --- a/node/src/test/kotlin/net/corda/node/services/transactions/DistributedImmutableMapTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/transactions/DistributedImmutableMapTests.kt @@ -11,11 +11,10 @@ import net.corda.core.utilities.getOrThrow import net.corda.node.services.network.NetworkMapService import net.corda.node.utilities.CordaPersistence import net.corda.node.utilities.configureDatabase -import net.corda.testing.LogHelper -import net.corda.testing.TestDependencyInjectionBase -import net.corda.testing.freeLocalHostAndPort +import net.corda.testing.* import net.corda.testing.node.makeTestDataSourceProperties import net.corda.testing.node.makeTestDatabaseProperties +import net.corda.testing.node.makeTestIdentityService import org.jetbrains.exposed.sql.Transaction import org.junit.After import org.junit.Before @@ -35,7 +34,7 @@ class DistributedImmutableMapTests : TestDependencyInjectionBase() { fun setup() { LogHelper.setLevel("-org.apache.activemq") LogHelper.setLevel(NetworkMapService::class) - database = configureDatabase(makeTestDataSourceProperties(), makeTestDatabaseProperties()) + database = configureDatabase(makeTestDataSourceProperties(), makeTestDatabaseProperties(), identitySvc = ::makeTestIdentityService) cluster = setUpCluster() } diff --git a/node/src/test/kotlin/net/corda/node/services/transactions/PersistentUniquenessProviderTests.kt b/node/src/test/kotlin/net/corda/node/services/transactions/PersistentUniquenessProviderTests.kt index f8db1ecee0..57eefb0f22 100644 --- a/node/src/test/kotlin/net/corda/node/services/transactions/PersistentUniquenessProviderTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/transactions/PersistentUniquenessProviderTests.kt @@ -4,12 +4,10 @@ import net.corda.core.crypto.SecureHash import net.corda.core.node.services.UniquenessException import net.corda.node.utilities.CordaPersistence import net.corda.node.utilities.configureDatabase -import net.corda.testing.LogHelper -import net.corda.testing.MEGA_CORP -import net.corda.testing.TestDependencyInjectionBase -import net.corda.testing.generateStateRef +import net.corda.testing.* import net.corda.testing.node.makeTestDataSourceProperties import net.corda.testing.node.makeTestDatabaseProperties +import net.corda.testing.node.makeTestIdentityService import org.junit.After import org.junit.Before import org.junit.Test @@ -25,7 +23,7 @@ class PersistentUniquenessProviderTests : TestDependencyInjectionBase() { @Before fun setUp() { LogHelper.setLevel(PersistentUniquenessProvider::class) - database = configureDatabase(makeTestDataSourceProperties(), makeTestDatabaseProperties()) + database = configureDatabase(makeTestDataSourceProperties(), makeTestDatabaseProperties(), identitySvc = ::makeTestIdentityService) } @After diff --git a/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt b/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt index 3dacf3f2c1..1c8c3dfab2 100644 --- a/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt @@ -15,13 +15,9 @@ import net.corda.core.node.services.* import net.corda.core.node.services.vault.* import net.corda.core.node.services.vault.QueryCriteria.* import net.corda.core.utilities.seconds -import net.corda.core.transactions.SignedTransaction import net.corda.core.utilities.NonEmptySet import net.corda.core.utilities.OpaqueBytes import net.corda.core.utilities.toHexString -import net.corda.node.services.database.HibernateConfiguration -import net.corda.node.services.identity.InMemoryIdentityService -import net.corda.node.services.schema.NodeSchemaService import net.corda.core.utilities.* import net.corda.node.utilities.CordaPersistence import net.corda.node.utilities.configureDatabase @@ -34,6 +30,7 @@ import net.corda.testing.contracts.* import net.corda.testing.node.MockServices import net.corda.testing.node.makeTestDatabaseAndMockServices import net.corda.testing.node.makeTestDatabaseProperties +import net.corda.testing.node.makeTestIdentityService import net.corda.testing.schemas.DummyLinearStateSchemaV1 import org.assertj.core.api.Assertions import org.assertj.core.api.Assertions.assertThat @@ -77,7 +74,7 @@ class VaultQueryTests : TestDependencyInjectionBase() { @Ignore @Test fun createPersistentTestDb() { - val database = configureDatabase(makePersistentDataSourceProperties(), makeTestDatabaseProperties()) + val database = configureDatabase(makePersistentDataSourceProperties(), makeTestDatabaseProperties(), identitySvc = ::makeTestIdentityService) setUpDb(database, 5000) diff --git a/node/src/test/kotlin/net/corda/node/utilities/ObservablesTests.kt b/node/src/test/kotlin/net/corda/node/utilities/ObservablesTests.kt index f330411d3f..0d99e3f3b8 100644 --- a/node/src/test/kotlin/net/corda/node/utilities/ObservablesTests.kt +++ b/node/src/test/kotlin/net/corda/node/utilities/ObservablesTests.kt @@ -5,6 +5,7 @@ import net.corda.core.internal.bufferUntilSubscribed import net.corda.core.internal.tee import net.corda.testing.node.makeTestDataSourceProperties import net.corda.testing.node.makeTestDatabaseProperties +import net.corda.testing.node.makeTestIdentityService import org.assertj.core.api.Assertions.assertThat import org.junit.After import org.junit.Test @@ -20,7 +21,7 @@ class ObservablesTests { val toBeClosed = mutableListOf() fun createDatabase(): CordaPersistence { - val database = configureDatabase(makeTestDataSourceProperties(), makeTestDatabaseProperties()) + val database = configureDatabase(makeTestDataSourceProperties(), makeTestDatabaseProperties(), identitySvc = ::makeTestIdentityService) toBeClosed += database return database } diff --git a/samples/irs-demo/src/test/kotlin/net/corda/irs/api/NodeInterestRatesTest.kt b/samples/irs-demo/src/test/kotlin/net/corda/irs/api/NodeInterestRatesTest.kt index 5898271c45..c501759e09 100644 --- a/samples/irs-demo/src/test/kotlin/net/corda/irs/api/NodeInterestRatesTest.kt +++ b/samples/irs-demo/src/test/kotlin/net/corda/irs/api/NodeInterestRatesTest.kt @@ -18,10 +18,7 @@ import net.corda.irs.flows.RatesFixFlow import net.corda.node.utilities.CordaPersistence import net.corda.node.utilities.configureDatabase import net.corda.testing.* -import net.corda.testing.node.MockNetwork -import net.corda.testing.node.MockServices -import net.corda.testing.node.makeTestDataSourceProperties -import net.corda.testing.node.makeTestDatabaseProperties +import net.corda.testing.node.* import org.bouncycastle.asn1.x500.X500Name import org.junit.After import org.junit.Assert @@ -60,7 +57,7 @@ class NodeInterestRatesTest : TestDependencyInjectionBase() { @Before fun setUp() { - database = configureDatabase(makeTestDataSourceProperties(), makeTestDatabaseProperties()) + database = configureDatabase(makeTestDataSourceProperties(), makeTestDatabaseProperties(), identitySvc = ::makeTestIdentityService) database.transaction { oracle = NodeInterestRates.Oracle( MEGA_CORP, diff --git a/test-utils/src/main/kotlin/net/corda/testing/node/MockServices.kt b/test-utils/src/main/kotlin/net/corda/testing/node/MockServices.kt index 001007f968..d0635b4205 100644 --- a/test-utils/src/main/kotlin/net/corda/testing/node/MockServices.kt +++ b/test-utils/src/main/kotlin/net/corda/testing/node/MockServices.kt @@ -88,7 +88,7 @@ open class MockServices(vararg val keys: KeyPair) : ServiceHub { lateinit var hibernatePersister: HibernateObserver - fun makeVaultService(dataSourceProps: Properties, hibernateConfig: HibernateConfiguration = HibernateConfiguration(NodeSchemaService(), makeTestDatabaseProperties(), identityService)): VaultService { + fun makeVaultService(dataSourceProps: Properties, hibernateConfig: HibernateConfiguration = HibernateConfiguration(NodeSchemaService(), makeTestDatabaseProperties(), { identityService })): VaultService { val vaultService = NodeVaultService(this, dataSourceProps, makeTestDatabaseProperties()) hibernatePersister = HibernateObserver(vaultService.rawUpdates, hibernateConfig) return vaultService @@ -216,13 +216,15 @@ fun makeTestDatabaseProperties(): Properties { return props } +fun makeTestIdentityService() = InMemoryIdentityService(MOCK_IDENTITIES, trustRoot = DUMMY_CA.certificate) + fun makeTestDatabaseAndMockServices(customSchemas: Set = setOf(CommercialPaperSchemaV1, DummyLinearStateSchemaV1, CashSchemaV1), keys: List = listOf(MEGA_CORP_KEY)): Pair { val dataSourceProps = makeTestDataSourceProperties() val databaseProperties = makeTestDatabaseProperties() - val database = configureDatabase(dataSourceProps, databaseProperties) + + val database = configureDatabase(dataSourceProps, databaseProperties, identitySvc = ::makeTestIdentityService) val mockService = database.transaction { - val identityService = InMemoryIdentityService(MOCK_IDENTITIES, trustRoot = DUMMY_CA.certificate) - val hibernateConfig = HibernateConfiguration(NodeSchemaService(customSchemas), databaseProperties, identityService) + val hibernateConfig = HibernateConfiguration(NodeSchemaService(customSchemas), databaseProperties, identitySvc = ::makeTestIdentityService) object : MockServices(*(keys.toTypedArray())) { override val vaultService: VaultService = makeVaultService(dataSourceProps, hibernateConfig) diff --git a/test-utils/src/main/kotlin/net/corda/testing/node/SimpleNode.kt b/test-utils/src/main/kotlin/net/corda/testing/node/SimpleNode.kt index 641396028c..0f55f1bdb3 100644 --- a/test-utils/src/main/kotlin/net/corda/testing/node/SimpleNode.kt +++ b/test-utils/src/main/kotlin/net/corda/testing/node/SimpleNode.kt @@ -32,11 +32,11 @@ class SimpleNode(val config: NodeConfiguration, val address: NetworkHostAndPort rpcAddress: NetworkHostAndPort = freeLocalHostAndPort(), trustRoot: X509Certificate) : AutoCloseable { - val database: CordaPersistence = configureDatabase(config.dataSourceProperties, config.database) val userService = RPCUserServiceImpl(config.rpcUsers) val monitoringService = MonitoringService(MetricRegistry()) val identity: KeyPair = generateKeyPair() val identityService: IdentityService = InMemoryIdentityService(trustRoot = trustRoot) + val database: CordaPersistence = configureDatabase(config.dataSourceProperties, config.database, identitySvc = {InMemoryIdentityService(trustRoot = trustRoot)}) val keyService: KeyManagementService = E2ETestKeyManagementService(identityService, setOf(identity)) val executor = ServiceAffinityExecutor(config.myLegalName.commonName, 1) // TODO: We should have a dummy service hub rather than change behaviour in tests From 3407cd4580885e8344f76e208c59f730487ec895 Mon Sep 17 00:00:00 2001 From: Shams Asari Date: Mon, 14 Aug 2017 13:41:52 +0100 Subject: [PATCH 2/3] Fix to test in InMemoryIdentityServiceTests --- .../network/InMemoryIdentityServiceTests.kt | 25 ++++++------------- .../corda/testing/SerializationTestHelpers.kt | 2 +- 2 files changed, 9 insertions(+), 18 deletions(-) diff --git a/node/src/test/kotlin/net/corda/node/services/network/InMemoryIdentityServiceTests.kt b/node/src/test/kotlin/net/corda/node/services/network/InMemoryIdentityServiceTests.kt index b425fdcc8e..4ef14ae0e4 100644 --- a/node/src/test/kotlin/net/corda/node/services/network/InMemoryIdentityServiceTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/network/InMemoryIdentityServiceTests.kt @@ -115,12 +115,12 @@ class InMemoryIdentityServiceTests { service.verifyAndRegisterAnonymousIdentity(aliceTxIdentity, alice.party) var actual = service.anonymousFromKey(aliceTxIdentity.party.owningKey) - assertEquals(aliceTxIdentity, actual!!) + assertEquals(aliceTxIdentity, actual!!) assertNull(service.anonymousFromKey(bobTxIdentity.party.owningKey)) service.verifyAndRegisterAnonymousIdentity(bobTxIdentity, bob.party) actual = service.anonymousFromKey(bobTxIdentity.party.owningKey) - assertEquals(bobTxIdentity, actual!!) + assertEquals(bobTxIdentity, actual!!) } /** @@ -131,37 +131,28 @@ class InMemoryIdentityServiceTests { fun `assert ownership`() { withTestSerialization { val trustRoot = DUMMY_CA - val (alice, aliceTxIdentity) = createParty(ALICE.name, trustRoot) - - val certFactory = CertificateFactory.getInstance("X509") - val bobRootKey = Crypto.generateKeyPair() - val bobRoot = getTestPartyAndCertificate(BOB.name, bobRootKey.public) - val bobRootCert = bobRoot.certificate - val bobTxKey = Crypto.generateKeyPair() - val bobTxCert = X509Utilities.createCertificate(CertificateType.IDENTITY, bobRootCert, bobRootKey, BOB.name, bobTxKey.public) - val bobCertPath = certFactory.generateCertPath(listOf(bobTxCert.cert, bobRootCert.cert)) - val bob = PartyAndCertificate(BOB.name, bobRootKey.public, bobRootCert, bobCertPath) + val (alice, anonymousAlice) = createParty(ALICE.name, trustRoot) + val (bob, anonymousBob) = createParty(BOB.name, trustRoot) // Now we have identities, construct the service and let it know about both val service = InMemoryIdentityService(setOf(alice, bob), emptyMap(), trustRoot.certificate.cert) - service.verifyAndRegisterAnonymousIdentity(aliceTxIdentity, alice.party) - val anonymousBob = AnonymousPartyAndPath(AnonymousParty(bobTxKey.public),bobCertPath) + service.verifyAndRegisterAnonymousIdentity(anonymousAlice, alice.party) service.verifyAndRegisterAnonymousIdentity(anonymousBob, bob.party) // Verify that paths are verified - service.assertOwnership(alice.party, aliceTxIdentity.party) + service.assertOwnership(alice.party, anonymousAlice.party) service.assertOwnership(bob.party, anonymousBob.party) assertFailsWith { service.assertOwnership(alice.party, anonymousBob.party) } assertFailsWith { - service.assertOwnership(bob.party, aliceTxIdentity.party) + service.assertOwnership(bob.party, anonymousAlice.party) } assertFailsWith { val owningKey = Crypto.decodePublicKey(trustRoot.certificate.subjectPublicKeyInfo.encoded) - service.assertOwnership(Party(trustRoot.certificate.subject, owningKey), aliceTxIdentity.party) + service.assertOwnership(Party(trustRoot.certificate.subject, owningKey), anonymousAlice.party) } } } diff --git a/test-utils/src/main/kotlin/net/corda/testing/SerializationTestHelpers.kt b/test-utils/src/main/kotlin/net/corda/testing/SerializationTestHelpers.kt index 82d7925fec..c250214420 100644 --- a/test-utils/src/main/kotlin/net/corda/testing/SerializationTestHelpers.kt +++ b/test-utils/src/main/kotlin/net/corda/testing/SerializationTestHelpers.kt @@ -6,7 +6,7 @@ import net.corda.core.utilities.ByteSequence import net.corda.node.serialization.KryoServerSerializationScheme import net.corda.nodeapi.internal.serialization.* -fun withTestSerialization(block: () -> T): T { +inline fun withTestSerialization(block: () -> T): T { initialiseTestSerialization() try { return block() From 2f08425c43776dc605b0c8d8319648ec0d6b3c5c Mon Sep 17 00:00:00 2001 From: Rick Parker Date: Mon, 14 Aug 2017 17:24:04 +0100 Subject: [PATCH 3/3] Refactor KryoAMQPSerializer to go through generic APIs to access AMQP serialization (#1225) --- .../net/corda/client/rpc/CordaRPCClient.kt | 2 +- .../rpc/serialization/SerializationScheme.kt | 5 +- .../core/serialization/SerializationAPI.kt | 5 ++ .../kotlin/net/corda/nodeapi/RPCStructures.kt | 4 +- .../serialization/AMQPSerializationScheme.kt | 17 ++++- .../serialization/CordaClassResolver.kt | 40 ++++++----- .../serialization/DefaultKryoCustomizer.kt | 5 +- .../serialization/KryoAMQPSerializer.kt | 29 ++------ .../serialization/SerializationScheme.kt | 9 +-- .../serialization/CordaClassResolverTests.kt | 69 +++++++++++-------- .../internal/serialization/KryoTests.kt | 7 +- .../serialization/SerializationTokenTest.kt | 7 +- .../amqp/SerializationOutputTests.kt | 6 +- .../kotlin/net/corda/node/internal/Node.kt | 2 +- .../node/serialization/SerializationScheme.kt | 5 +- .../corda/node/utilities/X509UtilitiesTest.kt | 4 +- .../corda/testing/SerializationTestHelpers.kt | 8 ++- .../kotlin/net/corda/verifier/Verifier.kt | 5 +- 18 files changed, 129 insertions(+), 100 deletions(-) diff --git a/client/rpc/src/main/kotlin/net/corda/client/rpc/CordaRPCClient.kt b/client/rpc/src/main/kotlin/net/corda/client/rpc/CordaRPCClient.kt index a0a4a7ea1f..3ec8268c3e 100644 --- a/client/rpc/src/main/kotlin/net/corda/client/rpc/CordaRPCClient.kt +++ b/client/rpc/src/main/kotlin/net/corda/client/rpc/CordaRPCClient.kt @@ -71,7 +71,7 @@ class CordaRPCClient( fun initialiseSerialization() { try { SerializationDefaults.SERIALIZATION_FACTORY = SerializationFactoryImpl().apply { - registerScheme(KryoClientSerializationScheme()) + registerScheme(KryoClientSerializationScheme(this)) registerScheme(AMQPClientSerializationScheme()) } SerializationDefaults.P2P_CONTEXT = KRYO_P2P_CONTEXT diff --git a/client/rpc/src/main/kotlin/net/corda/client/rpc/serialization/SerializationScheme.kt b/client/rpc/src/main/kotlin/net/corda/client/rpc/serialization/SerializationScheme.kt index 667cf76539..0bb26b93fb 100644 --- a/client/rpc/src/main/kotlin/net/corda/client/rpc/serialization/SerializationScheme.kt +++ b/client/rpc/src/main/kotlin/net/corda/client/rpc/serialization/SerializationScheme.kt @@ -3,20 +3,21 @@ package net.corda.client.rpc.serialization import com.esotericsoftware.kryo.pool.KryoPool import net.corda.client.rpc.internal.RpcClientObservableSerializer import net.corda.core.serialization.SerializationContext +import net.corda.core.serialization.SerializationFactory import net.corda.core.utilities.ByteSequence import net.corda.nodeapi.RPCKryo import net.corda.nodeapi.internal.serialization.AbstractKryoSerializationScheme import net.corda.nodeapi.internal.serialization.DefaultKryoCustomizer import net.corda.nodeapi.internal.serialization.KryoHeaderV0_1 -class KryoClientSerializationScheme : AbstractKryoSerializationScheme() { +class KryoClientSerializationScheme(serializationFactory: SerializationFactory) : AbstractKryoSerializationScheme(serializationFactory) { override fun canDeserializeVersion(byteSequence: ByteSequence, target: SerializationContext.UseCase): Boolean { return byteSequence == KryoHeaderV0_1 && (target == SerializationContext.UseCase.RPCClient || target == SerializationContext.UseCase.P2P) } override fun rpcClientKryoPool(context: SerializationContext): KryoPool { return KryoPool.Builder { - DefaultKryoCustomizer.customize(RPCKryo(RpcClientObservableSerializer, context.whitelist)).apply { classLoader = context.deserializationClassLoader } + DefaultKryoCustomizer.customize(RPCKryo(RpcClientObservableSerializer, serializationFactory, context)).apply { classLoader = context.deserializationClassLoader } }.build() } diff --git a/core/src/main/kotlin/net/corda/core/serialization/SerializationAPI.kt b/core/src/main/kotlin/net/corda/core/serialization/SerializationAPI.kt index 59c0d6a9d6..7dab1e0243 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/SerializationAPI.kt +++ b/core/src/main/kotlin/net/corda/core/serialization/SerializationAPI.kt @@ -81,6 +81,11 @@ interface SerializationContext { */ fun withWhitelisted(clazz: Class<*>): SerializationContext + /** + * Helper method to return a new context based on this context but with serialization using the format this header sequence represents. + */ + fun withPreferredSerializationVersion(versionHeader: ByteSequence): SerializationContext + /** * The use case that we are serializing for, since it influences the implementations chosen. */ diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/RPCStructures.kt b/node-api/src/main/kotlin/net/corda/nodeapi/RPCStructures.kt index 7e8b11c3a4..a4be40c829 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/RPCStructures.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/RPCStructures.kt @@ -8,6 +8,8 @@ import net.corda.core.concurrent.CordaFuture import net.corda.core.CordaRuntimeException import net.corda.core.serialization.ClassWhitelist import net.corda.core.serialization.CordaSerializable +import net.corda.core.serialization.SerializationContext +import net.corda.core.serialization.SerializationFactory import net.corda.core.toFuture import net.corda.core.toObservable import net.corda.nodeapi.config.OldConfig @@ -47,7 +49,7 @@ class PermissionException(msg: String) : RuntimeException(msg) // The Kryo used for the RPC wire protocol. Every type in the wire protocol is listed here explicitly. // This is annoying to write out, but will make it easier to formalise the wire protocol when the time comes, // because we can see everything we're using in one place. -class RPCKryo(observableSerializer: Serializer>, whitelist: ClassWhitelist) : CordaKryo(CordaClassResolver(whitelist)) { +class RPCKryo(observableSerializer: Serializer>, val serializationFactory: SerializationFactory, val serializationContext: SerializationContext) : CordaKryo(CordaClassResolver(serializationFactory, serializationContext)) { init { DefaultKryoCustomizer.customize(this) diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/AMQPSerializationScheme.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/AMQPSerializationScheme.kt index 9772bef28a..3d1ca75f34 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/AMQPSerializationScheme.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/AMQPSerializationScheme.kt @@ -11,9 +11,22 @@ import net.corda.nodeapi.internal.serialization.amqp.SerializationOutput import net.corda.nodeapi.internal.serialization.amqp.SerializerFactory import java.util.concurrent.ConcurrentHashMap -private const val AMQP_ENABLED = false +internal val AMQP_ENABLED get() = SerializationDefaults.P2P_CONTEXT.preferedSerializationVersion == AmqpHeaderV1_0 abstract class AbstractAMQPSerializationScheme : SerializationScheme { + internal companion object { + fun registerCustomSerializers(factory: SerializerFactory) { + factory.apply { + register(net.corda.nodeapi.internal.serialization.amqp.custom.PublicKeySerializer) + register(net.corda.nodeapi.internal.serialization.amqp.custom.ThrowableSerializer(this)) + register(net.corda.nodeapi.internal.serialization.amqp.custom.X500NameSerializer) + register(net.corda.nodeapi.internal.serialization.amqp.custom.BigDecimalSerializer) + register(net.corda.nodeapi.internal.serialization.amqp.custom.CurrencySerializer) + register(net.corda.nodeapi.internal.serialization.amqp.custom.InstantSerializer(this)) + } + } + } + private val serializerFactoriesForContexts = ConcurrentHashMap, SerializerFactory>() protected abstract fun rpcClientSerializerFactory(context: SerializationContext): SerializerFactory @@ -30,7 +43,7 @@ abstract class AbstractAMQPSerializationScheme : SerializationScheme { rpcServerSerializerFactory(context) else -> SerializerFactory(context.whitelist) // TODO pass class loader also } - } + }.also { registerCustomSerializers(it) } } override fun deserialize(byteSequence: ByteSequence, clazz: Class, context: SerializationContext): T { diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/CordaClassResolver.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/CordaClassResolver.kt index 02c0e32be0..9c180148fd 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/CordaClassResolver.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/CordaClassResolver.kt @@ -6,10 +6,9 @@ import com.esotericsoftware.kryo.io.Output import com.esotericsoftware.kryo.serializers.FieldSerializer import com.esotericsoftware.kryo.util.DefaultClassResolver import com.esotericsoftware.kryo.util.Util -import net.corda.core.serialization.AttachmentsClassLoader -import net.corda.core.serialization.ClassWhitelist -import net.corda.core.serialization.CordaSerializable +import net.corda.core.serialization.* import net.corda.core.utilities.loggerFor +import net.corda.nodeapi.internal.serialization.amqp.AmqpHeaderV1_0 import java.io.PrintWriter import java.lang.reflect.Modifier.isAbstract import java.nio.charset.StandardCharsets @@ -22,23 +21,13 @@ fun Kryo.addToWhitelist(type: Class<*>) { ((classResolver as? CordaClassResolver)?.whitelist as? MutableClassWhitelist)?.add(type) } -fun makeStandardClassResolver(): ClassResolver { - return CordaClassResolver(GlobalTransientClassWhiteList(BuiltInExceptionsWhitelist())) -} - -fun makeNoWhitelistClassResolver(): ClassResolver { - return CordaClassResolver(AllWhitelist) -} - -fun makeAllButBlacklistedClassResolver(): ClassResolver { - return CordaClassResolver(AllButBlacklisted) -} - /** * @param amqpEnabled Setting this to true turns on experimental AMQP serialization for any class annotated with * [CordaSerializable]. */ -class CordaClassResolver(val whitelist: ClassWhitelist, val amqpEnabled: Boolean = false) : DefaultClassResolver() { +class CordaClassResolver(val serializationFactory: SerializationFactory, val serializationContext: SerializationContext) : DefaultClassResolver() { + val whitelist: ClassWhitelist = TransientClassWhiteList(serializationContext.whitelist) + /** Returns the registration for the specified class, or null if the class is not registered. */ override fun getRegistration(type: Class<*>): Registration? { return super.getRegistration(type) ?: checkClass(type) @@ -78,9 +67,9 @@ class CordaClassResolver(val whitelist: ClassWhitelist, val amqpEnabled: Boolean // If something is not annotated, or AMQP is disabled, we stay serializing with Kryo. This will typically be the // case for flow checkpoints (ignoring all cases where AMQP is disabled) since our top level messaging data structures // are annotated and once we enter AMQP serialisation we stay with it for the entire object subgraph. - if (checkForAnnotation(type) && amqpEnabled) { + if (checkForAnnotation(type) && AMQP_ENABLED) { // Build AMQP serializer - return register(Registration(type, KryoAMQPSerializer, NAME.toInt())) + return register(Registration(type, KryoAMQPSerializer(serializationFactory, serializationContext), NAME.toInt())) } val objectInstance = try { @@ -179,6 +168,21 @@ class GlobalTransientClassWhiteList(val delegate: ClassWhitelist) : MutableClass } } +/** + * A whitelist that can be customised via the [CordaPluginRegistry], since implements [MutableClassWhitelist]. + */ +class TransientClassWhiteList(val delegate: ClassWhitelist) : MutableClassWhitelist, ClassWhitelist by delegate { + val whitelist: MutableSet = Collections.synchronizedSet(mutableSetOf()) + + override fun hasListed(type: Class<*>): Boolean { + return (type.name in whitelist) || delegate.hasListed(type) + } + + override fun add(entry: Class<*>) { + whitelist += entry.name + } +} + /** * This class is not currently used, but can be installed to log a large number of missing entries from the whitelist diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/DefaultKryoCustomizer.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/DefaultKryoCustomizer.kt index f5ece98d46..5112ea6a33 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/DefaultKryoCustomizer.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/DefaultKryoCustomizer.kt @@ -46,10 +46,7 @@ import kotlin.collections.ArrayList object DefaultKryoCustomizer { private val pluginRegistries: List by lazy { - // No ClassResolver only constructor. MapReferenceResolver is the default as used by Kryo in other constructors. - val unusedKryo = Kryo(makeStandardClassResolver(), MapReferenceResolver()) - val customization = KryoSerializationCustomization(unusedKryo) - ServiceLoader.load(CordaPluginRegistry::class.java, this.javaClass.classLoader).toList().filter { it.customizeSerialization(customization) } + ServiceLoader.load(CordaPluginRegistry::class.java, this.javaClass.classLoader).toList() } fun customize(kryo: Kryo): Kryo { diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/KryoAMQPSerializer.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/KryoAMQPSerializer.kt index d864fa3941..16dec8a83e 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/KryoAMQPSerializer.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/KryoAMQPSerializer.kt @@ -4,7 +4,11 @@ import com.esotericsoftware.kryo.Kryo import com.esotericsoftware.kryo.Serializer import com.esotericsoftware.kryo.io.Input import com.esotericsoftware.kryo.io.Output +import net.corda.core.serialization.SerializationContext +import net.corda.core.serialization.SerializationFactory import net.corda.core.serialization.SerializedBytes +import net.corda.core.utilities.sequence +import net.corda.nodeapi.internal.serialization.amqp.AmqpHeaderV1_0 import net.corda.nodeapi.internal.serialization.amqp.DeserializationInput import net.corda.nodeapi.internal.serialization.amqp.SerializationOutput import net.corda.nodeapi.internal.serialization.amqp.SerializerFactory @@ -15,38 +19,19 @@ import net.corda.nodeapi.internal.serialization.amqp.SerializerFactory * * There is no need to write out the length, since this can be peeked out of the first few bytes of the stream. */ -object KryoAMQPSerializer : Serializer() { - internal fun registerCustomSerializers(factory: SerializerFactory) { - factory.apply { - register(net.corda.nodeapi.internal.serialization.amqp.custom.PublicKeySerializer) - register(net.corda.nodeapi.internal.serialization.amqp.custom.ThrowableSerializer(this)) - register(net.corda.nodeapi.internal.serialization.amqp.custom.X500NameSerializer) - register(net.corda.nodeapi.internal.serialization.amqp.custom.BigDecimalSerializer) - register(net.corda.nodeapi.internal.serialization.amqp.custom.CurrencySerializer) - register(net.corda.nodeapi.internal.serialization.amqp.custom.InstantSerializer(this)) - } - } - - // TODO: need to sort out the whitelist... we currently do not apply the whitelist attached to the [Kryo] - // instance to the factory. We need to do this before turning on AMQP serialization. - private val serializerFactory = SerializerFactory().apply { - registerCustomSerializers(this) - } - +class KryoAMQPSerializer(val serializationFactory: SerializationFactory, val serializationContext: SerializationContext) : Serializer() { override fun write(kryo: Kryo, output: Output, obj: Any) { - val amqpOutput = SerializationOutput(serializerFactory) - val bytes = amqpOutput.serialize(obj).bytes + val bytes = serializationFactory.serialize(obj, serializationContext.withPreferredSerializationVersion(AmqpHeaderV1_0)).bytes // No need to write out the size since it's encoded within the AMQP. output.write(bytes) } override fun read(kryo: Kryo, input: Input, type: Class): Any { - val amqpInput = DeserializationInput(serializerFactory) // Use our helper functions to peek the size of the serialized object out of the AMQP byte stream. val peekedBytes = input.readBytes(DeserializationInput.BYTES_NEEDED_TO_PEEK) val size = DeserializationInput.peekSize(peekedBytes) val allBytes = peekedBytes.copyOf(size) input.readBytes(allBytes, peekedBytes.size, size - peekedBytes.size) - return amqpInput.deserialize(SerializedBytes(allBytes), type) + return serializationFactory.deserialize(allBytes.sequence(), type, serializationContext) } } \ No newline at end of file diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/SerializationScheme.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/SerializationScheme.kt index 688dfacf7e..4ab8c3c27e 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/SerializationScheme.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/SerializationScheme.kt @@ -34,7 +34,6 @@ data class SerializationContextImpl(override val preferedSerializationVersion: B override val properties: Map, override val objectReferencesEnabled: Boolean, override val useCase: SerializationContext.UseCase) : SerializationContext { - override fun withProperty(property: Any, value: Any): SerializationContext { return copy(properties = properties + (property to value)) } @@ -52,6 +51,8 @@ data class SerializationContextImpl(override val preferedSerializationVersion: B override fun hasListed(type: Class<*>): Boolean = whitelist.hasListed(type) || type.name == clazz.name }) } + + override fun withPreferredSerializationVersion(versionHeader: ByteSequence) = copy(preferedSerializationVersion = versionHeader) } private const val HEADER_SIZE: Int = 8 @@ -118,7 +119,7 @@ private object AutoCloseableSerialisationDetector : Serializer() override fun read(kryo: Kryo, input: Input, type: Class) = throw IllegalStateException("Should not reach here!") } -abstract class AbstractKryoSerializationScheme : SerializationScheme { +abstract class AbstractKryoSerializationScheme(val serializationFactory: SerializationFactory) : SerializationScheme { private val kryoPoolsForContexts = ConcurrentHashMap, KryoPool>() protected abstract fun rpcClientKryoPool(context: SerializationContext): KryoPool @@ -130,7 +131,7 @@ abstract class AbstractKryoSerializationScheme : SerializationScheme { SerializationContext.UseCase.Checkpoint -> KryoPool.Builder { val serializer = Fiber.getFiberSerializer(false) as KryoSerializer - val classResolver = makeNoWhitelistClassResolver().apply { setKryo(serializer.kryo) } + val classResolver = CordaClassResolver(serializationFactory, context).apply { setKryo(serializer.kryo) } // TODO The ClassResolver can only be set in the Kryo constructor and Quasar doesn't provide us with a way of doing that val field = Kryo::class.java.getDeclaredField("classResolver").apply { isAccessible = true } serializer.kryo.apply { @@ -146,7 +147,7 @@ abstract class AbstractKryoSerializationScheme : SerializationScheme { rpcServerKryoPool(context) else -> KryoPool.Builder { - DefaultKryoCustomizer.customize(CordaKryo(CordaClassResolver(context.whitelist))).apply { classLoader = it.second } + DefaultKryoCustomizer.customize(CordaKryo(CordaClassResolver(serializationFactory, context))).apply { classLoader = it.second } }.build() } } diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/CordaClassResolverTests.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/CordaClassResolverTests.kt index 7410128111..88ac9f4049 100644 --- a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/CordaClassResolverTests.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/CordaClassResolverTests.kt @@ -5,8 +5,8 @@ import com.esotericsoftware.kryo.io.Input import com.esotericsoftware.kryo.io.Output import com.esotericsoftware.kryo.util.MapReferenceResolver import net.corda.core.node.services.AttachmentStorage -import net.corda.core.serialization.AttachmentsClassLoader -import net.corda.core.serialization.CordaSerializable +import net.corda.core.serialization.* +import net.corda.core.utilities.ByteSequence import net.corda.nodeapi.AttachmentClassLoaderTests import net.corda.testing.node.MockAttachmentStorage import org.junit.Rule @@ -76,71 +76,84 @@ class DefaultSerializableSerializer : Serializer() { } class CordaClassResolverTests { + val factory: SerializationFactory = object : SerializationFactory { + override fun deserialize(byteSequence: ByteSequence, clazz: Class, context: SerializationContext): T { + TODO("not implemented") //To change body of created functions use File | Settings | File Templates. + } + + override fun serialize(obj: T, context: SerializationContext): SerializedBytes { + TODO("not implemented") //To change body of created functions use File | Settings | File Templates. + } + + } + + val emptyWhitelistContext: SerializationContext = SerializationContextImpl(KryoHeaderV0_1, this.javaClass.classLoader, EmptyWhitelist, emptyMap(), true, SerializationContext.UseCase.P2P) + val allButBlacklistedContext: SerializationContext = SerializationContextImpl(KryoHeaderV0_1, this.javaClass.classLoader, AllButBlacklisted, emptyMap(), true, SerializationContext.UseCase.P2P) + @Test fun `Annotation on enum works for specialised entries`() { // TODO: Remove this suppress when we upgrade to kotlin 1.1 or when JetBrain fixes the bug. @Suppress("UNSUPPORTED_FEATURE") - CordaClassResolver(EmptyWhitelist).getRegistration(Foo.Bar::class.java) + CordaClassResolver(factory, emptyWhitelistContext).getRegistration(Foo.Bar::class.java) } @Test fun `Annotation on array element works`() { val values = arrayOf(Element()) - CordaClassResolver(EmptyWhitelist).getRegistration(values.javaClass) + CordaClassResolver(factory, emptyWhitelistContext).getRegistration(values.javaClass) } @Test fun `Annotation not needed on abstract class`() { - CordaClassResolver(EmptyWhitelist).getRegistration(AbstractClass::class.java) + CordaClassResolver(factory, emptyWhitelistContext).getRegistration(AbstractClass::class.java) } @Test fun `Annotation not needed on interface`() { - CordaClassResolver(EmptyWhitelist).getRegistration(Interface::class.java) + CordaClassResolver(factory, emptyWhitelistContext).getRegistration(Interface::class.java) } @Test fun `Calling register method on modified Kryo does not consult the whitelist`() { - val kryo = CordaKryo(CordaClassResolver(EmptyWhitelist)) + val kryo = CordaKryo(CordaClassResolver(factory, emptyWhitelistContext)) kryo.register(NotSerializable::class.java) } @Test(expected = KryoException::class) fun `Calling register method on unmodified Kryo does consult the whitelist`() { - val kryo = Kryo(CordaClassResolver(EmptyWhitelist), MapReferenceResolver()) + val kryo = Kryo(CordaClassResolver(factory, emptyWhitelistContext), MapReferenceResolver()) kryo.register(NotSerializable::class.java) } @Test(expected = KryoException::class) fun `Annotation is needed without whitelisting`() { - CordaClassResolver(EmptyWhitelist).getRegistration(NotSerializable::class.java) + CordaClassResolver(factory, emptyWhitelistContext).getRegistration(NotSerializable::class.java) } @Test fun `Annotation is not needed with whitelisting`() { - val resolver = CordaClassResolver(GlobalTransientClassWhiteList(EmptyWhitelist)) - (resolver.whitelist as MutableClassWhitelist).add(NotSerializable::class.java) + val resolver = CordaClassResolver(factory, emptyWhitelistContext.withWhitelisted(NotSerializable::class.java)) resolver.getRegistration(NotSerializable::class.java) } @Test fun `Annotation not needed on Object`() { - CordaClassResolver(EmptyWhitelist).getRegistration(Object::class.java) + CordaClassResolver(factory, emptyWhitelistContext).getRegistration(Object::class.java) } @Test fun `Annotation not needed on primitive`() { - CordaClassResolver(EmptyWhitelist).getRegistration(Integer.TYPE) + CordaClassResolver(factory, emptyWhitelistContext).getRegistration(Integer.TYPE) } @Test(expected = KryoException::class) fun `Annotation does not work for custom serializable`() { - CordaClassResolver(EmptyWhitelist).getRegistration(CustomSerializable::class.java) + CordaClassResolver(factory, emptyWhitelistContext).getRegistration(CustomSerializable::class.java) } @Test(expected = KryoException::class) fun `Annotation does not work in conjunction with Kryo annotation`() { - CordaClassResolver(EmptyWhitelist).getRegistration(DefaultSerializable::class.java) + CordaClassResolver(factory, emptyWhitelistContext).getRegistration(DefaultSerializable::class.java) } private fun importJar(storage: AttachmentStorage) = AttachmentClassLoaderTests.ISOLATED_CONTRACTS_JAR_PATH.openStream().use { storage.importAttachment(it) } @@ -151,20 +164,20 @@ class CordaClassResolverTests { val attachmentHash = importJar(storage) val classLoader = AttachmentsClassLoader(arrayOf(attachmentHash).map { storage.openAttachment(it)!! }) val attachedClass = Class.forName("net.corda.contracts.isolated.AnotherDummyContract", true, classLoader) - CordaClassResolver(EmptyWhitelist).getRegistration(attachedClass) + CordaClassResolver(factory, emptyWhitelistContext).getRegistration(attachedClass) } @Test fun `Annotation is inherited from interfaces`() { - CordaClassResolver(EmptyWhitelist).getRegistration(SerializableViaInterface::class.java) - CordaClassResolver(EmptyWhitelist).getRegistration(SerializableViaSubInterface::class.java) + CordaClassResolver(factory, emptyWhitelistContext).getRegistration(SerializableViaInterface::class.java) + CordaClassResolver(factory, emptyWhitelistContext).getRegistration(SerializableViaSubInterface::class.java) } @Test fun `Annotation is inherited from superclass`() { - CordaClassResolver(EmptyWhitelist).getRegistration(SubElement::class.java) - CordaClassResolver(EmptyWhitelist).getRegistration(SubSubElement::class.java) - CordaClassResolver(EmptyWhitelist).getRegistration(SerializableViaSuperSubInterface::class.java) + CordaClassResolver(factory, emptyWhitelistContext).getRegistration(SubElement::class.java) + CordaClassResolver(factory, emptyWhitelistContext).getRegistration(SubSubElement::class.java) + CordaClassResolver(factory, emptyWhitelistContext).getRegistration(SerializableViaSuperSubInterface::class.java) } // Blacklist tests. @@ -175,7 +188,7 @@ class CordaClassResolverTests { fun `Check blacklisted class`() { expectedEx.expect(IllegalStateException::class.java) expectedEx.expectMessage("Class java.util.HashSet is blacklisted, so it cannot be used in serialization.") - val resolver = CordaClassResolver(AllButBlacklisted) + val resolver = CordaClassResolver(factory, allButBlacklistedContext) // HashSet is blacklisted. resolver.getRegistration(HashSet::class.java) } @@ -185,7 +198,7 @@ class CordaClassResolverTests { fun `Check blacklisted subclass`() { expectedEx.expect(IllegalStateException::class.java) expectedEx.expectMessage("The superclass java.util.HashSet of net.corda.nodeapi.internal.serialization.CordaClassResolverTests\$SubHashSet is blacklisted, so it cannot be used in serialization.") - val resolver = CordaClassResolver(AllButBlacklisted) + val resolver = CordaClassResolver(factory, allButBlacklistedContext) // SubHashSet extends the blacklisted HashSet. resolver.getRegistration(SubHashSet::class.java) } @@ -195,7 +208,7 @@ class CordaClassResolverTests { fun `Check blacklisted subsubclass`() { expectedEx.expect(IllegalStateException::class.java) expectedEx.expectMessage("The superclass java.util.HashSet of net.corda.nodeapi.internal.serialization.CordaClassResolverTests\$SubSubHashSet is blacklisted, so it cannot be used in serialization.") - val resolver = CordaClassResolver(AllButBlacklisted) + val resolver = CordaClassResolver(factory, allButBlacklistedContext) // SubSubHashSet extends SubHashSet, which extends the blacklisted HashSet. resolver.getRegistration(SubSubHashSet::class.java) } @@ -205,7 +218,7 @@ class CordaClassResolverTests { fun `Check blacklisted interface impl`() { expectedEx.expect(IllegalStateException::class.java) expectedEx.expectMessage("The superinterface java.sql.Connection of net.corda.nodeapi.internal.serialization.CordaClassResolverTests\$ConnectionImpl is blacklisted, so it cannot be used in serialization.") - val resolver = CordaClassResolver(AllButBlacklisted) + val resolver = CordaClassResolver(factory, allButBlacklistedContext) // ConnectionImpl implements blacklisted Connection. resolver.getRegistration(ConnectionImpl::class.java) } @@ -216,14 +229,14 @@ class CordaClassResolverTests { fun `Check blacklisted super-interface impl`() { expectedEx.expect(IllegalStateException::class.java) expectedEx.expectMessage("The superinterface java.sql.Connection of net.corda.nodeapi.internal.serialization.CordaClassResolverTests\$SubConnectionImpl is blacklisted, so it cannot be used in serialization.") - val resolver = CordaClassResolver(AllButBlacklisted) + val resolver = CordaClassResolver(factory, allButBlacklistedContext) // SubConnectionImpl implements SubConnection, which extends the blacklisted Connection. resolver.getRegistration(SubConnectionImpl::class.java) } @Test fun `Check forcibly allowed`() { - val resolver = CordaClassResolver(AllButBlacklisted) + val resolver = CordaClassResolver(factory, allButBlacklistedContext) // LinkedHashSet is allowed for serialization. resolver.getRegistration(LinkedHashSet::class.java) } @@ -234,7 +247,7 @@ class CordaClassResolverTests { fun `Check blacklist precedes CordaSerializable`() { expectedEx.expect(IllegalStateException::class.java) expectedEx.expectMessage("The superclass java.util.HashSet of net.corda.nodeapi.internal.serialization.CordaClassResolverTests\$CordaSerializableHashSet is blacklisted, so it cannot be used in serialization.") - val resolver = CordaClassResolver(AllButBlacklisted) + val resolver = CordaClassResolver(factory, allButBlacklistedContext) // CordaSerializableHashSet is @CordaSerializable, but extends the blacklisted HashSet. resolver.getRegistration(CordaSerializableHashSet::class.java) } diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/KryoTests.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/KryoTests.kt index 702c66b9b7..66918ac812 100644 --- a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/KryoTests.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/KryoTests.kt @@ -12,6 +12,7 @@ import net.corda.core.utilities.sequence import net.corda.node.serialization.KryoServerSerializationScheme import net.corda.node.services.persistence.NodeAttachmentService import net.corda.testing.ALICE_PUBKEY +import net.corda.testing.TestDependencyInjectionBase import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThatThrownBy import org.junit.Before @@ -23,13 +24,13 @@ import java.time.Instant import kotlin.test.assertEquals import kotlin.test.assertTrue -class KryoTests { +class KryoTests : TestDependencyInjectionBase() { private lateinit var factory: SerializationFactory private lateinit var context: SerializationContext @Before fun setup() { - factory = SerializationFactoryImpl().apply { registerScheme(KryoServerSerializationScheme()) } + factory = SerializationFactoryImpl().apply { registerScheme(KryoServerSerializationScheme(this)) } context = SerializationContextImpl(KryoHeaderV0_1, javaClass.classLoader, AllWhitelist, @@ -199,7 +200,7 @@ class KryoTests { } } Tmp() - val factory = SerializationFactoryImpl().apply { registerScheme(KryoServerSerializationScheme()) } + val factory = SerializationFactoryImpl().apply { registerScheme(KryoServerSerializationScheme(this)) } val context = SerializationContextImpl(KryoHeaderV0_1, javaClass.classLoader, AllWhitelist, diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/SerializationTokenTest.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/SerializationTokenTest.kt index 8efd50fa0e..03ab48214d 100644 --- a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/SerializationTokenTest.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/SerializationTokenTest.kt @@ -8,19 +8,20 @@ import net.corda.core.node.ServiceHub import net.corda.core.serialization.* import net.corda.core.utilities.OpaqueBytes import net.corda.node.serialization.KryoServerSerializationScheme +import net.corda.testing.TestDependencyInjectionBase import org.assertj.core.api.Assertions.assertThat import org.junit.Before import org.junit.Test import java.io.ByteArrayOutputStream -class SerializationTokenTest { +class SerializationTokenTest : TestDependencyInjectionBase() { lateinit var factory: SerializationFactory lateinit var context: SerializationContext @Before fun setup() { - factory = SerializationFactoryImpl().apply { registerScheme(KryoServerSerializationScheme()) } + factory = SerializationFactoryImpl().apply { registerScheme(KryoServerSerializationScheme(this)) } context = SerializationContextImpl(KryoHeaderV0_1, javaClass.classLoader, AllWhitelist, @@ -96,7 +97,7 @@ class SerializationTokenTest { val context = serializeAsTokenContext(tokenizableBefore) val testContext = this.context.withTokenContext(context) - val kryo: Kryo = DefaultKryoCustomizer.customize(CordaKryo(makeNoWhitelistClassResolver())) + val kryo: Kryo = DefaultKryoCustomizer.customize(CordaKryo(CordaClassResolver(factory, this.context))) val stream = ByteArrayOutputStream() Output(stream).use { it.write(KryoHeaderV0_1.bytes) diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializationOutputTests.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializationOutputTests.kt index b891201721..53bdaaafe1 100644 --- a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializationOutputTests.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializationOutputTests.kt @@ -11,8 +11,8 @@ import net.corda.core.identity.AbstractParty import net.corda.core.serialization.CordaSerializable import net.corda.core.transactions.LedgerTransaction import net.corda.nodeapi.RPCException +import net.corda.nodeapi.internal.serialization.AbstractAMQPSerializationScheme import net.corda.nodeapi.internal.serialization.EmptyWhitelist -import net.corda.nodeapi.internal.serialization.KryoAMQPSerializer import net.corda.nodeapi.internal.serialization.amqp.SerializerFactory.Companion.isPrimitive import net.corda.nodeapi.internal.serialization.amqp.custom.* import net.corda.testing.MEGA_CORP @@ -528,10 +528,10 @@ class SerializationOutputTests { val state = TransactionState(FooState(), MEGA_CORP) val factory = SerializerFactory() - KryoAMQPSerializer.registerCustomSerializers(factory) + AbstractAMQPSerializationScheme.registerCustomSerializers(factory) val factory2 = SerializerFactory() - KryoAMQPSerializer.registerCustomSerializers(factory2) + AbstractAMQPSerializationScheme.registerCustomSerializers(factory2) val desState = serdes(state, factory, factory2, expectedEqual = false, expectDeserializedEqual = false) assertTrue(desState is TransactionState<*>) diff --git a/node/src/main/kotlin/net/corda/node/internal/Node.kt b/node/src/main/kotlin/net/corda/node/internal/Node.kt index eea378de27..db94242f56 100644 --- a/node/src/main/kotlin/net/corda/node/internal/Node.kt +++ b/node/src/main/kotlin/net/corda/node/internal/Node.kt @@ -331,7 +331,7 @@ open class Node(override val configuration: FullNodeConfiguration, private fun initialiseSerialization() { SerializationDefaults.SERIALIZATION_FACTORY = SerializationFactoryImpl().apply { - registerScheme(KryoServerSerializationScheme()) + registerScheme(KryoServerSerializationScheme(this)) registerScheme(AMQPServerSerializationScheme()) } SerializationDefaults.P2P_CONTEXT = KRYO_P2P_CONTEXT diff --git a/node/src/main/kotlin/net/corda/node/serialization/SerializationScheme.kt b/node/src/main/kotlin/net/corda/node/serialization/SerializationScheme.kt index b35f91bacf..14b5ef144e 100644 --- a/node/src/main/kotlin/net/corda/node/serialization/SerializationScheme.kt +++ b/node/src/main/kotlin/net/corda/node/serialization/SerializationScheme.kt @@ -2,6 +2,7 @@ package net.corda.node.serialization import com.esotericsoftware.kryo.pool.KryoPool import net.corda.core.serialization.SerializationContext +import net.corda.core.serialization.SerializationFactory import net.corda.core.utilities.ByteSequence import net.corda.node.services.messaging.RpcServerObservableSerializer import net.corda.nodeapi.RPCKryo @@ -9,7 +10,7 @@ import net.corda.nodeapi.internal.serialization.AbstractKryoSerializationScheme import net.corda.nodeapi.internal.serialization.DefaultKryoCustomizer import net.corda.nodeapi.internal.serialization.KryoHeaderV0_1 -class KryoServerSerializationScheme : AbstractKryoSerializationScheme() { +class KryoServerSerializationScheme(serializationFactory: SerializationFactory) : AbstractKryoSerializationScheme(serializationFactory) { override fun canDeserializeVersion(byteSequence: ByteSequence, target: SerializationContext.UseCase): Boolean { return byteSequence == KryoHeaderV0_1 && target != SerializationContext.UseCase.RPCClient } @@ -20,7 +21,7 @@ class KryoServerSerializationScheme : AbstractKryoSerializationScheme() { override fun rpcServerKryoPool(context: SerializationContext): KryoPool { return KryoPool.Builder { - DefaultKryoCustomizer.customize(RPCKryo(RpcServerObservableSerializer, context.whitelist)).apply { classLoader = context.deserializationClassLoader } + DefaultKryoCustomizer.customize(RPCKryo(RpcServerObservableSerializer, serializationFactory, context)).apply { classLoader = context.deserializationClassLoader } }.build() } } \ No newline at end of file diff --git a/node/src/test/kotlin/net/corda/node/utilities/X509UtilitiesTest.kt b/node/src/test/kotlin/net/corda/node/utilities/X509UtilitiesTest.kt index 2c528cbfd5..69ada450ff 100644 --- a/node/src/test/kotlin/net/corda/node/utilities/X509UtilitiesTest.kt +++ b/node/src/test/kotlin/net/corda/node/utilities/X509UtilitiesTest.kt @@ -398,7 +398,7 @@ class X509UtilitiesTest { @Test fun `serialize - deserialize X509CertififcateHolder`() { - val factory = SerializationFactoryImpl().apply { registerScheme(KryoServerSerializationScheme()) } + val factory = SerializationFactoryImpl().apply { registerScheme(KryoServerSerializationScheme(this)) } val context = SerializationContextImpl(KryoHeaderV0_1, javaClass.classLoader, AllWhitelist, @@ -413,7 +413,7 @@ class X509UtilitiesTest { @Test fun `serialize - deserialize X509CertPath`() { - val factory = SerializationFactoryImpl().apply { registerScheme(KryoServerSerializationScheme()) } + val factory = SerializationFactoryImpl().apply { registerScheme(KryoServerSerializationScheme(this)) } val context = SerializationContextImpl(KryoHeaderV0_1, javaClass.classLoader, AllWhitelist, diff --git a/test-utils/src/main/kotlin/net/corda/testing/SerializationTestHelpers.kt b/test-utils/src/main/kotlin/net/corda/testing/SerializationTestHelpers.kt index c250214420..9ec3760622 100644 --- a/test-utils/src/main/kotlin/net/corda/testing/SerializationTestHelpers.kt +++ b/test-utils/src/main/kotlin/net/corda/testing/SerializationTestHelpers.kt @@ -61,8 +61,8 @@ fun initialiseTestSerialization() { // Now configure all the testing related delegates. (SerializationDefaults.SERIALIZATION_FACTORY as TestSerializationFactory).delegate = SerializationFactoryImpl().apply { - registerScheme(KryoClientSerializationScheme()) - registerScheme(KryoServerSerializationScheme()) + registerScheme(KryoClientSerializationScheme(this)) + registerScheme(KryoServerSerializationScheme(this)) registerScheme(AMQPClientSerializationScheme()) registerScheme(AMQPServerSerializationScheme()) } @@ -139,4 +139,8 @@ class TestSerializationContext : SerializationContext { override fun withWhitelisted(clazz: Class<*>): SerializationContext { return TestSerializationContext().apply { delegate = this@TestSerializationContext.delegate!!.withWhitelisted(clazz) } } + + override fun withPreferredSerializationVersion(versionHeader: ByteSequence): SerializationContext { + return TestSerializationContext().apply { delegate = this@TestSerializationContext.delegate!!.withPreferredSerializationVersion(versionHeader) } + } } diff --git a/verifier/src/main/kotlin/net/corda/verifier/Verifier.kt b/verifier/src/main/kotlin/net/corda/verifier/Verifier.kt index cafcfb91c2..fec528f883 100644 --- a/verifier/src/main/kotlin/net/corda/verifier/Verifier.kt +++ b/verifier/src/main/kotlin/net/corda/verifier/Verifier.kt @@ -7,6 +7,7 @@ import com.typesafe.config.ConfigParseOptions import net.corda.core.internal.div import net.corda.core.serialization.SerializationContext import net.corda.core.serialization.SerializationDefaults +import net.corda.core.serialization.SerializationFactory import net.corda.core.utilities.ByteSequence import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.debug @@ -89,13 +90,13 @@ class Verifier { private fun initialiseSerialization() { SerializationDefaults.SERIALIZATION_FACTORY = SerializationFactoryImpl().apply { - registerScheme(KryoVerifierSerializationScheme) + registerScheme(KryoVerifierSerializationScheme(this)) } SerializationDefaults.P2P_CONTEXT = KRYO_P2P_CONTEXT } } - object KryoVerifierSerializationScheme : AbstractKryoSerializationScheme() { + class KryoVerifierSerializationScheme(serializationFactory: SerializationFactory) : AbstractKryoSerializationScheme(serializationFactory) { override fun canDeserializeVersion(byteSequence: ByteSequence, target: SerializationContext.UseCase): Boolean { return byteSequence.equals(KryoHeaderV0_1) && target == SerializationContext.UseCase.P2P }