diff --git a/core/src/main/kotlin/com/r3corda/core/protocols/ProtocolLogicRef.kt b/core/src/main/kotlin/com/r3corda/core/protocols/ProtocolLogicRef.kt index 248627a3fb..5ac7fc525f 100644 --- a/core/src/main/kotlin/com/r3corda/core/protocols/ProtocolLogicRef.kt +++ b/core/src/main/kotlin/com/r3corda/core/protocols/ProtocolLogicRef.kt @@ -49,6 +49,22 @@ class ProtocolLogicRefFactory(private val protocolWhitelist: Map, attachments: List = emptyList()): ProtocolLogicRef { + val context = AppContext(attachments) + validateProtocolClassName(protocolLogicClassName, context) + for(arg in args.values.filterNotNull()) { + validateArgClassName(protocolLogicClassName, arg.javaClass.name, context) + } + val clazz = Class.forName(protocolLogicClassName) + require(ProtocolLogic::class.java.isAssignableFrom(clazz)) { "$protocolLogicClassName is not a ProtocolLogic" } + @Suppress("UNCHECKED_CAST") + val logic = clazz as Class>> + return createKotlin(logic, args) + } + /** * Create a [ProtocolLogicRef] for the Kotlin primary constructor or Java constructor and the given args. */ diff --git a/node/src/main/kotlin/com/r3corda/node/internal/APIServerImpl.kt b/node/src/main/kotlin/com/r3corda/node/internal/APIServerImpl.kt index 85cafc6692..ebd134a56a 100644 --- a/node/src/main/kotlin/com/r3corda/node/internal/APIServerImpl.kt +++ b/node/src/main/kotlin/com/r3corda/node/internal/APIServerImpl.kt @@ -5,14 +5,10 @@ import com.r3corda.core.contracts.* import com.r3corda.core.crypto.DigitalSignature import com.r3corda.core.crypto.SecureHash import com.r3corda.core.node.services.linearHeadsOfType -import com.r3corda.core.protocols.ProtocolLogic import com.r3corda.core.serialization.SerializedBytes import com.r3corda.node.api.* import java.time.LocalDateTime -import java.util.* import javax.ws.rs.core.Response -import kotlin.reflect.KParameter -import kotlin.reflect.jvm.javaType class APIServerImpl(val node: AbstractNode) : APIServer { @@ -70,38 +66,9 @@ class APIServerImpl(val node: AbstractNode) : APIServer { private fun invokeProtocolAsync(type: ProtocolRef, args: Map): ListenableFuture { if (type is ProtocolClassRef) { - val clazz = Class.forName(type.className) - if (ProtocolLogic::class.java.isAssignableFrom(clazz)) { - // TODO for security, check annotated as exposed on API? Or have PublicProtocolLogic... etc - nextConstructor@ for (constructor in clazz.kotlin.constructors) { - val params = HashMap() - for (parameter in constructor.parameters) { - if (parameter.isOptional && !args.containsKey(parameter.name)) { - // OK to be missing - } else if (args.containsKey(parameter.name)) { - val value = args[parameter.name] - if (value is Any) { - // TODO consider supporting more complex test here to support coercing numeric/Kotlin types - if (!(parameter.type.javaType as Class<*>).isAssignableFrom(value.javaClass)) { - // Not null and not assignable - break@nextConstructor - } - } else if (!parameter.type.isMarkedNullable) { - // Null and not nullable - break@nextConstructor - } - params[parameter] = value - } else { - break@nextConstructor - } - } - // If we get here then we matched every parameter - val protocol = constructor.callBy(params) as ProtocolLogic<*> - val future = node.smm.add("api-call", protocol) - return future - } - } - throw UnsupportedOperationException("Could not find matching protocol and constructor for: $type $args") + val protocolLogicRef = node.services.protocolLogicRefFactory.createKotlin(type.className, args) + val protocolInstance = node.services.protocolLogicRefFactory.toProtocolLogic(protocolLogicRef) + return node.services.startProtocol(type.className, protocolInstance) } else { throw UnsupportedOperationException("Unsupported ProtocolRef type: $type") }