diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/CordaRPCClientImpl.kt b/node/src/main/kotlin/net/corda/node/services/messaging/CordaRPCClientImpl.kt index 6d614d15e8..fc8017f15a 100644 --- a/node/src/main/kotlin/net/corda/node/services/messaging/CordaRPCClientImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/messaging/CordaRPCClientImpl.kt @@ -24,11 +24,13 @@ import org.apache.activemq.artemis.api.core.client.ClientSession import rx.Observable import rx.subjects.PublishSubject import java.io.Closeable +import java.lang.ref.WeakReference import java.lang.reflect.InvocationHandler import java.lang.reflect.Method import java.lang.reflect.Proxy import java.time.Duration import java.util.* +import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.locks.ReentrantLock import javax.annotation.concurrent.GuardedBy import javax.annotation.concurrent.ThreadSafe @@ -107,8 +109,11 @@ class CordaRPCClientImpl(private val session: ClientSession, // do this. private fun ClientMessage.deserialize(kryo: Kryo): T = ByteArray(bodySize).apply { bodyBuffer.readBytes(this) }.deserialize(kryo) + // We by default use a weak reference so GC can happen, otherwise they persist for the life of the client. @GuardedBy("sessionLock") - private val addressToQueueObservables = CacheBuilder.newBuilder().build() + private val addressToQueuedObservables = CacheBuilder.newBuilder().weakValues().build() + // This is used to hold a reference counted hard reference when we know there are subscribers. + private val hardReferencesToQueuedObservables = mutableSetOf() private var producer: ClientProducer? = null @@ -118,8 +123,8 @@ class CordaRPCClientImpl(private val session: ClientSession, override fun read(kryo: Kryo, input: Input, type: Class>): Observable { val handle = input.readInt(true) val ob = sessionLock.withLock { - addressToQueueObservables.getIfPresent(qName) ?: QueuedObservable(qName, rpcName, rpcLocation, this).apply { - addressToQueueObservables.put(qName, this) + addressToQueuedObservables.getIfPresent(qName) ?: QueuedObservable(qName, rpcName, rpcLocation, this).apply { + addressToQueuedObservables.put(qName, this) } } val result = ob.getForHandle(handle) @@ -281,7 +286,39 @@ class CordaRPCClientImpl(private val session: ClientSession, // This could be made more efficient by using a specialised IntMap private val observables = HashMap>() - private var consumer: ClientConsumer? = sessionLock.withLock { session.createConsumer(qName) }.setMessageHandler { deliver(it) } + private var consumer: ClientConsumer? = null + + private val referenceCount = AtomicInteger(0) + + // We have to create a weak reference, otherwise we cannot be GC'd. + init { + val weakThis = WeakReference(this) + consumer = sessionLock.withLock { session.createConsumer(qName) }.setMessageHandler { weakThis.get()?.deliver(it) } + } + + /** + * We have to reference count subscriptions to the returned [Observable]s to prevent early GC because we are + * weak referenced. + * + * Derived [Observables] (e.g. filtered etc) hold a strong reference to the original, but for example, if + * the pattern as follows is used, the original passes out of scope and the direction of reference is from the + * original to the [Observer]. We use the reference counting to allow for this pattern. + * + * val observationsSubject = PublishSubject.create() + * originalObservable.subscribe(observationsSubject) + * return observationsSubject + */ + private fun refCountUp() { + if(referenceCount.andIncrement == 0) { + hardReferencesToQueuedObservables.add(this) + } + } + + private fun refCountDown() { + if(referenceCount.decrementAndGet() == 0) { + hardReferencesToQueuedObservables.remove(this) + } + } @Synchronized fun getForHandle(handle: Int): Observable { @@ -296,8 +333,11 @@ class CordaRPCClientImpl(private val session: ClientSession, * * The buffer -> dematerialize order ensures that the Observable may not unsubscribe until the caller * subscribes, which must be after full deserialisation and registering of all top level Observables. + * + * In addition, when subscribe and unsubscribe is called on the [Observable] returned here, we + * reference count a hard reference to this [QueuedObservable] to prevent premature GC. */ - rootShared.filter { it.forHandle == handle }.map { it.what }.bufferUntilSubscribed().dematerialize().share() + rootShared.filter { it.forHandle == handle }.map { it.what }.bufferUntilSubscribed().dematerialize().doOnSubscribe { refCountUp() }.doOnUnsubscribe { refCountDown() }.share() } } @@ -325,9 +365,9 @@ class CordaRPCClientImpl(private val session: ClientSession, fun finalize() { val c = synchronized(this) { consumer } if (c != null) { - rpcLog.warn("A hot observable returned from an RPC ($rpcName) was never subscribed to or explicitly closed. " + + rpcLog.warn("A hot observable returned from an RPC ($rpcName) was never subscribed to. " + "This wastes server-side resources because it was queueing observations for retrieval. " + - "It is being closed now, but please adjust your code to cast the observable to AutoCloseable and then close it explicitly.", rpcLocation) + "It is being closed now, but please adjust your code to subscribe and unsubscribe from the observable to close it explicitly.", rpcLocation) c.close() } }