From d7ef385cc7667b3f379aff534c4fe7bb972007fa Mon Sep 17 00:00:00 2001 From: Michele Sollecito Date: Wed, 9 May 2018 15:58:18 +0700 Subject: [PATCH 1/4] [CORDA-1395] [CORDA-1378]: Control the max number of transaction dependencies. (#3047) --- .../net/corda/core/internal/FetchDataFlow.kt | 1 + .../core/internal/ResolveTransactionsFlow.kt | 37 ++++++-- .../ContractUpgradeTransactions.kt | 1 + .../transactions/NotaryChangeTransactions.kt | 1 + .../internal/ResolveTransactionsFlowTest.kt | 9 +- docs/source/changelog.rst | 1 + .../node/services/vault/NodeVaultService.kt | 17 +++- .../transactions/NotaryServiceTests.kt | 3 +- .../node/services/vault/VaultQueryTests.kt | 84 +++++++++++++++++-- 9 files changed, 128 insertions(+), 26 deletions(-) diff --git a/core/src/main/kotlin/net/corda/core/internal/FetchDataFlow.kt b/core/src/main/kotlin/net/corda/core/internal/FetchDataFlow.kt index 867a719949..c7a6cfd112 100644 --- a/core/src/main/kotlin/net/corda/core/internal/FetchDataFlow.kt +++ b/core/src/main/kotlin/net/corda/core/internal/FetchDataFlow.kt @@ -85,6 +85,7 @@ sealed class FetchDataFlow( for (hash in toFetch) { // We skip the validation here (with unwrap { it }) because we will do it below in validateFetchResponse. // The only thing checked is the object type. It is a protocol violation to send results out of order. + // TODO We need to page here after large messages will work. maybeItems += otherSideSession.sendAndReceive>(Request.Data(NonEmptySet.of(hash), dataType)).unwrap { it } } // Check for a buggy/malicious peer answering with something that we didn't ask for. diff --git a/core/src/main/kotlin/net/corda/core/internal/ResolveTransactionsFlow.kt b/core/src/main/kotlin/net/corda/core/internal/ResolveTransactionsFlow.kt index 5ceb17d094..d9a6f2c715 100644 --- a/core/src/main/kotlin/net/corda/core/internal/ResolveTransactionsFlow.kt +++ b/core/src/main/kotlin/net/corda/core/internal/ResolveTransactionsFlow.kt @@ -12,6 +12,8 @@ import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.WireTransaction import net.corda.core.utilities.exactAdd import java.util.* +import kotlin.collections.ArrayList +import kotlin.math.min // TODO: This code is currently unit tested by TwoPartyTradeFlowTests, it should have its own tests. /** @@ -20,8 +22,12 @@ import java.util.* * * @return a list of verified [SignedTransaction] objects, in a depth-first order. */ -class ResolveTransactionsFlow(private val txHashes: Set, - private val otherSide: FlowSession) : FlowLogic>() { +class ResolveTransactionsFlow(txHashesArg: Set, + private val otherSide: FlowSession) : FlowLogic() { + + // Need it ordered in terms of iteration. Needs to be a variable for the check-pointing logic to work. + private val txHashes = txHashesArg.toList() + /** * Resolves and validates the dependencies of the specified [SignedTransaction]. Fetches the attachments, but does * *not* validate or store the [SignedTransaction] itself. @@ -35,6 +41,8 @@ class ResolveTransactionsFlow(private val txHashes: Set, companion object { private fun dependencyIDs(stx: SignedTransaction) = stx.inputs.map { it.txhash }.toSet() + private const val RESOLUTION_PAGE_SIZE = 100 + /** * Topologically sorts the given transactions such that dependencies are listed before dependers. */ @JvmStatic @@ -83,10 +91,16 @@ class ResolveTransactionsFlow(private val txHashes: Set, @Suspendable @Throws(FetchDataFlow.HashNotFound::class) - override fun call(): List { + override fun call() { + val newTxns = ArrayList(txHashes.size) // Start fetching data. - val newTxns = downloadDependencies(txHashes) - fetchMissingAttachments(signedTransaction?.let { newTxns + it } ?: newTxns) + for (pageNumber in 0..(txHashes.size - 1) / RESOLUTION_PAGE_SIZE) { + val page = page(pageNumber, RESOLUTION_PAGE_SIZE) + + newTxns += downloadDependencies(page) + val txsWithMissingAttachments = if (pageNumber == 0) signedTransaction?.let { newTxns + it } ?: newTxns else newTxns + fetchMissingAttachments(txsWithMissingAttachments) + } otherSide.send(FetchDataFlow.Request.End) // Finish fetching data. @@ -99,13 +113,17 @@ class ResolveTransactionsFlow(private val txHashes: Set, it.verify(serviceHub) serviceHub.recordTransactions(StatesToRecord.NONE, listOf(it)) } + } - return signedTransaction?.let { - result + it - } ?: result + private fun page(pageNumber: Int, pageSize: Int): Set { + val offset = pageNumber * pageSize + val limit = min(offset + pageSize, txHashes.size) + // call toSet() is needed because sub-lists are not checkpoint-friendly. + return txHashes.subList(offset, limit).toSet() } @Suspendable + // TODO use paging here (we literally get the entire dependencies graph in memory) private fun downloadDependencies(depsToCheck: Set): List { // Maintain a work queue of all hashes to load/download, initialised with our starting set. Then do a breadth // first traversal across the dependency graph. @@ -132,13 +150,14 @@ class ResolveTransactionsFlow(private val txHashes: Set, while (nextRequests.isNotEmpty()) { // Don't re-download the same tx when we haven't verified it yet but it's referenced multiple times in the // graph we're traversing. - val notAlreadyFetched = nextRequests.filterNot { it in resultQ }.toSet() + val notAlreadyFetched: Set = nextRequests - resultQ.keys nextRequests.clear() if (notAlreadyFetched.isEmpty()) // Done early. break // Request the standalone transaction data (which may refer to things we don't yet have). + // TODO use paging here val downloads: List = subFlow(FetchTransactionsFlow(notAlreadyFetched, otherSide)).downloaded for (stx in downloads) diff --git a/core/src/main/kotlin/net/corda/core/transactions/ContractUpgradeTransactions.kt b/core/src/main/kotlin/net/corda/core/transactions/ContractUpgradeTransactions.kt index e85b50f82e..d44cee239d 100644 --- a/core/src/main/kotlin/net/corda/core/transactions/ContractUpgradeTransactions.kt +++ b/core/src/main/kotlin/net/corda/core/transactions/ContractUpgradeTransactions.kt @@ -41,6 +41,7 @@ data class ContractUpgradeWireTransaction( init { check(inputs.isNotEmpty()) { "A contract upgrade transaction must have inputs" } + checkBaseInvariants() } /** diff --git a/core/src/main/kotlin/net/corda/core/transactions/NotaryChangeTransactions.kt b/core/src/main/kotlin/net/corda/core/transactions/NotaryChangeTransactions.kt index 0704ec0db2..fbee73d680 100644 --- a/core/src/main/kotlin/net/corda/core/transactions/NotaryChangeTransactions.kt +++ b/core/src/main/kotlin/net/corda/core/transactions/NotaryChangeTransactions.kt @@ -45,6 +45,7 @@ data class NotaryChangeWireTransaction( init { check(inputs.isNotEmpty()) { "A notary change transaction must have inputs" } check(notary != newNotary) { "The old and new notaries must be different – $newNotary" } + checkBaseInvariants() } /** diff --git a/core/src/test/kotlin/net/corda/core/internal/ResolveTransactionsFlowTest.kt b/core/src/test/kotlin/net/corda/core/internal/ResolveTransactionsFlowTest.kt index 8b762ecb9a..b652959156 100644 --- a/core/src/test/kotlin/net/corda/core/internal/ResolveTransactionsFlowTest.kt +++ b/core/src/test/kotlin/net/corda/core/internal/ResolveTransactionsFlowTest.kt @@ -58,8 +58,7 @@ class ResolveTransactionsFlowTest { val p = TestFlow(setOf(stx2.id), megaCorp) val future = miniCorpNode.startFlow(p) mockNet.runNetwork() - val results = future.getOrThrow() - assertEquals(listOf(stx1.id, stx2.id), results.map { it.id }) + future.getOrThrow() miniCorpNode.transaction { assertEquals(stx1, miniCorpNode.services.validatedTransactions.getTransaction(stx1.id)) assertEquals(stx2, miniCorpNode.services.validatedTransactions.getTransaction(stx2.id)) @@ -189,16 +188,16 @@ class ResolveTransactionsFlowTest { // DOCEND 2 @InitiatingFlow - private class TestFlow(val otherSide: Party, private val resolveTransactionsFlowFactory: (FlowSession) -> ResolveTransactionsFlow, private val txCountLimit: Int? = null) : FlowLogic>() { + private class TestFlow(val otherSide: Party, private val resolveTransactionsFlowFactory: (FlowSession) -> ResolveTransactionsFlow, private val txCountLimit: Int? = null) : FlowLogic() { constructor(txHashes: Set, otherSide: Party, txCountLimit: Int? = null) : this(otherSide, { ResolveTransactionsFlow(txHashes, it) }, txCountLimit = txCountLimit) constructor(stx: SignedTransaction, otherSide: Party) : this(otherSide, { ResolveTransactionsFlow(stx, it) }) @Suspendable - override fun call(): List { + override fun call() { val session = initiateFlow(otherSide) val resolveTransactionsFlow = resolveTransactionsFlowFactory(session) txCountLimit?.let { resolveTransactionsFlow.transactionCountLimit = it } - return subFlow(resolveTransactionsFlow) + subFlow(resolveTransactionsFlow) } } diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index ae292a8b71..4deb351720 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -6,6 +6,7 @@ release, see :doc:`upgrade-notes`. Unreleased ========== +* Fixed an error thrown by NodeVaultService upon recording a transaction with a number of inputs greater than the default page size. * Fixed incorrect computation of ``totalStates`` from ``otherResults`` in ``NodeVaultService``. diff --git a/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt b/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt index 038d25b562..817d0d7a40 100644 --- a/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt +++ b/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt @@ -188,9 +188,18 @@ class NodeVaultService( } private fun loadStates(refs: Collection): Collection> { - return if (refs.isNotEmpty()) - queryBy(QueryCriteria.VaultQueryCriteria(stateRefs = refs.toList())).states - else emptySet() + val states = mutableListOf>() + if (refs.isNotEmpty()) { + val refsList = refs.toList() + val pageSize = PageSpecification().pageSize + (0..(refsList.size - 1) / pageSize).forEach { + val offset = it * pageSize + val limit = minOf(offset + pageSize, refsList.size) + val page = queryBy(QueryCriteria.VaultQueryCriteria(stateRefs = refsList.subList(offset, limit))).states + states.addAll(page) + } + } + return states } private fun processAndNotify(updates: List>) { @@ -507,4 +516,4 @@ class NodeVaultService( } return myInterfaces } -} +} \ No newline at end of file diff --git a/node/src/test/kotlin/net/corda/node/services/transactions/NotaryServiceTests.kt b/node/src/test/kotlin/net/corda/node/services/transactions/NotaryServiceTests.kt index cc8f47a2a0..a09a027cee 100644 --- a/node/src/test/kotlin/net/corda/node/services/transactions/NotaryServiceTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/transactions/NotaryServiceTests.kt @@ -63,7 +63,8 @@ class NotaryServiceTests { } private fun generateTransaction(node: StartedNode, party: Party, notary: Party): SignedTransaction { - val inputs = (1..10_005).map { StateRef(SecureHash.randomSHA256(), 0) } + val txHash = SecureHash.randomSHA256() + val inputs = (1..10_005).map { StateRef(txHash, it) } val tx = NotaryChangeTransactionBuilder(inputs, notary, party).build() return node.services.run { diff --git a/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt b/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt index 3271385473..d08c63eaff 100644 --- a/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt @@ -1,18 +1,58 @@ package net.corda.node.services.vault -import net.corda.core.contracts.* +import net.corda.core.contracts.Amount +import net.corda.core.contracts.ContractState +import net.corda.core.contracts.FungibleAsset +import net.corda.core.contracts.LinearState +import net.corda.core.contracts.PartyAndReference +import net.corda.core.contracts.StateAndRef +import net.corda.core.contracts.StateRef +import net.corda.core.crypto.Crypto import net.corda.core.crypto.SecureHash +import net.corda.core.crypto.SignatureMetadata import net.corda.core.crypto.generateKeyPair import net.corda.core.crypto.toStringShort import net.corda.core.identity.CordaX500Name import net.corda.core.identity.Party import net.corda.core.internal.packageName -import net.corda.core.node.services.* -import net.corda.core.node.services.vault.* -import net.corda.core.node.services.vault.QueryCriteria.* +import net.corda.core.node.services.IdentityService +import net.corda.core.node.services.Vault +import net.corda.core.node.services.VaultQueryException +import net.corda.core.node.services.VaultService +import net.corda.core.node.services.queryBy +import net.corda.core.node.services.trackBy +import net.corda.core.node.services.vault.BinaryComparisonOperator +import net.corda.core.node.services.vault.ColumnPredicate +import net.corda.core.node.services.vault.DEFAULT_PAGE_NUM +import net.corda.core.node.services.vault.DEFAULT_PAGE_SIZE +import net.corda.core.node.services.vault.MAX_PAGE_SIZE +import net.corda.core.node.services.vault.PageSpecification +import net.corda.core.node.services.vault.QueryCriteria +import net.corda.core.node.services.vault.QueryCriteria.FungibleAssetQueryCriteria +import net.corda.core.node.services.vault.QueryCriteria.LinearStateQueryCriteria +import net.corda.core.node.services.vault.QueryCriteria.SoftLockingCondition +import net.corda.core.node.services.vault.QueryCriteria.SoftLockingType +import net.corda.core.node.services.vault.QueryCriteria.TimeCondition +import net.corda.core.node.services.vault.QueryCriteria.TimeInstantType +import net.corda.core.node.services.vault.QueryCriteria.VaultCustomQueryCriteria +import net.corda.core.node.services.vault.QueryCriteria.VaultQueryCriteria +import net.corda.core.node.services.vault.Sort +import net.corda.core.node.services.vault.SortAttribute +import net.corda.core.node.services.vault.builder import net.corda.core.transactions.TransactionBuilder -import net.corda.core.utilities.* -import net.corda.finance.* +import net.corda.core.utilities.NonEmptySet +import net.corda.core.utilities.OpaqueBytes +import net.corda.core.utilities.days +import net.corda.core.utilities.seconds +import net.corda.core.utilities.toHexString +import net.corda.finance.AMOUNT +import net.corda.finance.CHF +import net.corda.finance.DOLLARS +import net.corda.finance.GBP +import net.corda.finance.POUNDS +import net.corda.finance.SWISS_FRANCS +import net.corda.finance.USD +import net.corda.finance.`issued by` import net.corda.finance.contracts.CommercialPaper import net.corda.finance.contracts.Commodity import net.corda.finance.contracts.DealState @@ -27,7 +67,18 @@ import net.corda.node.internal.configureDatabase import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.DatabaseConfig import net.corda.nodeapi.internal.persistence.DatabaseTransaction -import net.corda.testing.core.* +import net.corda.testing.core.ALICE_NAME +import net.corda.testing.core.BOB_NAME +import net.corda.testing.core.BOC_NAME +import net.corda.testing.core.CHARLIE_NAME +import net.corda.testing.core.DUMMY_NOTARY_NAME +import net.corda.testing.core.SerializationEnvironmentRule +import net.corda.testing.core.TestIdentity +import net.corda.testing.core.dummyCommand +import net.corda.testing.core.expect +import net.corda.testing.core.expectEvents +import net.corda.testing.core.sequence +import net.corda.testing.core.singleIdentityAndCert import net.corda.testing.internal.TEST_TX_TIME import net.corda.testing.internal.rigorousMock import net.corda.testing.internal.vault.DUMMY_LINEAR_CONTRACT_PROGRAM_ID @@ -38,6 +89,7 @@ import net.corda.testing.node.MockServices import net.corda.testing.node.MockServices.Companion.makeTestDatabaseAndMockServices import net.corda.testing.node.makeTestIdentityService import org.assertj.core.api.Assertions.assertThat +import org.assertj.core.api.Assertions.assertThatCode import org.junit.ClassRule import org.junit.Ignore import org.junit.Rule @@ -2235,6 +2287,24 @@ abstract class VaultQueryTestsBase : VaultQueryParties { assertThat(exitStates).hasSize(0) } } + + @Test + fun `record a transaction with number of inputs greater than vault page size`() { + val notary = dummyNotary + val issuerKey = notary.keyPair + val signatureMetadata = SignatureMetadata(services.myInfo.platformVersion, Crypto.findSignatureScheme(issuerKey.public).schemeNumberID) + val states = database.transaction { + vaultFiller.fillWithSomeTestLinearStates(PageSpecification().pageSize + 1).states + } + + database.transaction { + val statesExitingTx = TransactionBuilder(notary.party).withItems(*states.toList().toTypedArray()).addCommand(dummyCommand()) + val signedStatesExitingTx = services.signInitialTransaction(statesExitingTx).withAdditionalSignature(issuerKey, signatureMetadata) + + assertThatCode { services.recordTransactions(signedStatesExitingTx) }.doesNotThrowAnyException() + } + } + /** * USE CASE demonstrations (outside of mainline Corda) * From a70e47969669e670781a16f966bd344e687ca74d Mon Sep 17 00:00:00 2001 From: Tudor Malene Date: Wed, 9 May 2018 11:16:36 +0100 Subject: [PATCH 2/4] ENT-1762 doc around jarDirs (#3094) * ENT-1762 doc around jarDirs --- docs/source/corda-configuration-file.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/source/corda-configuration-file.rst b/docs/source/corda-configuration-file.rst index c28c30a392..79d785c107 100644 --- a/docs/source/corda-configuration-file.rst +++ b/docs/source/corda-configuration-file.rst @@ -178,7 +178,8 @@ absolute path to the node's base directory. :jarDirs: An optional list of file system directories containing JARs to include in the classpath when launching via ``corda.jar`` only. Each should be a string. Only the JARs in the directories are added, not the directories themselves. This is useful - for including JDBC drivers and the like. e.g. ``jarDirs = [ 'lib' ]`` + for including JDBC drivers and the like. e.g. ``jarDirs = [ '${baseDirectory}/lib' ]`` (Note that you have to use the ``baseDirectory`` + substitution value when pointing to a relative path) :sshd: If provided, node will start internal SSH server which will provide a management shell. It uses the same credentials and permissions as RPC subsystem. It has one required parameter. From be11da76c8de0b284d820fb3ed29f7604874e2e9 Mon Sep 17 00:00:00 2001 From: Patrick Kuo Date: Wed, 9 May 2018 12:56:10 +0100 Subject: [PATCH 3/4] CORDA-1363 Network registration helper should check public key in certificate before storing in keystore (#3071) * check pub key in network registration helper before storing in DB --- .../registration/NetworkRegistrationHelper.kt | 13 ++- .../NetworkRegistrationHelperTest.kt | 82 +++++++++++-------- 2 files changed, 58 insertions(+), 37 deletions(-) diff --git a/node/src/main/kotlin/net/corda/node/utilities/registration/NetworkRegistrationHelper.kt b/node/src/main/kotlin/net/corda/node/utilities/registration/NetworkRegistrationHelper.kt index d36a2f8a11..8bb223694c 100644 --- a/node/src/main/kotlin/net/corda/node/utilities/registration/NetworkRegistrationHelper.kt +++ b/node/src/main/kotlin/net/corda/node/utilities/registration/NetworkRegistrationHelper.kt @@ -18,6 +18,7 @@ import java.io.StringWriter import java.nio.file.Path import java.security.KeyPair import java.security.KeyStore +import java.security.PublicKey import java.security.cert.X509Certificate /** @@ -85,14 +86,17 @@ open class NetworkRegistrationHelper(private val config: SSLConfiguration, requestIdStore.deleteIfExists() throw certificateRequestException } - validateCertificates(certificates) + validateCertificates(keyPair.public, certificates) storePrivateKeyWithCertificates(nodeKeyStore, keyPair, certificates, keyAlias) onSuccess(keyPair, certificates) // All done, clean up temp files. requestIdStore.deleteIfExists() + + println("Successfully registered Corda node with compatibility zone, node identity keys and certificates are stored in '${config.certificatesDirectory}', it is advised to backup the private keys and certificates.") + println("Corda node will now terminate.") } - private fun validateCertificates(certificates: List) { + private fun validateCertificates(registeringPublicKey: PublicKey, certificates: List) { val nodeCACertificate = certificates.first() val nodeCaSubject = try { @@ -114,6 +118,11 @@ open class NetworkRegistrationHelper(private val config: SSLConfiguration, throw CertificateRequestException("Received certificate contains invalid cert role, expected '$certRole', got '$nodeCaCertRole'.") } + // Validate returned certificate is for the correct public key. + if (Crypto.toSupportedPublicKey(certificates.first().publicKey) != Crypto.toSupportedPublicKey(registeringPublicKey)) { + throw CertificateRequestException("Received certificate contains incorrect public key, expected '$registeringPublicKey', got '${certificates.first().publicKey}'.") + } + // Validate certificate chain returned from the doorman with the root cert obtained via out-of-band process, to prevent MITM attack on doorman server. X509Utilities.validateCertificateChain(rootCert, certificates) println("Certificate signing request approved, storing private key with the certificate chain.") diff --git a/node/src/test/kotlin/net/corda/node/utilities/registration/NetworkRegistrationHelperTest.kt b/node/src/test/kotlin/net/corda/node/utilities/registration/NetworkRegistrationHelperTest.kt index e93805b717..5c82a90afa 100644 --- a/node/src/test/kotlin/net/corda/node/utilities/registration/NetworkRegistrationHelperTest.kt +++ b/node/src/test/kotlin/net/corda/node/utilities/registration/NetworkRegistrationHelperTest.kt @@ -3,8 +3,8 @@ package net.corda.node.utilities.registration import com.google.common.jimfs.Configuration.unix import com.google.common.jimfs.Jimfs import com.nhaarman.mockito_kotlin.any +import com.nhaarman.mockito_kotlin.doAnswer import com.nhaarman.mockito_kotlin.doReturn -import com.nhaarman.mockito_kotlin.eq import com.nhaarman.mockito_kotlin.whenever import net.corda.core.crypto.Crypto import net.corda.core.crypto.SecureHash @@ -17,6 +17,7 @@ import net.corda.core.utilities.seconds import net.corda.node.NodeRegistrationOption import net.corda.node.services.config.NodeConfiguration import net.corda.nodeapi.internal.DevIdentityGenerator +import net.corda.nodeapi.internal.crypto.CertificateAndKeyPair import net.corda.nodeapi.internal.crypto.CertificateType import net.corda.nodeapi.internal.crypto.X509KeyStore import net.corda.nodeapi.internal.crypto.X509Utilities @@ -27,9 +28,12 @@ import org.assertj.core.api.Assertions.* import org.bouncycastle.asn1.x509.GeneralName import org.bouncycastle.asn1.x509.GeneralSubtree import org.bouncycastle.asn1.x509.NameConstraints +import org.bouncycastle.pkcs.PKCS10CertificationRequest +import org.bouncycastle.pkcs.jcajce.JcaPKCS10CertificationRequest import org.junit.After import org.junit.Before import org.junit.Test +import java.security.PublicKey import java.security.cert.CertPathValidatorException import java.security.cert.X509Certificate import javax.security.auth.x500.X500Principal @@ -37,7 +41,6 @@ import kotlin.test.assertFalse class NetworkRegistrationHelperTest { private val fs = Jimfs.newFileSystem(unix()) - private val requestId = SecureHash.randomSHA256().toString() private val nodeLegalName = ALICE_NAME private lateinit var config: NodeConfiguration @@ -69,10 +72,9 @@ class NetworkRegistrationHelperTest { assertThat(config.sslKeystore).doesNotExist() assertThat(config.trustStoreFile).doesNotExist() - val nodeCaCertPath = createNodeCaCertPath() + val rootAndIntermediateCA = createDevIntermediateCaCertPath().also { saveNetworkTrustStore(it.first.certificate) } - saveNetworkTrustStore(nodeCaCertPath.last()) - createRegistrationHelper(nodeCaCertPath).buildKeystore() + createRegistrationHelper(rootAndIntermediateCA = rootAndIntermediateCA).buildKeystore() val nodeKeystore = config.loadNodeKeyStore() val sslKeystore = config.loadSslKeyStore() @@ -82,7 +84,7 @@ class NetworkRegistrationHelperTest { assertFalse(contains(X509Utilities.CORDA_INTERMEDIATE_CA)) assertFalse(contains(X509Utilities.CORDA_ROOT_CA)) assertFalse(contains(X509Utilities.CORDA_CLIENT_TLS)) - assertThat(getCertificateChain(X509Utilities.CORDA_CLIENT_CA)).containsExactlyElementsOf(nodeCaCertPath) + assertThat(CertRole.extract(getCertificate(X509Utilities.CORDA_CLIENT_CA))).isEqualTo(CertRole.NODE_CA) } sslKeystore.run { @@ -93,13 +95,13 @@ class NetworkRegistrationHelperTest { assertThat(nodeTlsCertChain).hasSize(4) // The TLS cert has the same subject as the node CA cert assertThat(CordaX500Name.build(nodeTlsCertChain[0].subjectX500Principal)).isEqualTo(nodeLegalName) - assertThat(nodeTlsCertChain.drop(1)).containsExactlyElementsOf(nodeCaCertPath) + assertThat(CertRole.extract(nodeTlsCertChain.first())).isEqualTo(CertRole.TLS) } trustStore.run { assertFalse(contains(X509Utilities.CORDA_CLIENT_CA)) assertFalse(contains(X509Utilities.CORDA_INTERMEDIATE_CA)) - assertThat(getCertificate(X509Utilities.CORDA_ROOT_CA)).isEqualTo(nodeCaCertPath.last()) + assertThat(getCertificate(X509Utilities.CORDA_ROOT_CA)).isEqualTo(rootAndIntermediateCA.first.certificate) } } @@ -107,7 +109,7 @@ class NetworkRegistrationHelperTest { fun `missing truststore`() { val nodeCaCertPath = createNodeCaCertPath() assertThatThrownBy { - createRegistrationHelper(nodeCaCertPath) + createFixedResponseRegistrationHelper(nodeCaCertPath) }.hasMessageContaining("This file must contain the root CA cert of your compatibility zone. Please contact your CZ operator.") } @@ -115,7 +117,7 @@ class NetworkRegistrationHelperTest { fun `node CA with incorrect cert role`() { val nodeCaCertPath = createNodeCaCertPath(type = CertificateType.TLS) saveNetworkTrustStore(nodeCaCertPath.last()) - val registrationHelper = createRegistrationHelper(nodeCaCertPath) + val registrationHelper = createFixedResponseRegistrationHelper(nodeCaCertPath) assertThatExceptionOfType(CertificateRequestException::class.java) .isThrownBy { registrationHelper.buildKeystore() } .withMessageContaining(CertificateType.TLS.toString()) @@ -126,7 +128,7 @@ class NetworkRegistrationHelperTest { val invalidName = CordaX500Name("Foo", "MU", "GB") val nodeCaCertPath = createNodeCaCertPath(legalName = invalidName) saveNetworkTrustStore(nodeCaCertPath.last()) - val registrationHelper = createRegistrationHelper(nodeCaCertPath) + val registrationHelper = createFixedResponseRegistrationHelper(nodeCaCertPath) assertThatExceptionOfType(CertificateRequestException::class.java) .isThrownBy { registrationHelper.buildKeystore() } .withMessageContaining(invalidName.toString()) @@ -138,7 +140,8 @@ class NetworkRegistrationHelperTest { X500Principal("O=Foo,L=MU,C=GB"), Crypto.generateKeyPair(X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME)) saveNetworkTrustStore(wrongRootCert) - val registrationHelper = createRegistrationHelper(createNodeCaCertPath()) + + val registrationHelper = createRegistrationHelper() assertThatThrownBy { registrationHelper.buildKeystore() }.isInstanceOf(CertPathValidatorException::class.java) @@ -150,10 +153,9 @@ class NetworkRegistrationHelperTest { assertThat(config.sslKeystore).doesNotExist() assertThat(config.trustStoreFile).doesNotExist() - val serviceIdentityCertPath = createServiceIdentityCertPath() + val rootAndIntermediateCA = createDevIntermediateCaCertPath().also { saveNetworkTrustStore(it.first.certificate) } - saveNetworkTrustStore(serviceIdentityCertPath.last()) - createRegistrationHelper(serviceIdentityCertPath, CertRole.SERVICE_IDENTITY).buildKeystore() + createRegistrationHelper(CertRole.SERVICE_IDENTITY, rootAndIntermediateCA).buildKeystore() val nodeKeystore = config.loadNodeKeyStore() @@ -167,42 +169,52 @@ class NetworkRegistrationHelperTest { assertFalse(contains(X509Utilities.CORDA_ROOT_CA)) assertFalse(contains(X509Utilities.CORDA_CLIENT_TLS)) assertFalse(contains(X509Utilities.CORDA_CLIENT_CA)) - assertThat(getCertificateChain(serviceIdentityAlias)).containsExactlyElementsOf(serviceIdentityCertPath) + assertThat(CertRole.extract(getCertificate(serviceIdentityAlias))).isEqualTo(CertRole.SERVICE_IDENTITY) } } private fun createNodeCaCertPath(type: CertificateType = CertificateType.NODE_CA, - legalName: CordaX500Name = nodeLegalName): List { - val (rootCa, intermediateCa) = createDevIntermediateCaCertPath() - val keyPair = Crypto.generateKeyPair(X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME) - val nameConstraints = NameConstraints(arrayOf(GeneralSubtree(GeneralName(GeneralName.directoryName, legalName.x500Name))), arrayOf()) + legalName: CordaX500Name = nodeLegalName, + publicKey: PublicKey = Crypto.generateKeyPair(X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME).public, + rootAndIntermediateCA: Pair = createDevIntermediateCaCertPath()): List { + val (rootCa, intermediateCa) = rootAndIntermediateCA + val nameConstraints = if (type == CertificateType.NODE_CA) { + NameConstraints(arrayOf(GeneralSubtree(GeneralName(GeneralName.directoryName, legalName.x500Name))), arrayOf()) + } else { + null + } val nodeCaCert = X509Utilities.createCertificate( type, intermediateCa.certificate, intermediateCa.keyPair, legalName.x500Principal, - keyPair.public, + publicKey, nameConstraints = nameConstraints) return listOf(nodeCaCert, intermediateCa.certificate, rootCa.certificate) } - private fun createServiceIdentityCertPath(type: CertificateType = CertificateType.SERVICE_IDENTITY, - legalName: CordaX500Name = nodeLegalName): List { - val (rootCa, intermediateCa) = createDevIntermediateCaCertPath() - val keyPair = Crypto.generateKeyPair() - val serviceIdentityCert = X509Utilities.createCertificate( - type, - intermediateCa.certificate, - intermediateCa.keyPair, - legalName.x500Principal, - keyPair.public) - return listOf(serviceIdentityCert, intermediateCa.certificate, rootCa.certificate) + private fun createFixedResponseRegistrationHelper(response: List, certRole: CertRole = CertRole.NODE_CA): NetworkRegistrationHelper { + return createRegistrationHelper(certRole) { response } } - private fun createRegistrationHelper(response: List, certRole: CertRole = CertRole.NODE_CA): NetworkRegistrationHelper { + private fun createRegistrationHelper(certRole: CertRole = CertRole.NODE_CA, rootAndIntermediateCA: Pair = createDevIntermediateCaCertPath()) = createRegistrationHelper(certRole) { + val certType = CertificateType.values().first { it.role == certRole } + createNodeCaCertPath(rootAndIntermediateCA = rootAndIntermediateCA, publicKey = it.publicKey, type = certType) + } + + private fun createRegistrationHelper(certRole: CertRole = CertRole.NODE_CA, dynamicResponse: (JcaPKCS10CertificationRequest) -> List): NetworkRegistrationHelper { val certService = rigorousMock().also { - doReturn(requestId).whenever(it).submitRequest(any()) - doReturn(CertificateResponse(5.seconds, response)).whenever(it).retrieveCertificates(eq(requestId)) + val requests = mutableMapOf() + doAnswer { + val requestId = SecureHash.randomSHA256().toString() + val request = JcaPKCS10CertificationRequest(it.getArgument(0)) + requests[requestId] = request + requestId + }.whenever(it).submitRequest(any()) + + doAnswer { + CertificateResponse(5.seconds, dynamicResponse(requests[it.getArgument(0)]!!)) + }.whenever(it).retrieveCertificates(any()) } return when (certRole) { From 781b50642aec9deeeadee219318509e050f9026e Mon Sep 17 00:00:00 2001 From: Chris Rankin Date: Wed, 9 May 2018 13:37:04 +0100 Subject: [PATCH 4/4] ENT-1463: Prepare node-api for determination. (#3080) * Prepare node-api for determination. * Disentangle Kryo and AMQP classes. * Add version properties for fast-classpath-scanner, proton-j and snappy. * Remove String.jvm extension function. * Refactor Cordapp reference out of AMQP serialisers' primary constructors. --- build.gradle | 5 +- node-api/build.gradle | 9 +-- .../serialization/ByteBufferStreams.kt | 63 +++++++++++++++++++ .../serialization/SerializationScheme.kt | 52 +++++++++------ .../amqp/AMQPSerializationScheme.kt | 28 ++++++--- .../serialization/amqp/AMQPStreams.kt | 7 ++- .../amqp/DeserializationInput.kt | 2 +- .../serialization/amqp/SerializationOutput.kt | 2 +- .../serialization/amqp/SerializerFactory.kt | 3 +- .../serialization/carpenter/ClassCarpenter.kt | 27 ++++---- .../kryo/KryoSerializationScheme.kt | 5 +- .../serialization/kryo/KryoStreams.kt | 35 +---------- .../amqp/SerializationSchemaTests.kt | 2 +- .../serialization/kryo/KryoStreamsTest.kt | 1 + 14 files changed, 149 insertions(+), 92 deletions(-) create mode 100644 node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/ByteBufferStreams.kt diff --git a/build.gradle b/build.gradle index fe6a9308eb..4476afa75d 100644 --- a/build.gradle +++ b/build.gradle @@ -21,7 +21,7 @@ buildscript { // TODO: Upgrade gradle-capsule-plugin to a version with capsule:1.0.3 ext.capsule_version = '1.0.1' - ext.asm_version = '0.5.3' + ext.asm_version = '5.0.4' /* * TODO Upgrade to version 2.4 for large message streaming support @@ -77,6 +77,9 @@ buildscript { ext.eaagentloader_version = '1.0.3' ext.jsch_version = '0.1.54' ext.commons_cli_version = '1.4' + ext.protonj_version = '0.27.1' + ext.snappy_version = '0.4' + ext.fast_classpath_scanner_version = '2.12.3' // Update 121 is required for ObjectInputFilter and at time of writing 131 was latest: ext.java8_minUpdateVersion = '131' diff --git a/node-api/build.gradle b/node-api/build.gradle index 4f3beb3890..745749c4ae 100644 --- a/node-api/build.gradle +++ b/node-api/build.gradle @@ -33,15 +33,16 @@ dependencies { // Kryo: object graph serialization. compile "com.esotericsoftware:kryo:4.0.0" compile "de.javakaffee:kryo-serializers:0.41" + compile "org.ow2.asm:asm:$asm_version" // For AMQP serialisation. - compile "org.apache.qpid:proton-j:0.27.1" + compile "org.apache.qpid:proton-j:$protonj_version" - // FastClasspathScanner: classpath scanning - needed for the NetworkBootstraper - compile 'io.github.lukehutch:fast-classpath-scanner:2.12.3' + // FastClasspathScanner: classpath scanning - needed for the NetworkBootstrapper and AMQP. + compile "io.github.lukehutch:fast-classpath-scanner:$fast_classpath_scanner_version" // Pure-Java Snappy compression - compile 'org.iq80.snappy:snappy:0.4' + compile "org.iq80.snappy:snappy:$snappy_version" // For caches rather than guava compile "com.github.ben-manes.caffeine:caffeine:$caffeine_version" diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/ByteBufferStreams.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/ByteBufferStreams.kt new file mode 100644 index 0000000000..f3710569ed --- /dev/null +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/ByteBufferStreams.kt @@ -0,0 +1,63 @@ +@file:JvmName("ByteBufferStreams") +package net.corda.nodeapi.internal.serialization + +import net.corda.core.internal.LazyPool +import java.io.ByteArrayOutputStream +import java.io.IOException +import java.io.InputStream +import java.io.OutputStream +import java.nio.ByteBuffer +import kotlin.math.min + +internal val serializeOutputStreamPool = LazyPool( + clear = ByteBufferOutputStream::reset, + shouldReturnToPool = { it.size() < 256 * 1024 }, // Discard if it grew too large + newInstance = { ByteBufferOutputStream(64 * 1024) }) + +internal fun byteArrayOutput(task: (ByteBufferOutputStream) -> T): ByteArray { + return serializeOutputStreamPool.run { underlying -> + task(underlying) + underlying.toByteArray() // Must happen after close, to allow ZIP footer to be written for example. + } +} + +class ByteBufferInputStream(val byteBuffer: ByteBuffer) : InputStream() { + @Throws(IOException::class) + override fun read(): Int { + return if (byteBuffer.hasRemaining()) byteBuffer.get().toInt() else -1 + } + + @Throws(IOException::class) + override fun read(b: ByteArray, offset: Int, length: Int): Int { + if (offset < 0 || length < 0 || length > b.size - offset) { + throw IndexOutOfBoundsException() + } else if (length == 0) { + return 0 + } else if (!byteBuffer.hasRemaining()) { + return -1 + } + val size = min(length, byteBuffer.remaining()) + byteBuffer.get(b, offset, size) + return size + } +} + +class ByteBufferOutputStream(size: Int) : ByteArrayOutputStream(size) { + companion object { + private val ensureCapacity = ByteArrayOutputStream::class.java.getDeclaredMethod("ensureCapacity", Int::class.java).apply { + isAccessible = true + } + } + + fun alsoAsByteBuffer(remaining: Int, task: (ByteBuffer) -> T): T { + ensureCapacity.invoke(this, count + remaining) + val buffer = ByteBuffer.wrap(buf, count, remaining) + val result = task(buffer) + count = buffer.position() + return result + } + + fun copyTo(stream: OutputStream) { + stream.write(buf, 0, count) + } +} diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/SerializationScheme.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/SerializationScheme.kt index d99227f800..33962ab02b 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/SerializationScheme.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/SerializationScheme.kt @@ -16,7 +16,7 @@ import java.util.* import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.ExecutionException -val attachmentsClassLoaderEnabledPropertyName = "attachments.class.loader.enabled" +const val attachmentsClassLoaderEnabledPropertyName = "attachments.class.loader.enabled" internal object NullEncodingWhitelist : EncodingWhitelist { override fun acceptEncoding(encoding: SerializationEncoding) = false @@ -30,7 +30,7 @@ data class SerializationContextImpl @JvmOverloads constructor(override val prefe override val useCase: SerializationContext.UseCase, override val encoding: SerializationEncoding?, override val encodingWhitelist: EncodingWhitelist = NullEncodingWhitelist) : SerializationContext { - private val cache: Cache, AttachmentsClassLoader> = Caffeine.newBuilder().weakValues().maximumSize(1024).build() + private val builder = AttachmentsClassLoaderBuilder(properties, deserializationClassLoader) /** * {@inheritDoc} @@ -39,23 +39,8 @@ data class SerializationContextImpl @JvmOverloads constructor(override val prefe */ override fun withAttachmentsClassLoader(attachmentHashes: List): SerializationContext { properties[attachmentsClassLoaderEnabledPropertyName] as? Boolean == true || return this - val serializationContext = properties[serializationContextKey] as? SerializeAsTokenContextImpl - ?: return this // Some tests don't set one. - try { - return withClassLoader(cache.get(attachmentHashes) { - val missing = ArrayList() - val attachments = ArrayList() - attachmentHashes.forEach { id -> - serializationContext.serviceHub.attachments.openAttachment(id)?.let { attachments += it } - ?: run { missing += id } - } - missing.isNotEmpty() && throw MissingAttachmentsException(missing) - AttachmentsClassLoader(attachments, parent = deserializationClassLoader) - }!!) - } catch (e: ExecutionException) { - // Caught from within the cache get, so unwrap. - throw e.cause!! - } + val classLoader = builder.build(attachmentHashes) ?: return this + return withClassLoader(classLoader) } override fun withProperty(property: Any, value: Any): SerializationContext { @@ -80,6 +65,33 @@ data class SerializationContextImpl @JvmOverloads constructor(override val prefe override fun withEncoding(encoding: SerializationEncoding?) = copy(encoding = encoding) } +/* + * This class is internal rather than private so that node-api-deterministic + * can replace it with an alternative version. + */ +internal class AttachmentsClassLoaderBuilder(private val properties: Map, private val deserializationClassLoader: ClassLoader) { + private val cache: Cache, AttachmentsClassLoader> = Caffeine.newBuilder().weakValues().maximumSize(1024).build() + + fun build(attachmentHashes: List): AttachmentsClassLoader? { + val serializationContext = properties[serializationContextKey] as? SerializeAsTokenContext ?: return null // Some tests don't set one. + try { + return cache.get(attachmentHashes) { + val missing = ArrayList() + val attachments = ArrayList() + attachmentHashes.forEach { id -> + serializationContext.serviceHub.attachments.openAttachment(id)?.let { attachments += it } + ?: run { missing += id } + } + missing.isNotEmpty() && throw MissingAttachmentsException(missing) + AttachmentsClassLoader(attachments, parent = deserializationClassLoader) + }!! + } catch (e: ExecutionException) { + // Caught from within the cache get, so unwrap. + throw e.cause!! + } + } +} + open class SerializationFactoryImpl : SerializationFactory() { companion object { val magicSize = sequenceOf(kryoMagic, amqpMagic).map { it.size }.distinct().single() @@ -152,4 +164,4 @@ interface SerializationScheme { @Throws(NotSerializableException::class) fun serialize(obj: T, context: SerializationContext): SerializedBytes -} \ No newline at end of file +} diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/AMQPSerializationScheme.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/AMQPSerializationScheme.kt index ed1ab9c4ed..4c16e1f44a 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/AMQPSerializationScheme.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/AMQPSerializationScheme.kt @@ -35,9 +35,11 @@ open class SerializerFactoryFactory { } abstract class AbstractAMQPSerializationScheme( - val cordappLoader: List, + private val cordappCustomSerializers: Set>, val sff: SerializerFactoryFactory = SerializerFactoryFactory() ) : SerializationScheme { + constructor(cordapps: List) : this(cordapps.customSerializers) + // TODO: This method of initialisation for the Whitelist and plugin serializers will have to change // when we have per-cordapp contexts and dynamic app reloading but for now it's the easiest way companion object { @@ -62,6 +64,8 @@ abstract class AbstractAMQPSerializationScheme( .map { it.kotlin.objectOrNewInstance() } } } + + val List.customSerializers get() = flatMap { it.serializationCustomSerializers }.toSet() } private fun registerCustomSerializers(context: SerializationContext, factory: SerializerFactory) { @@ -103,15 +107,13 @@ abstract class AbstractAMQPSerializationScheme( // If we're passed in an external list we trust that, otherwise revert to looking at the scan of the // classpath to find custom serializers. - if (cordappLoader.isEmpty()) { + if (cordappCustomSerializers.isEmpty()) { for (customSerializer in customSerializers) { factory.registerExternal(CorDappCustomSerializer(customSerializer, factory)) } } else { - cordappLoader.forEach { loader -> - for (customSerializer in loader.serializationCustomSerializers) { - factory.registerExternal(CorDappCustomSerializer(customSerializer, factory)) - } + cordappCustomSerializers.forEach { customSerializer -> + factory.registerExternal(CorDappCustomSerializer(customSerializer, factory)) } } @@ -154,13 +156,16 @@ abstract class AbstractAMQPSerializationScheme( } // TODO: This will eventually cover server RPC as well and move to node module, but for now this is not implemented -class AMQPServerSerializationScheme(cordapps: List = emptyList()) : AbstractAMQPSerializationScheme(cordapps) { +class AMQPServerSerializationScheme(cordappCustomSerializers: Set> = emptySet()) + : AbstractAMQPSerializationScheme(cordappCustomSerializers) { + constructor(cordapps: List) : this(cordapps.customSerializers) + override fun rpcClientSerializerFactory(context: SerializationContext): SerializerFactory { throw UnsupportedOperationException() } override fun rpcServerSerializerFactory(context: SerializationContext): SerializerFactory { - TODO("not implemented") //To change body of created functions use File | Settings | File Templates. + throw UnsupportedOperationException() } override fun canDeserializeVersion(magic: CordaSerializationMagic, target: SerializationContext.UseCase): Boolean { @@ -171,9 +176,12 @@ class AMQPServerSerializationScheme(cordapps: List = emptyList()) : Abs } // TODO: This will eventually cover client RPC as well and move to client module, but for now this is not implemented -class AMQPClientSerializationScheme(cordapps: List = emptyList()) : AbstractAMQPSerializationScheme(cordapps) { +class AMQPClientSerializationScheme(cordappCustomSerializers: Set> = emptySet()) + : AbstractAMQPSerializationScheme(cordappCustomSerializers) { + constructor(cordapps: List) : this(cordapps.customSerializers) + override fun rpcClientSerializerFactory(context: SerializationContext): SerializerFactory { - TODO("not implemented") //To change body of created functions use File | Settings | File Templates. + throw UnsupportedOperationException() } override fun rpcServerSerializerFactory(context: SerializationContext): SerializerFactory { diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/AMQPStreams.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/AMQPStreams.kt index f45ac6d864..dee5d8e425 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/AMQPStreams.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/AMQPStreams.kt @@ -1,8 +1,9 @@ +@file:JvmName("AMQPStreams") package net.corda.nodeapi.internal.serialization.amqp -import com.esotericsoftware.kryo.io.ByteBufferInputStream -import net.corda.nodeapi.internal.serialization.kryo.ByteBufferOutputStream -import net.corda.nodeapi.internal.serialization.kryo.serializeOutputStreamPool +import net.corda.nodeapi.internal.serialization.ByteBufferInputStream +import net.corda.nodeapi.internal.serialization.ByteBufferOutputStream +import net.corda.nodeapi.internal.serialization.serializeOutputStreamPool import java.io.InputStream import java.io.OutputStream import java.nio.ByteBuffer diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializationInput.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializationInput.kt index 71c303439b..592d21f638 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializationInput.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializationInput.kt @@ -37,7 +37,7 @@ class DeserializationInput @JvmOverloads constructor(private val serializerFacto private val objectHistory: MutableList = mutableListOf() companion object { - private val BYTES_NEEDED_TO_PEEK: Int = 23 + private const val BYTES_NEEDED_TO_PEEK: Int = 23 fun peekSize(bytes: ByteArray): Int { // There's an 8 byte header, and then a 0 byte plus descriptor followed by constructor diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializationOutput.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializationOutput.kt index 3290c063ba..7a952e5310 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializationOutput.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializationOutput.kt @@ -5,7 +5,7 @@ import net.corda.core.serialization.SerializationEncoding import net.corda.core.serialization.SerializedBytes import net.corda.nodeapi.internal.serialization.CordaSerializationEncoding import net.corda.nodeapi.internal.serialization.SectionId -import net.corda.nodeapi.internal.serialization.kryo.byteArrayOutput +import net.corda.nodeapi.internal.serialization.byteArrayOutput import org.apache.qpid.proton.codec.Data import java.io.NotSerializableException import java.io.OutputStream diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializerFactory.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializerFactory.kt index abf8034115..816f35d570 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializerFactory.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializerFactory.kt @@ -2,7 +2,6 @@ package net.corda.nodeapi.internal.serialization.amqp import com.google.common.primitives.Primitives import com.google.common.reflect.TypeResolver -import net.corda.core.internal.getStackTraceAsString import net.corda.core.internal.uncheckedCast import net.corda.core.serialization.ClassWhitelist import net.corda.core.utilities.loggerFor @@ -247,7 +246,7 @@ open class SerializerFactory( // preserve the actual message locally loggerFor().apply { error("${e.message} [hint: enable trace debugging for the stack trace]") - trace(e.getStackTraceAsString()) + trace("", e) } // prevent carpenter exceptions escaping into the world, convert things into a nice diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/ClassCarpenter.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/ClassCarpenter.kt index 6dc7453a1a..5ac6c95c7c 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/ClassCarpenter.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/ClassCarpenter.kt @@ -1,5 +1,6 @@ package net.corda.nodeapi.internal.serialization.carpenter +import com.google.common.base.MoreObjects import net.corda.core.serialization.ClassWhitelist import net.corda.core.serialization.CordaSerializable import org.objectweb.asm.ClassWriter @@ -23,7 +24,7 @@ interface SimpleFieldAccess { class CarpenterClassLoader(parentClassLoader: ClassLoader = Thread.currentThread().contextClassLoader) : ClassLoader(parentClassLoader) { - fun load(name: String, bytes: ByteArray) = defineClass(name, bytes, 0, bytes.size) + fun load(name: String, bytes: ByteArray): Class<*> = defineClass(name, bytes, 0, bytes.size) } class InterfaceMismatchNonGetterException(val clazz: Class<*>, val method: Method) : InterfaceMismatchException( @@ -37,10 +38,12 @@ class InterfaceMismatchMissingAMQPFieldException(val clazz: Class<*>, val field: */ private const val TARGET_VERSION = V1_8 -private val jlEnum get() = Type.getInternalName(Enum::class.java) -private val jlString get() = Type.getInternalName(String::class.java) -private val jlObject get() = Type.getInternalName(Object::class.java) -private val jlClass get() = Type.getInternalName(Class::class.java) +private val jlEnum: String = Type.getInternalName(Enum::class.java) +private val jlString: String = Type.getInternalName(String::class.java) +private val jlObject: String = Type.getInternalName(Object::class.java) +private val jlClass: String = Type.getInternalName(Class::class.java) +private val moreObjects: String = Type.getInternalName(MoreObjects::class.java) +private val toStringHelper: String = Type.getInternalName(MoreObjects.ToStringHelper::class.java) /** * A class carpenter generates JVM bytecodes for a class given a schema and then loads it into a sub-classloader. @@ -97,7 +100,6 @@ class ClassCarpenter(cl: ClassLoader = Thread.currentThread().contextClassLoader val classloader = CarpenterClassLoader(cl) private val _loaded = HashMap>() - private val String.jvm: String get() = replace(".", "/") /** Returns a snapshot of the currently loaded classes as a map of full class name (package names+dots) -> class object */ val loaded: Map> = HashMap(_loaded) @@ -155,7 +157,7 @@ class ClassCarpenter(cl: ClassLoader = Thread.currentThread().contextClassLoader private fun generateInterface(interfaceSchema: Schema): Class<*> { return generate(interfaceSchema) { cw, schema -> - val interfaces = schema.interfaces.map { it.name.jvm }.toTypedArray() + val interfaces = schema.interfaces.map { Type.getInternalName(it) }.toTypedArray() cw.apply { visit(TARGET_VERSION, ACC_PUBLIC + ACC_ABSTRACT + ACC_INTERFACE, schema.jvmName, null, @@ -172,12 +174,12 @@ class ClassCarpenter(cl: ClassLoader = Thread.currentThread().contextClassLoader private fun generateClass(classSchema: Schema): Class<*> { return generate(classSchema) { cw, schema -> val superName = schema.superclass?.jvmName ?: jlObject - val interfaces = schema.interfaces.map { it.name.jvm }.toMutableList() + val interfaces = schema.interfaces.map { Type.getInternalName(it) }.toMutableList() if (SimpleFieldAccess::class.java !in schema.interfaces && schema.flags.cordaSerializable() && schema.flags.simpleFieldAccess()) { - interfaces.add(SimpleFieldAccess::class.java.name.jvm) + interfaces.add(Type.getInternalName(SimpleFieldAccess::class.java)) } cw.apply { @@ -214,12 +216,11 @@ class ClassCarpenter(cl: ClassLoader = Thread.currentThread().contextClassLoader } private fun ClassWriter.generateToString(schema: Schema) { - val toStringHelper = "com/google/common/base/MoreObjects\$ToStringHelper" with(visitMethod(ACC_PUBLIC, "toString", "()L$jlString;", null, null)) { visitCode() // com.google.common.base.MoreObjects.toStringHelper("TypeName") visitLdcInsn(schema.name.split('.').last()) - visitMethodInsn(INVOKESTATIC, "com/google/common/base/MoreObjects", "toStringHelper", + visitMethodInsn(INVOKESTATIC, moreObjects, "toStringHelper", "(L$jlString;)L$toStringHelper;", false) // Call the add() methods. for ((name, field) in schema.fieldsIncludingSuperclasses().entries) { @@ -237,7 +238,7 @@ class ClassCarpenter(cl: ClassLoader = Thread.currentThread().contextClassLoader } private fun ClassWriter.generateGetMethod() { - val ourJvmName = ClassCarpenter::class.java.name.jvm + val ourJvmName = Type.getInternalName(ClassCarpenter::class.java) with(visitMethod(ACC_PUBLIC, "get", "(L$jlString;)L$jlObject;", null, null)) { visitCode() visitVarInsn(ALOAD, 0) // Load 'this' @@ -372,7 +373,7 @@ class ClassCarpenter(cl: ClassLoader = Thread.currentThread().contextClassLoader var slot = 1 superclassFields.values.forEach { slot += load(slot, it) } val superDesc = sc.descriptorsIncludingSuperclasses().values.joinToString("") - visitMethodInsn(INVOKESPECIAL, sc.name.jvm, "", "($superDesc)V", false) + visitMethodInsn(INVOKESPECIAL, sc.jvmName, "", "($superDesc)V", false) } // Assign the fields from parameters. diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/KryoSerializationScheme.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/KryoSerializationScheme.kt index 874219e1ec..405ae1b72c 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/KryoSerializationScheme.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/KryoSerializationScheme.kt @@ -5,7 +5,6 @@ import co.paralleluniverse.io.serialization.kryo.KryoSerializer import com.esotericsoftware.kryo.Kryo import com.esotericsoftware.kryo.KryoException import com.esotericsoftware.kryo.Serializer -import com.esotericsoftware.kryo.io.ByteBufferInputStream import com.esotericsoftware.kryo.io.Input import com.esotericsoftware.kryo.io.Output import com.esotericsoftware.kryo.pool.KryoPool @@ -39,8 +38,8 @@ abstract class AbstractKryoSerializationScheme : SerializationScheme { protected abstract fun rpcClientKryoPool(context: SerializationContext): KryoPool protected abstract fun rpcServerKryoPool(context: SerializationContext): KryoPool - // this can be overriden in derived serialization schemes - open protected val publicKeySerializer: Serializer = PublicKeySerializer + // this can be overridden in derived serialization schemes + protected open val publicKeySerializer: Serializer = PublicKeySerializer private fun getPool(context: SerializationContext): KryoPool { return kryoPoolsForContexts.computeIfAbsent(Pair(context.whitelist, context.deserializationClassLoader)) { diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/KryoStreams.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/KryoStreams.kt index d88d943a79..977f4077fc 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/KryoStreams.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/KryoStreams.kt @@ -1,40 +1,16 @@ +@file:JvmName("KryoStreams") package net.corda.nodeapi.internal.serialization.kryo import com.esotericsoftware.kryo.io.Input import com.esotericsoftware.kryo.io.Output import net.corda.core.internal.LazyPool -import java.io.ByteArrayOutputStream +import net.corda.nodeapi.internal.serialization.byteArrayOutput import java.io.InputStream import java.io.OutputStream import java.io.SequenceInputStream -import java.nio.ByteBuffer - -class ByteBufferOutputStream(size: Int) : ByteArrayOutputStream(size) { - companion object { - private val ensureCapacity = ByteArrayOutputStream::class.java.getDeclaredMethod("ensureCapacity", Int::class.java).apply { - isAccessible = true - } - } - - fun alsoAsByteBuffer(remaining: Int, task: (ByteBuffer) -> T): T { - ensureCapacity.invoke(this, count + remaining) - val buffer = ByteBuffer.wrap(buf, count, remaining) - val result = task(buffer) - count = buffer.position() - return result - } - - fun copyTo(stream: OutputStream) { - stream.write(buf, 0, count) - } -} private val serializationBufferPool = LazyPool( newInstance = { ByteArray(64 * 1024) }) -internal val serializeOutputStreamPool = LazyPool( - clear = ByteBufferOutputStream::reset, - shouldReturnToPool = { it.size() < 256 * 1024 }, // Discard if it grew too large - newInstance = { ByteBufferOutputStream(64 * 1024) }) internal fun kryoInput(underlying: InputStream, task: Input.() -> T): T { return serializationBufferPool.run { @@ -56,13 +32,6 @@ internal fun kryoOutput(task: Output.() -> T): ByteArray { } } -internal fun byteArrayOutput(task: (ByteBufferOutputStream) -> T): ByteArray { - return serializeOutputStreamPool.run { underlying -> - task(underlying) - underlying.toByteArray() // Must happen after close, to allow ZIP footer to be written for example. - } -} - internal fun Output.substitute(transform: (OutputStream) -> OutputStream) { flush() outputStream = transform(outputStream) diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializationSchemaTests.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializationSchemaTests.kt index 066278b2c7..e1218ff917 100644 --- a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializationSchemaTests.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializationSchemaTests.kt @@ -43,7 +43,7 @@ class TestSerializerFactoryFactory : SerializerFactoryFactory() { } } -class AMQPTestSerializationScheme : AbstractAMQPSerializationScheme(emptyList(), TestSerializerFactoryFactory()) { +class AMQPTestSerializationScheme : AbstractAMQPSerializationScheme(emptySet(), TestSerializerFactoryFactory()) { override fun rpcClientSerializerFactory(context: SerializationContext): SerializerFactory { throw UnsupportedOperationException() } diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/kryo/KryoStreamsTest.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/kryo/KryoStreamsTest.kt index d8eedd305d..881d304c41 100644 --- a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/kryo/KryoStreamsTest.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/kryo/KryoStreamsTest.kt @@ -1,6 +1,7 @@ package net.corda.nodeapi.internal.serialization.kryo import net.corda.core.internal.declaredField +import net.corda.nodeapi.internal.serialization.ByteBufferOutputStream import org.assertj.core.api.Assertions.catchThrowable import org.junit.Assert.assertArrayEquals import org.junit.Test