diff --git a/core/src/test/kotlin/net/corda/core/flows/FlowTestsUtils.kt b/core/src/test/kotlin/net/corda/core/flows/FlowTestsUtils.kt index 82c825c052..0108e91dee 100644 --- a/core/src/test/kotlin/net/corda/core/flows/FlowTestsUtils.kt +++ b/core/src/test/kotlin/net/corda/core/flows/FlowTestsUtils.kt @@ -48,7 +48,7 @@ class NoAnswer(private val closure: () -> Unit = {}) : FlowLogic() { /** * Allows to register a flow of type [R] against an initiating flow of type [I]. */ -inline fun , reified R : FlowLogic<*>> StartedNode.registerInitiatedFlow(initiatingFlowType: KClass, crossinline construct: (session: FlowSession) -> R) { +inline fun , reified R : FlowLogic<*>> StartedNode<*>.registerInitiatedFlow(initiatingFlowType: KClass, crossinline construct: (session: FlowSession) -> R) { internalRegisterFlowFactory(initiatingFlowType.java, InitiatedFlowFactory.Core { session -> construct(session) }, R::class.javaObjectType, true) } diff --git a/node/src/integration-test/kotlin/net/corda/MessageState.kt b/node/src/integration-test/kotlin/net/corda/MessageState.kt new file mode 100644 index 0000000000..fcbe58417c --- /dev/null +++ b/node/src/integration-test/kotlin/net/corda/MessageState.kt @@ -0,0 +1,70 @@ +package net.corda + +import net.corda.core.contracts.* +import net.corda.core.identity.AbstractParty +import net.corda.core.identity.Party +import net.corda.core.schemas.MappedSchema +import net.corda.core.schemas.PersistentState +import net.corda.core.schemas.QueryableState +import net.corda.core.serialization.CordaSerializable +import net.corda.core.transactions.LedgerTransaction +import javax.persistence.Column +import javax.persistence.Entity +import javax.persistence.Table + +@CordaSerializable +data class Message(val value: String) + +data class MessageState(val message: Message, val by: Party, override val linearId: UniqueIdentifier = UniqueIdentifier()) : LinearState, QueryableState { + override val participants: List = listOf(by) + + override fun generateMappedObject(schema: MappedSchema): PersistentState { + return when (schema) { + is MessageSchemaV1 -> MessageSchemaV1.PersistentMessage( + by = by.name.toString(), + value = message.value + ) + else -> throw IllegalArgumentException("Unrecognised schema $schema") + } + } + + override fun supportedSchemas(): Iterable = listOf(MessageSchemaV1) +} + +object MessageSchema +object MessageSchemaV1 : MappedSchema( + schemaFamily = MessageSchema.javaClass, + version = 1, + mappedTypes = listOf(PersistentMessage::class.java)) { + + @Entity + @Table(name = "messages") + class PersistentMessage( + @Column(name = "by") + var by: String, + + @Column(name = "value") + var value: String + ) : PersistentState() +} + +const val MESSAGE_CONTRACT_PROGRAM_ID = "net.corda.MessageContract" + +open class MessageContract : Contract { + override fun verify(tx: LedgerTransaction) { + val command = tx.commands.requireSingleCommand() + requireThat { + // Generic constraints around the IOU transaction. + "No inputs should be consumed when sending a message." using (tx.inputs.isEmpty()) + "Only one output state should be created." using (tx.outputs.size == 1) + val out = tx.outputsOfType().single() + "Message sender must sign." using (command.signers.containsAll(out.participants.map { it.owningKey })) + + "Message value must not be empty." using (out.message.value.isNotBlank()) + } + } + + interface Commands : CommandData { + class Send : Commands + } +} \ No newline at end of file diff --git a/node/src/integration-test/kotlin/net/corda/RpcInfo.kt b/node/src/integration-test/kotlin/net/corda/RpcInfo.kt new file mode 100644 index 0000000000..c826abb1c4 --- /dev/null +++ b/node/src/integration-test/kotlin/net/corda/RpcInfo.kt @@ -0,0 +1,7 @@ +package net.corda + +import net.corda.core.serialization.CordaSerializable +import net.corda.core.utilities.NetworkHostAndPort + +@CordaSerializable +data class RpcInfo(val address: NetworkHostAndPort, val username: String, val password: String) \ No newline at end of file diff --git a/node/src/integration-test/kotlin/net/corda/test/node/NodeStatePersistenceTests.kt b/node/src/integration-test/kotlin/net/corda/node/NodeStatePersistenceTests.kt similarity index 73% rename from node/src/integration-test/kotlin/net/corda/test/node/NodeStatePersistenceTests.kt rename to node/src/integration-test/kotlin/net/corda/node/NodeStatePersistenceTests.kt index ef0bc88694..b24f1affd2 100644 --- a/node/src/integration-test/kotlin/net/corda/test/node/NodeStatePersistenceTests.kt +++ b/node/src/integration-test/kotlin/net/corda/node/NodeStatePersistenceTests.kt @@ -8,43 +8,42 @@ * Distribution of this file or any portion thereof via any medium without the express permission of R3 is strictly prohibited. */ -package net.corda.test.node +package net.corda.node import co.paralleluniverse.fibers.Suspendable +import net.corda.MESSAGE_CONTRACT_PROGRAM_ID +import net.corda.Message +import net.corda.MessageContract +import net.corda.MessageState import net.corda.client.rpc.CordaRPCClient -import net.corda.core.contracts.* +import net.corda.core.contracts.Command +import net.corda.core.contracts.StateAndContract +import net.corda.core.contracts.StateAndRef import net.corda.core.flows.FinalityFlow import net.corda.core.flows.FlowLogic import net.corda.core.flows.StartableByRPC -import net.corda.core.identity.AbstractParty import net.corda.core.identity.Party +import net.corda.core.internal.packageName import net.corda.core.messaging.startFlow -import net.corda.core.schemas.MappedSchema -import net.corda.core.schemas.PersistentState -import net.corda.core.schemas.QueryableState -import net.corda.core.serialization.CordaSerializable -import net.corda.core.transactions.LedgerTransaction import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.TransactionBuilder import net.corda.core.utilities.ProgressTracker import net.corda.core.utilities.getOrThrow import net.corda.node.services.Permissions.Companion.invokeRpc import net.corda.node.services.Permissions.Companion.startFlow +import net.corda.testing.core.singleIdentity import net.corda.testing.core.* import net.corda.testing.driver.DriverParameters import net.corda.testing.driver.driver import net.corda.testing.driver.internal.RandomFree +import net.corda.testing.node.User import net.corda.testing.internal.IntegrationTest import net.corda.testing.internal.IntegrationTestSchemas import net.corda.testing.internal.toDatabaseSchemaName -import net.corda.testing.node.User import org.junit.Assume.assumeFalse import org.junit.ClassRule import org.junit.Test import java.lang.management.ManagementFactory -import javax.persistence.Column -import javax.persistence.Entity -import javax.persistence.Table import kotlin.test.assertEquals import kotlin.test.assertNotNull @@ -63,7 +62,7 @@ class NodeStatePersistenceTests : IntegrationTest() { val user = User("mark", "dadada", setOf(startFlow(), invokeRpc("vaultQuery"))) val message = Message("Hello world!") - val stateAndRef: StateAndRef? = driver(DriverParameters(isDebug = true, startNodesInProcess = isQuasarAgentSpecified(), portAllocation = RandomFree)) { + val stateAndRef: StateAndRef? = driver(DriverParameters(isDebug = true, startNodesInProcess = isQuasarAgentSpecified(), portAllocation = RandomFree, extraCordappPackagesToScan = listOf(MessageState::class.packageName))) { val nodeName = { val nodeHandle = startNode(rpcUsers = listOf(user)).getOrThrow() val nodeName = nodeHandle.nodeInfo.singleIdentity().name @@ -97,7 +96,7 @@ class NodeStatePersistenceTests : IntegrationTest() { val user = User("mark", "dadada", setOf(startFlow(), invokeRpc("vaultQuery"))) val message = Message("Hello world!") - val stateAndRef: StateAndRef? = driver(DriverParameters(isDebug = true, startNodesInProcess = isQuasarAgentSpecified(), portAllocation = RandomFree)) { + val stateAndRef: StateAndRef? = driver(DriverParameters(isDebug = true, startNodesInProcess = isQuasarAgentSpecified(), portAllocation = RandomFree, extraCordappPackagesToScan = listOf(MessageState::class.packageName))) { val nodeName = { val nodeHandle = startNode(rpcUsers = listOf(user)).getOrThrow() val nodeName = nodeHandle.nodeInfo.singleIdentity().name @@ -129,65 +128,6 @@ fun isQuasarAgentSpecified(): Boolean { return jvmArgs.any { it.startsWith("-javaagent:") && it.endsWith("quasar.jar") } } -@CordaSerializable -data class Message(val value: String) - -data class MessageState(val message: Message, val by: Party, override val linearId: UniqueIdentifier = UniqueIdentifier()) : LinearState, QueryableState { - override val participants: List = listOf(by) - - override fun generateMappedObject(schema: MappedSchema): PersistentState { - return when (schema) { - is MessageSchemaV1 -> MessageSchemaV1.PersistentMessage( - by = by.name.toString(), - value = message.value - ) - else -> throw IllegalArgumentException("Unrecognised schema $schema") - } - } - - override fun supportedSchemas(): Iterable = listOf(MessageSchemaV1) -} - -object MessageSchema -object MessageSchemaV1 : MappedSchema( - schemaFamily = MessageSchema.javaClass, - version = 1, - mappedTypes = listOf(PersistentMessage::class.java)) { - - override val migrationResource = "message-schema.changelog-init" - - @Entity - @Table(name = "messages") - class PersistentMessage( - @Column(name = "message_by") - var by: String, - - @Column(name = "message_value") - var value: String - ) : PersistentState() -} - -const val MESSAGE_CONTRACT_PROGRAM_ID = "net.corda.test.node.MessageContract" - -open class MessageContract : Contract { - override fun verify(tx: LedgerTransaction) { - val command = tx.commands.requireSingleCommand() - requireThat { - // Generic constraints around the IOU transaction. - "No inputs should be consumed when sending a message." using (tx.inputs.isEmpty()) - "Only one output state should be created." using (tx.outputs.size == 1) - val out = tx.outputsOfType().single() - "Message sender must sign." using (command.signers.containsAll(out.participants.map { it.owningKey })) - - "Message value must not be empty." using (out.message.value.isNotBlank()) - } - } - - interface Commands : CommandData { - class Send : Commands - } -} - @StartableByRPC class SendMessageFlow(private val message: Message, private val notary: Party) : FlowLogic() { companion object { diff --git a/node/src/integration-test/kotlin/net/corda/node/modes/draining/FlowsDrainingModeContentionTest.kt b/node/src/integration-test/kotlin/net/corda/node/modes/draining/FlowsDrainingModeContentionTest.kt new file mode 100644 index 0000000000..4716374691 --- /dev/null +++ b/node/src/integration-test/kotlin/net/corda/node/modes/draining/FlowsDrainingModeContentionTest.kt @@ -0,0 +1,116 @@ +package net.corda.node.modes.draining + +import co.paralleluniverse.fibers.Suspendable +import net.corda.MESSAGE_CONTRACT_PROGRAM_ID +import net.corda.Message +import net.corda.MessageContract +import net.corda.MessageState +import net.corda.core.contracts.Command +import net.corda.core.contracts.StateAndContract +import net.corda.core.flows.* +import net.corda.core.identity.Party +import net.corda.core.internal.packageName +import net.corda.core.messaging.startFlow +import net.corda.core.transactions.SignedTransaction +import net.corda.core.transactions.TransactionBuilder +import net.corda.core.utilities.getOrThrow +import net.corda.core.utilities.unwrap +import net.corda.RpcInfo +import net.corda.client.rpc.CordaRPCClient +import net.corda.node.services.Permissions.Companion.all +import net.corda.testing.core.singleIdentity +import net.corda.testing.driver.DriverParameters +import net.corda.testing.driver.PortAllocation +import net.corda.testing.driver.driver +import net.corda.testing.internal.IntegrationTest +import net.corda.testing.node.User +import org.assertj.core.api.Assertions.assertThat +import org.junit.After +import org.junit.Before +import org.junit.Test +import java.util.concurrent.Executors +import java.util.concurrent.ScheduledExecutorService + +class FlowsDrainingModeContentionTest : IntegrationTest() { + + private val portAllocation = PortAllocation.Incremental(10000) + private val user = User("mark", "dadada", setOf(all())) + private val users = listOf(user) + + private var executor: ScheduledExecutorService? = null + + @Before + fun setup() { + executor = Executors.newSingleThreadScheduledExecutor() + } + + @After + fun cleanUp() { + executor!!.shutdown() + } + + @Test + fun `draining mode does not deadlock with acks between 2 nodes`() { + + val message = "Ground control to Major Tom" + + driver(DriverParameters(isDebug = true, startNodesInProcess = true, portAllocation = portAllocation, extraCordappPackagesToScan = listOf(MessageState::class.packageName))) { + + val nodeA = startNode(rpcUsers = users).getOrThrow() + val nodeB = startNode(rpcUsers = users).getOrThrow() + defaultNotaryNode.getOrThrow() + + val nodeARpcInfo = RpcInfo(nodeA.rpcAddress, user.username, user.password) + val flow = nodeA.rpc.startFlow(::ProposeTransactionAndWaitForCommit, message, nodeARpcInfo, nodeB.nodeInfo.singleIdentity(), defaultNotaryIdentity) + val committedTx = flow.returnValue.getOrThrow() + + committedTx.inputs + committedTx.tx.outputs + assertThat(committedTx.tx.outputsOfType().single().message.value).isEqualTo(message) + } + } +} + +@StartableByRPC +@InitiatingFlow +class ProposeTransactionAndWaitForCommit(private val data: String, private val myRpcInfo: RpcInfo, private val counterParty: Party, private val notary: Party) : FlowLogic() { + + @Suspendable + override fun call(): SignedTransaction { + + val session = initiateFlow(counterParty) + val messageState = MessageState(message = Message(data), by = ourIdentity) + val command = Command(MessageContract.Commands.Send(), messageState.participants.map { it.owningKey }) + val transaction = TransactionBuilder(notary) + transaction.withItems(StateAndContract(messageState, MESSAGE_CONTRACT_PROGRAM_ID), command) + val signedTx = serviceHub.signInitialTransaction(transaction) + + subFlow(SendTransactionFlow(session, signedTx)) + session.send(myRpcInfo) + + return waitForLedgerCommit(signedTx.id) + } +} + +@InitiatedBy(ProposeTransactionAndWaitForCommit::class) +class SignTransactionTriggerDrainingModeAndFinality(private val session: FlowSession) : FlowLogic() { + + @Suspendable + override fun call() { + + val tx = subFlow(ReceiveTransactionFlow(session)) + val signedTx = serviceHub.addSignature(tx) + val initiatingRpcInfo = session.receive().unwrap { it } + + triggerDrainingModeForInitiatingNode(initiatingRpcInfo) + + subFlow(FinalityFlow(signedTx, setOf(session.counterparty))) + } + + private fun triggerDrainingModeForInitiatingNode(initiatingRpcInfo: RpcInfo) { + + CordaRPCClient(initiatingRpcInfo.address).start(initiatingRpcInfo.username, initiatingRpcInfo.password).use { + it.proxy.setFlowsDrainingModeEnabled(true) + } + } +} \ No newline at end of file diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/ActionExecutorImpl.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/ActionExecutorImpl.kt index a22a03bd1d..daede37ee4 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/ActionExecutorImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/ActionExecutorImpl.kt @@ -13,6 +13,8 @@ package net.corda.node.services.statemachine import co.paralleluniverse.fibers.Fiber import co.paralleluniverse.fibers.Suspendable import com.codahale.metrics.* +import net.corda.core.context.InvocationOrigin +import net.corda.core.identity.Party import net.corda.core.internal.concurrent.thenMatch import net.corda.core.serialization.SerializationContext import net.corda.core.serialization.SerializedBytes @@ -72,7 +74,7 @@ class ActionExecutorImpl( is Action.ScheduleEvent -> executeScheduleEvent(fiber, action) is Action.SleepUntil -> executeSleepUntil(action) is Action.RemoveCheckpoint -> executeRemoveCheckpoint(action) - is Action.SendInitial -> executeSendInitial(action) + is Action.SendInitial -> executeSendInitial(action, fiber.mightDeadlockDrainingTarget(action.party)) is Action.SendExisting -> executeSendExisting(action) is Action.AddSessionBinding -> executeAddSessionBinding(action) is Action.RemoveSessionBindings -> executeRemoveSessionBindings(action) @@ -84,6 +86,12 @@ class ActionExecutorImpl( } } + private fun FlowFiber.mightDeadlockDrainingTarget(target: Party): Boolean { + // This prevents a "deadlock" in case an initiated flow tries to start a session against a draining node that is also the initiator. + // It does not help in case more than 2 nodes are involved in a circle, so the kill switch via RPC should be used in that case. + return invocationContext().origin.let { it is InvocationOrigin.Peer && it.party == target.name } + } + @Suspendable private fun executeTrackTransaction(fiber: FlowFiber, action: Action.TrackTransaction) { services.validatedTransactions.trackTransaction(action.hash).thenMatch( @@ -166,8 +174,8 @@ class ActionExecutorImpl( } @Suspendable - private fun executeSendInitial(action: Action.SendInitial) { - flowMessaging.sendSessionMessage(action.party, action.initialise, action.deduplicationId) + private fun executeSendInitial(action: Action.SendInitial, omitDrainingModeHeaders: Boolean) { + flowMessaging.sendSessionMessage(action.party, action.initialise, action.deduplicationId, omitDrainingModeHeaders) } @Suspendable diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowFiber.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowFiber.kt index dbf4b6b8fe..596a8ae653 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowFiber.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowFiber.kt @@ -11,6 +11,7 @@ package net.corda.node.services.statemachine import co.paralleluniverse.fibers.Suspendable +import net.corda.core.context.InvocationContext import net.corda.core.flows.StateMachineRunId import net.corda.node.services.statemachine.transitions.StateMachine @@ -25,4 +26,6 @@ interface FlowFiber { fun scheduleEvent(event: Event) fun snapshot(): StateMachineState + + fun invocationContext(): InvocationContext } diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowMessaging.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowMessaging.kt index 496252a5f8..4c5c221333 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowMessaging.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowMessaging.kt @@ -33,7 +33,7 @@ interface FlowMessaging { * listen on the send acknowledgement. */ @Suspendable - fun sendSessionMessage(party: Party, message: SessionMessage, deduplicationId: DeduplicationId) + fun sendSessionMessage(party: Party, message: SessionMessage, deduplicationId: DeduplicationId, omitDrainingModeHeaders: Boolean = false) /** * Start the messaging using the [onMessage] message handler. @@ -59,9 +59,9 @@ class FlowMessagingImpl(val serviceHub: ServiceHubInternal): FlowMessaging { } @Suspendable - override fun sendSessionMessage(party: Party, message: SessionMessage, deduplicationId: DeduplicationId) { + override fun sendSessionMessage(party: Party, message: SessionMessage, deduplicationId: DeduplicationId, omitDrainingModeHeaders: Boolean) { log.trace { "Sending message $deduplicationId $message to party $party" } - val networkMessage = serviceHub.networkService.createMessage(sessionTopic, serializeSessionMessage(message).bytes, deduplicationId, message.additionalHeaders()) + val networkMessage = serviceHub.networkService.createMessage(sessionTopic, serializeSessionMessage(message).bytes, deduplicationId, message.additionalHeaders(omitDrainingModeHeaders)) val partyInfo = serviceHub.networkMapCache.getPartyInfo(party) ?: throw IllegalArgumentException("Don't know about $party") val address = serviceHub.networkService.getAddressOfParty(partyInfo) val sequenceKey = when (message) { @@ -71,10 +71,10 @@ class FlowMessagingImpl(val serviceHub: ServiceHubInternal): FlowMessaging { serviceHub.networkService.send(networkMessage, address, sequenceKey = sequenceKey) } - private fun SessionMessage.additionalHeaders(): Map { - return when (this) { - is InitialSessionMessage -> mapOf(P2PMessagingHeaders.Type.KEY to P2PMessagingHeaders.Type.SESSION_INIT_VALUE) - else -> emptyMap() + private fun SessionMessage.additionalHeaders(omitDrainingModeHeaders: Boolean): Map { + return when { + this !is InitialSessionMessage || omitDrainingModeHeaders -> emptyMap() + else -> mapOf(P2PMessagingHeaders.Type.KEY to P2PMessagingHeaders.Type.SESSION_INIT_VALUE) } } diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt index afd2718d2f..26c26b5c91 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt @@ -64,6 +64,8 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, override val serviceHub get() = getTransientField(TransientValues::serviceHub) + override fun invocationContext(): InvocationContext = snapshot().flowLogic.stateMachine.context + data class TransientValues( val eventQueue: Channel, val resultFuture: CordaFuture,