diff --git a/core/src/main/kotlin/com/r3corda/core/protocols/ProtocolLogicRef.kt b/core/src/main/kotlin/com/r3corda/core/protocols/ProtocolLogicRef.kt index 248627a3fb..5ac7fc525f 100644 --- a/core/src/main/kotlin/com/r3corda/core/protocols/ProtocolLogicRef.kt +++ b/core/src/main/kotlin/com/r3corda/core/protocols/ProtocolLogicRef.kt @@ -49,6 +49,22 @@ class ProtocolLogicRefFactory(private val protocolWhitelist: Map, attachments: List = emptyList()): ProtocolLogicRef { + val context = AppContext(attachments) + validateProtocolClassName(protocolLogicClassName, context) + for(arg in args.values.filterNotNull()) { + validateArgClassName(protocolLogicClassName, arg.javaClass.name, context) + } + val clazz = Class.forName(protocolLogicClassName) + require(ProtocolLogic::class.java.isAssignableFrom(clazz)) { "$protocolLogicClassName is not a ProtocolLogic" } + @Suppress("UNCHECKED_CAST") + val logic = clazz as Class>> + return createKotlin(logic, args) + } + /** * Create a [ProtocolLogicRef] for the Kotlin primary constructor or Java constructor and the given args. */ diff --git a/node/src/main/kotlin/com/r3corda/node/internal/APIServerImpl.kt b/node/src/main/kotlin/com/r3corda/node/internal/APIServerImpl.kt index 578673d2c1..ebd134a56a 100644 --- a/node/src/main/kotlin/com/r3corda/node/internal/APIServerImpl.kt +++ b/node/src/main/kotlin/com/r3corda/node/internal/APIServerImpl.kt @@ -5,7 +5,6 @@ import com.r3corda.core.contracts.* import com.r3corda.core.crypto.DigitalSignature import com.r3corda.core.crypto.SecureHash import com.r3corda.core.node.services.linearHeadsOfType -import com.r3corda.core.protocols.ProtocolLogic import com.r3corda.core.serialization.SerializedBytes import com.r3corda.node.api.* import java.time.LocalDateTime @@ -67,15 +66,9 @@ class APIServerImpl(val node: AbstractNode) : APIServer { private fun invokeProtocolAsync(type: ProtocolRef, args: Map): ListenableFuture { if (type is ProtocolClassRef) { - val clazz = Class.forName(type.className) - if (ProtocolLogic::class.java.isAssignableFrom(clazz)) { - @Suppress("UNCHECKED_CAST") - val logic = clazz as Class>> - val protocolLogicRef = node.services.protocolLogicRefFactory.createKotlin(logic, args) - val protocolInstance = node.services.protocolLogicRefFactory.toProtocolLogic(protocolLogicRef) - return node.services.startProtocol(clazz.name, protocolInstance) - } - throw UnsupportedOperationException("Could not find matching protocol and constructor for: $type $args") + val protocolLogicRef = node.services.protocolLogicRefFactory.createKotlin(type.className, args) + val protocolInstance = node.services.protocolLogicRefFactory.toProtocolLogic(protocolLogicRef) + return node.services.startProtocol(type.className, protocolInstance) } else { throw UnsupportedOperationException("Unsupported ProtocolRef type: $type") }