diff --git a/client/src/main/kotlin/com/r3corda/client/model/ContractStateModel.kt b/client/src/main/kotlin/com/r3corda/client/model/ContractStateModel.kt index a16225c6d1..5ee19705c7 100644 --- a/client/src/main/kotlin/com/r3corda/client/model/ContractStateModel.kt +++ b/client/src/main/kotlin/com/r3corda/client/model/ContractStateModel.kt @@ -1,20 +1,23 @@ package com.r3corda.client.model +import com.r3corda.client.fxutils.foldToObservableList import com.r3corda.contracts.asset.Cash import com.r3corda.core.contracts.ContractState import com.r3corda.core.contracts.StateAndRef import com.r3corda.core.contracts.StateRef -import com.r3corda.client.fxutils.foldToObservableList import com.r3corda.node.services.monitor.ServiceToClientEvent import com.r3corda.node.services.monitor.StateSnapshotMessage import javafx.collections.ObservableList import kotlinx.support.jdk8.collections.removeIf import rx.Observable -class StatesDiff( - val added: Collection>, - val removed: Collection -) +sealed class StatesModification{ + class Diff( + val added: Collection>, + val removed: Collection + ) : StatesModification() + class Reset(val states: Collection>) : StatesModification() +} /** * This model exposes the list of owned contract states. @@ -24,16 +27,45 @@ class ContractStateModel { private val snapshot: Observable by observable(WalletMonitorModel::snapshot) private val outputStates = serviceToClient.ofType(ServiceToClientEvent.OutputState::class.java) - val contractStatesDiff = outputStates.map { StatesDiff(it.produced, it.consumed) } + val contractStatesDiff: Observable> = + outputStates.map { StatesModification.Diff(it.produced, it.consumed) } // We filter the diff first rather than the complete contract state list. - // TODO wire up snapshot once it holds StateAndRefs - val cashStatesDiff = contractStatesDiff.map { - StatesDiff(it.added.filterIsInstance>(), it.removed) - } + val cashStatesModification: Observable> = Observable.merge( + arrayOf( + contractStatesDiff.map { + StatesModification.Diff(it.added.filterCashStateAndRefs(), it.removed) + }, + snapshot.map { + StatesModification.Reset(it.contractStates.filterCashStateAndRefs()) + } + ) + ) val cashStates: ObservableList> = - cashStatesDiff.foldToObservableList(Unit) { statesDiff, _accumulator, observableList -> - observableList.removeIf { it.ref in statesDiff.removed } - observableList.addAll(statesDiff.added) + cashStatesModification.foldToObservableList(Unit) { statesDiff, _accumulator, observableList -> + when (statesDiff) { + is StatesModification.Diff -> { + observableList.removeIf { it.ref in statesDiff.removed } + observableList.addAll(statesDiff.added) + } + is StatesModification.Reset -> { + observableList.setAll(statesDiff.states) + } + } } + + companion object { + private fun Collection>.filterCashStateAndRefs(): List> { + return this.map { stateAndRef -> + @Suppress("UNCHECKED_CAST") + if (stateAndRef.state.data is Cash.State) { + // Kotlin doesn't unify here for some reason + stateAndRef as StateAndRef + } else { + null + } + }.filterNotNull() + } + } + } diff --git a/node/src/main/kotlin/com/r3corda/node/services/monitor/Messages.kt b/node/src/main/kotlin/com/r3corda/node/services/monitor/Messages.kt index 798f319c8f..610770ecdb 100644 --- a/node/src/main/kotlin/com/r3corda/node/services/monitor/Messages.kt +++ b/node/src/main/kotlin/com/r3corda/node/services/monitor/Messages.kt @@ -2,6 +2,7 @@ package com.r3corda.node.services.monitor import com.r3corda.core.contracts.ClientToServiceCommand import com.r3corda.core.contracts.ContractState +import com.r3corda.core.contracts.StateAndRef import com.r3corda.core.messaging.SingleMessageRecipient import com.r3corda.protocols.DirectRequestMessage @@ -14,6 +15,6 @@ data class DeregisterRequest(override val replyToRecipient: SingleMessageRecipie override val sessionID: Long) : DirectRequestMessage data class DeregisterResponse(val success: Boolean) -data class StateSnapshotMessage(val contractStates: Collection, val protocolStates: Collection) +data class StateSnapshotMessage(val contractStates: Collection>, val protocolStates: Collection) data class ClientToServiceCommandMessage(override val sessionID: Long, override val replyToRecipient: SingleMessageRecipient, val command: ClientToServiceCommand) : DirectRequestMessage diff --git a/node/src/main/kotlin/com/r3corda/node/services/monitor/WalletMonitorService.kt b/node/src/main/kotlin/com/r3corda/node/services/monitor/WalletMonitorService.kt index b339b8bd27..6f4d3cd7a7 100644 --- a/node/src/main/kotlin/com/r3corda/node/services/monitor/WalletMonitorService.kt +++ b/node/src/main/kotlin/com/r3corda/node/services/monitor/WalletMonitorService.kt @@ -135,7 +135,7 @@ class WalletMonitorService(services: ServiceHubInternal, val smm: StateMachineMa fun processRegisterRequest(req: RegisterRequest) { try { listeners.add(RegisteredListener(req.replyToRecipient, req.sessionID)) - val stateMessage = StateSnapshotMessage(services.walletService.currentWallet.states.map { it.state.data }.toList(), + val stateMessage = StateSnapshotMessage(services.walletService.currentWallet.states.toList(), smm.allStateMachines.map { it.javaClass.name }) net.send(net.createMessage(STATE_TOPIC, DEFAULT_SESSION_ID, stateMessage.serialize().bits), req.replyToRecipient)