diff --git a/node/src/main/kotlin/net/corda/node/internal/CordaRPCOpsImpl.kt b/node/src/main/kotlin/net/corda/node/internal/CordaRPCOpsImpl.kt index 07fd3c57b1..ed82b92a28 100644 --- a/node/src/main/kotlin/net/corda/node/internal/CordaRPCOpsImpl.kt +++ b/node/src/main/kotlin/net/corda/node/internal/CordaRPCOpsImpl.kt @@ -77,6 +77,7 @@ internal class CordaRPCOpsImpl( paging: PageSpecification, sorting: Sort, contractStateType: Class): Vault.Page { + contractStateType.checkIsA() return database.transaction { services.vaultService._queryBy(criteria, paging, sorting, contractStateType) } @@ -87,6 +88,7 @@ internal class CordaRPCOpsImpl( paging: PageSpecification, sorting: Sort, contractStateType: Class): DataFeed, Vault.Update> { + contractStateType.checkIsA() return database.transaction { services.vaultService._trackBy(criteria, paging, sorting, contractStateType) } @@ -315,14 +317,25 @@ internal class CordaRPCOpsImpl( } private fun InvocationContext.toFlowInitiator(): FlowInitiator { - val principal = origin.principal().name return when (origin) { is InvocationOrigin.RPC -> FlowInitiator.RPC(principal) - is InvocationOrigin.Peer -> services.identityService.wellKnownPartyFromX500Name((origin as InvocationOrigin.Peer).party)?.let { FlowInitiator.Peer(it) } ?: throw IllegalStateException("Unknown peer with name ${(origin as InvocationOrigin.Peer).party}.") + is InvocationOrigin.Peer -> { + val wellKnownParty = services.identityService.wellKnownPartyFromX500Name((origin as InvocationOrigin.Peer).party) + wellKnownParty?.let { FlowInitiator.Peer(it) } + ?: throw IllegalStateException("Unknown peer with name ${(origin as InvocationOrigin.Peer).party}.") + } is InvocationOrigin.Service -> FlowInitiator.Service(principal) InvocationOrigin.Shell -> FlowInitiator.Shell is InvocationOrigin.Scheduled -> FlowInitiator.Scheduled((origin as InvocationOrigin.Scheduled).scheduledState) } } + + /** + * RPC can be invoked from the shell where the type parameter of any [Class] parameter is lost, so we must + * explicitly check that the provided [Class] is the one we want. + */ + private inline fun Class<*>.checkIsA() { + require(TARGET::class.java.isAssignableFrom(this)) { "$name is not a ${TARGET::class.java.name}" } + } } \ No newline at end of file diff --git a/node/src/test/kotlin/net/corda/node/CordaRPCOpsImplTest.kt b/node/src/test/kotlin/net/corda/node/CordaRPCOpsImplTest.kt index 670226814f..1297bd97e6 100644 --- a/node/src/test/kotlin/net/corda/node/CordaRPCOpsImplTest.kt +++ b/node/src/test/kotlin/net/corda/node/CordaRPCOpsImplTest.kt @@ -14,11 +14,8 @@ import net.corda.core.flows.FlowLogic import net.corda.core.flows.StartableByRPC import net.corda.core.flows.StateMachineRunId import net.corda.core.identity.Party -import net.corda.core.messaging.CordaRPCOps -import net.corda.core.messaging.StateMachineUpdate -import net.corda.core.messaging.startFlow -import net.corda.core.messaging.vaultQueryBy -import net.corda.core.messaging.vaultTrackBy +import net.corda.core.internal.uncheckedCast +import net.corda.core.messaging.* import net.corda.core.node.services.Vault import net.corda.core.node.services.queryBy import net.corda.core.transactions.SignedTransaction @@ -48,8 +45,7 @@ import net.corda.testing.node.internal.InternalMockNetwork.MockNode import net.corda.testing.node.internal.InternalMockNodeParameters import net.corda.testing.node.testActor import org.apache.commons.io.IOUtils -import org.assertj.core.api.Assertions.assertThat -import org.assertj.core.api.Assertions.assertThatExceptionOfType +import org.assertj.core.api.Assertions.* import org.junit.After import org.junit.Assert.assertArrayEquals import org.junit.Before @@ -105,9 +101,12 @@ class CordaRPCOpsImplTest { @Test fun `cash issue accepted`() { - - withPermissions(invokeRpc("vaultTrackBy"), invokeRpc("vaultQueryBy"), invokeRpc(CordaRPCOps::stateMachinesFeed), startFlow()) { - + withPermissions( + invokeRpc("vaultTrackBy"), + invokeRpc("vaultQueryBy"), + invokeRpc(CordaRPCOps::stateMachinesFeed), + startFlow() + ) { aliceNode.database.transaction { stateMachineUpdates = rpc.stateMachinesFeed().updates vaultTrackCash = rpc.vaultTrackBy().updates @@ -158,7 +157,6 @@ class CordaRPCOpsImplTest { @Test fun `issue and move`() { - withPermissions(invokeRpc(CordaRPCOps::stateMachinesFeed), invokeRpc(CordaRPCOps::internalVerifiedTransactionsFeed), invokeRpc("vaultTrackBy"), @@ -268,9 +266,9 @@ class CordaRPCOpsImplTest { withPermissions(invokeRpc(CordaRPCOps::uploadAttachment), invokeRpc(CordaRPCOps::attachmentExists)) { val inputJar1 = Thread.currentThread().contextClassLoader.getResourceAsStream(testJar) val inputJar2 = Thread.currentThread().contextClassLoader.getResourceAsStream(testJar) - val secureHash1 = rpc.uploadAttachment(inputJar1) + rpc.uploadAttachment(inputJar1) assertThatExceptionOfType(java.nio.file.FileAlreadyExistsException::class.java).isThrownBy { - val secureHash2 = rpc.uploadAttachment(inputJar2) + rpc.uploadAttachment(inputJar2) } } } @@ -301,13 +299,14 @@ class CordaRPCOpsImplTest { @Test fun `kill a stuck flow through RPC`() { - - withPermissions(startFlow(), invokeRpc(CordaRPCOps::killFlow), invokeRpc(CordaRPCOps::stateMachinesFeed), invokeRpc(CordaRPCOps::stateMachinesSnapshot)) { - + withPermissions( + startFlow(), + invokeRpc(CordaRPCOps::killFlow), + invokeRpc(CordaRPCOps::stateMachinesFeed), + invokeRpc(CordaRPCOps::stateMachinesSnapshot) + ) { val flow = rpc.startFlow(::NewJoinerFlow) - val killed = rpc.killFlow(flow.id) - assertThat(killed).isTrue() assertThat(rpc.stateMachinesSnapshot().map { info -> info.id }).doesNotContain(flow.id) } @@ -315,13 +314,14 @@ class CordaRPCOpsImplTest { @Test fun `kill a waiting flow through RPC`() { - - withPermissions(startFlow(), invokeRpc(CordaRPCOps::killFlow), invokeRpc(CordaRPCOps::stateMachinesFeed), invokeRpc(CordaRPCOps::stateMachinesSnapshot)) { - + withPermissions( + startFlow(), + invokeRpc(CordaRPCOps::killFlow), + invokeRpc(CordaRPCOps::stateMachinesFeed), + invokeRpc(CordaRPCOps::stateMachinesSnapshot) + ) { val flow = rpc.startFlow(::HopefulFlow, alice) - val killed = rpc.killFlow(flow.id) - assertThat(killed).isTrue() assertThat(rpc.stateMachinesSnapshot().map { info -> info.id }).doesNotContain(flow.id) } @@ -329,23 +329,26 @@ class CordaRPCOpsImplTest { @Test fun `kill a nonexistent flow through RPC`() { - withPermissions(invokeRpc(CordaRPCOps::killFlow)) { - val nonexistentFlowId = StateMachineRunId.createRandom() - val killed = rpc.killFlow(nonexistentFlowId) - assertThat(killed).isFalse() } } + @Test + fun `non-ContractState class for the contractStateType param in vault queries`() { + val nonContractStateClass: Class = uncheckedCast(Cash::class.java) + withPermissions(invokeRpc("vaultTrack"), invokeRpc("vaultQuery")) { + assertThatThrownBy { rpc.vaultQuery(nonContractStateClass) }.hasMessageContaining(Cash::class.java.name) + assertThatThrownBy { rpc.vaultTrack(nonContractStateClass) }.hasMessageContaining(Cash::class.java.name) + } + } + @StartableByRPC class NewJoinerFlow : FlowLogic() { - @Suspendable override fun call(): String { - logger.info("When can I join you say? Almost there buddy...") Fiber.currentFiber().join() return "You'll never get me!" @@ -354,13 +357,10 @@ class CordaRPCOpsImplTest { @StartableByRPC class HopefulFlow(private val party: Party) : FlowLogic() { - @Suspendable override fun call(): String { - logger.info("Waiting for a miracle...") - val miracle = initiateFlow(party).receive().unwrap { it } - return miracle + return initiateFlow(party).receive().unwrap { it } } } @@ -384,17 +384,15 @@ class CordaRPCOpsImplTest { override fun call(): Void? = null } - private fun withPermissions(vararg permissions: String, action: () -> Unit) { - + private inline fun withPermissions(vararg permissions: String, action: () -> Unit) { val previous = CURRENT_RPC_CONTEXT.get() try { - CURRENT_RPC_CONTEXT.set(previous.copy(authorizer = - buildSubject(previous.principal, permissions.toSet()))) + CURRENT_RPC_CONTEXT.set(previous.copy(authorizer = buildSubject(previous.principal, permissions.toSet()))) action.invoke() } finally { CURRENT_RPC_CONTEXT.set(previous) } } - private fun withoutAnyPermissions(action: () -> Unit) = withPermissions(action = action) -} \ No newline at end of file + private inline fun withoutAnyPermissions(action: () -> Unit) = withPermissions(action = action) +}