diff --git a/src/main/kotlin/core/testing/MockNode.kt b/src/main/kotlin/core/testing/MockNode.kt index e4d0335c56..e0effdaeb8 100644 --- a/src/main/kotlin/core/testing/MockNode.kt +++ b/src/main/kotlin/core/testing/MockNode.kt @@ -34,7 +34,8 @@ import java.util.concurrent.Executors * for message exchanges to take place (and associated handlers to run), you must call the [runNetwork] * method. */ -class MockNetwork(private val threadPerNode: Boolean = false) { +class MockNetwork(private val threadPerNode: Boolean = false, + private val defaultFactory: Factory = MockNetwork.DefaultFactory) { private var counter = 0 val filesystem = Jimfs.newFileSystem(com.google.common.jimfs.Configuration.unix()) val messagingNetwork = InMemoryMessagingNetwork() @@ -49,6 +50,19 @@ class MockNetwork(private val threadPerNode: Boolean = false) { Files.createDirectory(filesystem.getPath("/nodes")) } + /** Allows customisation of how nodes are created. */ + interface Factory { + fun create(dir: Path, config: NodeConfiguration, network: MockNetwork, + timestamperAddr: LegallyIdentifiableNode?): MockNode + } + + object DefaultFactory : Factory { + override fun create(dir: Path, config: NodeConfiguration, network: MockNetwork, + timestamperAddr: LegallyIdentifiableNode?): MockNode { + return MockNode(dir, config, network, timestamperAddr) + } + } + open class MockNode(dir: Path, config: NodeConfiguration, val mockNet: MockNetwork, withTimestamper: LegallyIdentifiableNode?, val forcedID: Int = -1) : AbstractNode(dir, config, withTimestamper, Clock.systemUTC()) { override val log: Logger = loggerFor<MockNode>() @@ -81,8 +95,7 @@ class MockNetwork(private val threadPerNode: Boolean = false) { } /** Returns a started node, optionally created by the passed factory method */ - fun createNode(withTimestamper: LegallyIdentifiableNode?, forcedID: Int = -1, - factory: ((Path, NodeConfiguration, network: MockNetwork, LegallyIdentifiableNode?) -> MockNode)? = null): MockNode { + fun createNode(withTimestamper: LegallyIdentifiableNode?, forcedID: Int = -1, nodeFactory: Factory = defaultFactory): MockNode { val newNode = forcedID == -1 val id = if (newNode) counter++ else forcedID @@ -94,8 +107,7 @@ class MockNetwork(private val threadPerNode: Boolean = false) { override val exportJMXto: String = "" override val nearestCity: String = "Atlantis" } - val fac = factory ?: { p, n, n2, l -> MockNode(p, n, n2, l, id) } - val node = fac(path, config, this, withTimestamper).start() + val node = nodeFactory.create(path, config, this, withTimestamper).start() _nodes.add(node) return node } @@ -117,8 +129,8 @@ class MockNetwork(private val threadPerNode: Boolean = false) { /** * Sets up a two node network in which the first node runs a timestamping service and the other doesn't. */ - fun createTwoNodes(factory: ((Path, NodeConfiguration, network: MockNetwork, LegallyIdentifiableNode?) -> MockNode)? = null): Pair<MockNode, MockNode> { + fun createTwoNodes(nodeFactory: Factory = defaultFactory): Pair<MockNode, MockNode> { require(nodes.isEmpty()) - return Pair(createNode(null, -1, factory), createNode(nodes[0].legallyIdentifableAddress, -1, factory)) + return Pair(createNode(null, -1, nodeFactory), createNode(nodes[0].legallyIdentifableAddress, -1, nodeFactory)) } } \ No newline at end of file diff --git a/src/test/kotlin/core/messaging/AttachmentTests.kt b/src/test/kotlin/core/messaging/AttachmentTests.kt index e63c9ecf74..66fbaab6ff 100644 --- a/src/test/kotlin/core/messaging/AttachmentTests.kt +++ b/src/test/kotlin/core/messaging/AttachmentTests.kt @@ -8,22 +8,25 @@ package core.messaging -import protocols.FetchAttachmentsProtocol -import protocols.FetchDataProtocol import core.Attachment import core.crypto.SecureHash import core.crypto.sha256 -import core.testing.MockNetwork +import core.node.NodeConfiguration +import core.node.services.LegallyIdentifiableNode import core.node.services.NodeAttachmentService import core.serialization.OpaqueBytes +import core.testing.MockNetwork import core.testutils.rootCauseExceptions import core.utilities.BriefLogFormatter import org.junit.Before import org.junit.Test +import protocols.FetchAttachmentsProtocol +import protocols.FetchDataProtocol import java.io.ByteArrayInputStream import java.io.ByteArrayOutputStream import java.nio.ByteBuffer import java.nio.file.Files +import java.nio.file.Path import java.nio.file.StandardOpenOption import java.util.jar.JarOutputStream import java.util.zip.ZipEntry @@ -90,15 +93,17 @@ class AttachmentTests { @Test fun maliciousResponse() { // Make a node that doesn't do sanity checking at load time. - val n0 = network.createNode(null) { path, config, mock, ts -> - object : MockNetwork.MockNode(path, config, mock, ts) { - override fun start(): MockNetwork.MockNode { - super.start() - (storage.attachments as NodeAttachmentService).checkAttachmentsOnLoad = false - return this + val n0 = network.createNode(null, nodeFactory = object : MockNetwork.Factory { + override fun create(dir: Path, config: NodeConfiguration, network: MockNetwork, timestamperAddr: LegallyIdentifiableNode?): MockNetwork.MockNode { + return object : MockNetwork.MockNode(dir, config, network, timestamperAddr) { + override fun start(): MockNetwork.MockNode { + super.start() + (storage.attachments as NodeAttachmentService).checkAttachmentsOnLoad = false + return this + } } } - } + }) val n1 = network.createNode(n0.legallyIdentifableAddress) // Insert an attachment into node zero's store directly. diff --git a/src/test/kotlin/core/messaging/TwoPartyTradeProtocolTests.kt b/src/test/kotlin/core/messaging/TwoPartyTradeProtocolTests.kt index cba57e831d..5616f14ca1 100644 --- a/src/test/kotlin/core/messaging/TwoPartyTradeProtocolTests.kt +++ b/src/test/kotlin/core/messaging/TwoPartyTradeProtocolTests.kt @@ -12,6 +12,7 @@ import contracts.Cash import contracts.CommercialPaper import core.* import core.crypto.SecureHash +import core.node.NodeConfiguration import core.node.services.* import core.testing.InMemoryMessagingNetwork import core.testing.MockNetwork @@ -163,16 +164,18 @@ class TwoPartyTradeProtocolTests : TestWithInMemoryNetwork() { // ... bring the node back up ... the act of constructing the SMM will re-register the message handlers // that Bob was waiting on before the reboot occurred. - bobNode = net.createNode(timestamperAddr, bobAddr.id) { path, nodeConfiguration, net, timestamper -> - object : MockNetwork.MockNode(path, nodeConfiguration, net, timestamper, bobAddr.id) { - override fun initialiseStorageService(dir: Path): StorageService { - val ss = super.initialiseStorageService(dir) - val smMap = ss.stateMachines - smMap.putAll(savedCheckpoints) - return ss + bobNode = net.createNode(timestamperAddr, bobAddr.id, object : MockNetwork.Factory { + override fun create(dir: Path, config: NodeConfiguration, network: MockNetwork, timestamperAddr: LegallyIdentifiableNode?): MockNetwork.MockNode { + return object : MockNetwork.MockNode(dir, config, net, timestamperAddr, bobAddr.id) { + override fun initialiseStorageService(dir: Path): StorageService { + val ss = super.initialiseStorageService(dir) + val smMap = ss.stateMachines + smMap.putAll(savedCheckpoints) + return ss + } } } - } + }) // Find the future representing the result of this state machine again. var bobFuture = bobNode.smm.findStateMachines(TwoPartyTradeProtocol.Buyer::class.java).single().second @@ -192,16 +195,18 @@ class TwoPartyTradeProtocolTests : TestWithInMemoryNetwork() { // of gets and puts. private fun makeNodeWithTracking(name: String): MockNetwork.MockNode { // Create a node in the mock network ... - return net.createNode(null) { path, config, net, tsNode -> - object : MockNetwork.MockNode(path, config, net, tsNode) { - // That constructs the storage service object in a customised way ... - override fun constructStorageService(attachments: NodeAttachmentService, keypair: KeyPair, identity: Party, - contractFactory: ContractFactory): StorageServiceImpl { - // To use RecordingMaps instead of ordinary HashMaps. - return StorageServiceImpl(attachments, contractFactory, keypair, identity, { tableName -> name }) + return net.createNode(null, nodeFactory = object : MockNetwork.Factory { + override fun create(dir: Path, config: NodeConfiguration, network: MockNetwork, timestamperAddr: LegallyIdentifiableNode?): MockNetwork.MockNode { + return object : MockNetwork.MockNode(dir, config, network, timestamperAddr) { + // That constructs the storage service object in a customised way ... + override fun constructStorageService(attachments: NodeAttachmentService, keypair: KeyPair, identity: Party, + contractFactory: ContractFactory): StorageServiceImpl { + // To use RecordingMaps instead of ordinary HashMaps. + return StorageServiceImpl(attachments, contractFactory, keypair, identity, { tableName -> name }) + } } } - } + }) } @Test