diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/RPCServer.kt b/node/src/main/kotlin/net/corda/node/services/messaging/RPCServer.kt index e70f80d39d..39c8a8c3b5 100644 --- a/node/src/main/kotlin/net/corda/node/services/messaging/RPCServer.kt +++ b/node/src/main/kotlin/net/corda/node/services/messaging/RPCServer.kt @@ -8,9 +8,6 @@ import com.esotericsoftware.kryo.io.Output import com.github.benmanes.caffeine.cache.Cache import com.github.benmanes.caffeine.cache.Caffeine import com.github.benmanes.caffeine.cache.RemovalListener -import com.google.common.collect.HashMultimap -import com.google.common.collect.Multimaps -import com.google.common.collect.SetMultimap import com.google.common.util.concurrent.ThreadFactoryBuilder import net.corda.client.rpc.RPCException import net.corda.core.context.Actor @@ -112,7 +109,7 @@ class RPCServer( /** The observable subscription mapping. */ private val observableMap = createObservableSubscriptionMap() /** A mapping from client addresses to IDs of associated Observables */ - private val clientAddressToObservables = Multimaps.synchronizedSetMultimap(HashMultimap.create()) + private val clientAddressToObservables = ConcurrentHashMap>() /** The scheduled reaper handle. */ private var reaperScheduledFuture: ScheduledFuture<*>? = null @@ -291,8 +288,10 @@ class RPCServer( // Observables may be serialised and thus registered. private fun invalidateClient(clientAddress: SimpleString) { lifeCycle.requireState(State.STARTED) - val observableIds = clientAddressToObservables.removeAll(clientAddress) - observableMap.invalidateAll(observableIds) + val observableIds = clientAddressToObservables.remove(clientAddress) + if (observableIds != null) { + observableMap.invalidateAll(observableIds) + } responseMessageBuffer.remove(clientAddress) } @@ -419,7 +418,7 @@ class RPCServer( */ inner class ObservableContext( val observableMap: ObservableSubscriptionMap, - val clientAddressToObservables: SetMultimap, + val clientAddressToObservables: ConcurrentHashMap>, val deduplicationIdentity: String, val clientAddress: SimpleString ) { @@ -525,11 +524,30 @@ object RpcServerObservableSerializer : Serializer>() { } override fun onCompleted() { + observableContext.clientAddressToObservables.compute(observableContext.clientAddress) { _, observables -> + if (observables != null) { + observables.remove(observableId) + if (observables.isEmpty()) { + null + } else { + observables + } + } else { + null + } + } } } ) ) - observableContext.clientAddressToObservables.put(observableContext.clientAddress, observableId) + observableContext.clientAddressToObservables.compute(observableContext.clientAddress) { _, observables -> + if (observables == null) { + hashSetOf(observableId) + } else { + observables.add(observableId) + observables + } + } observableContext.observableMap.put(observableId, observableWithSubscription) }