Explicitly check the contractStateType param of the RPC vault queries is a ContractState class. (#3251)

We lose the compile-time checks of the Class type parameter when invoking from the shell.
This commit is contained in:
Shams Asari 2018-05-29 15:01:55 +01:00 committed by GitHub
parent a359f627d5
commit 0f82e2df7f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 52 additions and 41 deletions

View File

@ -77,6 +77,7 @@ internal class CordaRPCOpsImpl(
paging: PageSpecification,
sorting: Sort,
contractStateType: Class<out T>): Vault.Page<T> {
contractStateType.checkIsA<ContractState>()
return database.transaction {
services.vaultService._queryBy(criteria, paging, sorting, contractStateType)
}
@ -87,6 +88,7 @@ internal class CordaRPCOpsImpl(
paging: PageSpecification,
sorting: Sort,
contractStateType: Class<out T>): DataFeed<Vault.Page<T>, Vault.Update<T>> {
contractStateType.checkIsA<ContractState>()
return database.transaction {
services.vaultService._trackBy(criteria, paging, sorting, contractStateType)
}
@ -315,14 +317,25 @@ internal class CordaRPCOpsImpl(
}
private fun InvocationContext.toFlowInitiator(): FlowInitiator {
val principal = origin.principal().name
return when (origin) {
is InvocationOrigin.RPC -> FlowInitiator.RPC(principal)
is InvocationOrigin.Peer -> services.identityService.wellKnownPartyFromX500Name((origin as InvocationOrigin.Peer).party)?.let { FlowInitiator.Peer(it) } ?: throw IllegalStateException("Unknown peer with name ${(origin as InvocationOrigin.Peer).party}.")
is InvocationOrigin.Peer -> {
val wellKnownParty = services.identityService.wellKnownPartyFromX500Name((origin as InvocationOrigin.Peer).party)
wellKnownParty?.let { FlowInitiator.Peer(it) }
?: throw IllegalStateException("Unknown peer with name ${(origin as InvocationOrigin.Peer).party}.")
}
is InvocationOrigin.Service -> FlowInitiator.Service(principal)
InvocationOrigin.Shell -> FlowInitiator.Shell
is InvocationOrigin.Scheduled -> FlowInitiator.Scheduled((origin as InvocationOrigin.Scheduled).scheduledState)
}
}
/**
* RPC can be invoked from the shell where the type parameter of any [Class] parameter is lost, so we must
* explicitly check that the provided [Class] is the one we want.
*/
private inline fun <reified TARGET> Class<*>.checkIsA() {
require(TARGET::class.java.isAssignableFrom(this)) { "$name is not a ${TARGET::class.java.name}" }
}
}

View File

@ -14,11 +14,8 @@ import net.corda.core.flows.FlowLogic
import net.corda.core.flows.StartableByRPC
import net.corda.core.flows.StateMachineRunId
import net.corda.core.identity.Party
import net.corda.core.messaging.CordaRPCOps
import net.corda.core.messaging.StateMachineUpdate
import net.corda.core.messaging.startFlow
import net.corda.core.messaging.vaultQueryBy
import net.corda.core.messaging.vaultTrackBy
import net.corda.core.internal.uncheckedCast
import net.corda.core.messaging.*
import net.corda.core.node.services.Vault
import net.corda.core.node.services.queryBy
import net.corda.core.transactions.SignedTransaction
@ -48,8 +45,7 @@ import net.corda.testing.node.internal.InternalMockNetwork.MockNode
import net.corda.testing.node.internal.InternalMockNodeParameters
import net.corda.testing.node.testActor
import org.apache.commons.io.IOUtils
import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatExceptionOfType
import org.assertj.core.api.Assertions.*
import org.junit.After
import org.junit.Assert.assertArrayEquals
import org.junit.Before
@ -105,9 +101,12 @@ class CordaRPCOpsImplTest {
@Test
fun `cash issue accepted`() {
withPermissions(invokeRpc("vaultTrackBy"), invokeRpc("vaultQueryBy"), invokeRpc(CordaRPCOps::stateMachinesFeed), startFlow<CashIssueFlow>()) {
withPermissions(
invokeRpc("vaultTrackBy"),
invokeRpc("vaultQueryBy"),
invokeRpc(CordaRPCOps::stateMachinesFeed),
startFlow<CashIssueFlow>()
) {
aliceNode.database.transaction {
stateMachineUpdates = rpc.stateMachinesFeed().updates
vaultTrackCash = rpc.vaultTrackBy<Cash.State>().updates
@ -158,7 +157,6 @@ class CordaRPCOpsImplTest {
@Test
fun `issue and move`() {
withPermissions(invokeRpc(CordaRPCOps::stateMachinesFeed),
invokeRpc(CordaRPCOps::internalVerifiedTransactionsFeed),
invokeRpc("vaultTrackBy"),
@ -268,9 +266,9 @@ class CordaRPCOpsImplTest {
withPermissions(invokeRpc(CordaRPCOps::uploadAttachment), invokeRpc(CordaRPCOps::attachmentExists)) {
val inputJar1 = Thread.currentThread().contextClassLoader.getResourceAsStream(testJar)
val inputJar2 = Thread.currentThread().contextClassLoader.getResourceAsStream(testJar)
val secureHash1 = rpc.uploadAttachment(inputJar1)
rpc.uploadAttachment(inputJar1)
assertThatExceptionOfType(java.nio.file.FileAlreadyExistsException::class.java).isThrownBy {
val secureHash2 = rpc.uploadAttachment(inputJar2)
rpc.uploadAttachment(inputJar2)
}
}
}
@ -301,13 +299,14 @@ class CordaRPCOpsImplTest {
@Test
fun `kill a stuck flow through RPC`() {
withPermissions(startFlow<NewJoinerFlow>(), invokeRpc(CordaRPCOps::killFlow), invokeRpc(CordaRPCOps::stateMachinesFeed), invokeRpc(CordaRPCOps::stateMachinesSnapshot)) {
withPermissions(
startFlow<NewJoinerFlow>(),
invokeRpc(CordaRPCOps::killFlow),
invokeRpc(CordaRPCOps::stateMachinesFeed),
invokeRpc(CordaRPCOps::stateMachinesSnapshot)
) {
val flow = rpc.startFlow(::NewJoinerFlow)
val killed = rpc.killFlow(flow.id)
assertThat(killed).isTrue()
assertThat(rpc.stateMachinesSnapshot().map { info -> info.id }).doesNotContain(flow.id)
}
@ -315,13 +314,14 @@ class CordaRPCOpsImplTest {
@Test
fun `kill a waiting flow through RPC`() {
withPermissions(startFlow<HopefulFlow>(), invokeRpc(CordaRPCOps::killFlow), invokeRpc(CordaRPCOps::stateMachinesFeed), invokeRpc(CordaRPCOps::stateMachinesSnapshot)) {
withPermissions(
startFlow<HopefulFlow>(),
invokeRpc(CordaRPCOps::killFlow),
invokeRpc(CordaRPCOps::stateMachinesFeed),
invokeRpc(CordaRPCOps::stateMachinesSnapshot)
) {
val flow = rpc.startFlow(::HopefulFlow, alice)
val killed = rpc.killFlow(flow.id)
assertThat(killed).isTrue()
assertThat(rpc.stateMachinesSnapshot().map { info -> info.id }).doesNotContain(flow.id)
}
@ -329,23 +329,26 @@ class CordaRPCOpsImplTest {
@Test
fun `kill a nonexistent flow through RPC`() {
withPermissions(invokeRpc(CordaRPCOps::killFlow)) {
val nonexistentFlowId = StateMachineRunId.createRandom()
val killed = rpc.killFlow(nonexistentFlowId)
assertThat(killed).isFalse()
}
}
@Test
fun `non-ContractState class for the contractStateType param in vault queries`() {
val nonContractStateClass: Class<out ContractState> = uncheckedCast(Cash::class.java)
withPermissions(invokeRpc("vaultTrack"), invokeRpc("vaultQuery")) {
assertThatThrownBy { rpc.vaultQuery(nonContractStateClass) }.hasMessageContaining(Cash::class.java.name)
assertThatThrownBy { rpc.vaultTrack(nonContractStateClass) }.hasMessageContaining(Cash::class.java.name)
}
}
@StartableByRPC
class NewJoinerFlow : FlowLogic<String>() {
@Suspendable
override fun call(): String {
logger.info("When can I join you say? Almost there buddy...")
Fiber.currentFiber().join()
return "You'll never get me!"
@ -354,13 +357,10 @@ class CordaRPCOpsImplTest {
@StartableByRPC
class HopefulFlow(private val party: Party) : FlowLogic<String>() {
@Suspendable
override fun call(): String {
logger.info("Waiting for a miracle...")
val miracle = initiateFlow(party).receive<String>().unwrap { it }
return miracle
return initiateFlow(party).receive<String>().unwrap { it }
}
}
@ -384,17 +384,15 @@ class CordaRPCOpsImplTest {
override fun call(): Void? = null
}
private fun withPermissions(vararg permissions: String, action: () -> Unit) {
private inline fun withPermissions(vararg permissions: String, action: () -> Unit) {
val previous = CURRENT_RPC_CONTEXT.get()
try {
CURRENT_RPC_CONTEXT.set(previous.copy(authorizer =
buildSubject(previous.principal, permissions.toSet())))
CURRENT_RPC_CONTEXT.set(previous.copy(authorizer = buildSubject(previous.principal, permissions.toSet())))
action.invoke()
} finally {
CURRENT_RPC_CONTEXT.set(previous)
}
}
private fun withoutAnyPermissions(action: () -> Unit) = withPermissions(action = action)
}
private inline fun withoutAnyPermissions(action: () -> Unit) = withPermissions(action = action)
}