From 554b1fa371eb35c9009df83b65afa852602f317b Mon Sep 17 00:00:00 2001 From: Konstantinos Chalkias Date: Wed, 10 Oct 2018 10:35:18 +0100 Subject: [PATCH 1/6] [CORDA-2084] EdDSA, SPHINCS-256 and RSA PKCS#1 are deterministic, no RNG required. (#4051) --- .../kotlin/net/corda/core/crypto/Crypto.kt | 20 +- .../net/corda/core/crypto/CryptoUtilsTest.kt | 222 ++++++++++-------- 2 files changed, 138 insertions(+), 104 deletions(-) diff --git a/core/src/main/kotlin/net/corda/core/crypto/Crypto.kt b/core/src/main/kotlin/net/corda/core/crypto/Crypto.kt index e131cdc7df..9da4417c7d 100644 --- a/core/src/main/kotlin/net/corda/core/crypto/Crypto.kt +++ b/core/src/main/kotlin/net/corda/core/crypto/Crypto.kt @@ -424,15 +424,19 @@ object Crypto { } require(clearData.isNotEmpty()) { "Signing of an empty array is not permitted!" } val signature = Signature.getInstance(signatureScheme.signatureName, providerMap[signatureScheme.providerName]) - // Note that deterministic signature schemes, such as EdDSA, do not require extra randomness, but we have to - // ensure that non-deterministic algorithms (i.e., ECDSA) use non-blocking SecureRandom implementations (if possible). - // TODO consider updating this when the related BC issue for Sphincs is fixed. - if (signatureScheme != SPHINCS256_SHA256) { - signature.initSign(privateKey, newSecureRandom()) - } else { - // Special handling for Sphincs, due to a BC implementation issue. - // As Sphincs is deterministic, it does not require RNG input anyway. + // Note that deterministic signature schemes, such as EdDSA, original SPHINCS-256 and RSA PKCS#1, do not require + // extra randomness, but we have to ensure that non-deterministic algorithms (i.e., ECDSA) use non-blocking + // SecureRandom implementation. Also, SPHINCS-256 implementation in BouncyCastle 1.60 fails with + // ClassCastException if we invoke initSign with a SecureRandom as an input. + // TODO Although we handle the above issue here, consider updating to BC 1.61+ which provides a fix. + if (signatureScheme == EDDSA_ED25519_SHA512 + || signatureScheme == SPHINCS256_SHA256 + || signatureScheme == RSA_SHA256) { signature.initSign(privateKey) + } else { + // The rest of the algorithms will require a SecureRandom input (i.e., ECDSA or any new algorithm for which + // we don't know if it's deterministic). + signature.initSign(privateKey, newSecureRandom()) } signature.update(clearData) return signature.sign() diff --git a/core/src/test/kotlin/net/corda/core/crypto/CryptoUtilsTest.kt b/core/src/test/kotlin/net/corda/core/crypto/CryptoUtilsTest.kt index a8f156b77a..dbfa620d15 100644 --- a/core/src/test/kotlin/net/corda/core/crypto/CryptoUtilsTest.kt +++ b/core/src/test/kotlin/net/corda/core/crypto/CryptoUtilsTest.kt @@ -1,6 +1,12 @@ package net.corda.core.crypto import com.google.common.collect.Sets +import net.corda.core.crypto.Crypto.ECDSA_SECP256K1_SHA256 +import net.corda.core.crypto.Crypto.ECDSA_SECP256R1_SHA256 +import net.corda.core.crypto.Crypto.EDDSA_ED25519_SHA512 +import net.corda.core.crypto.Crypto.RSA_SHA256 +import net.corda.core.crypto.Crypto.SPHINCS256_SHA256 +import net.corda.core.utilities.OpaqueBytes import net.i2p.crypto.eddsa.EdDSAKey import net.i2p.crypto.eddsa.EdDSAPrivateKey import net.i2p.crypto.eddsa.EdDSAPublicKey @@ -30,17 +36,20 @@ import kotlin.test.* */ class CryptoUtilsTest { - private val testBytes = "Hello World".toByteArray() + companion object { + private val testBytes = "Hello World".toByteArray() + private val test100ZeroBytes = ByteArray(100) + } // key generation test @Test fun `Generate key pairs`() { // testing supported algorithms - val rsaKeyPair = Crypto.generateKeyPair(Crypto.RSA_SHA256) - val ecdsaKKeyPair = Crypto.generateKeyPair(Crypto.ECDSA_SECP256K1_SHA256) - val ecdsaRKeyPair = Crypto.generateKeyPair(Crypto.ECDSA_SECP256R1_SHA256) - val eddsaKeyPair = Crypto.generateKeyPair(Crypto.EDDSA_ED25519_SHA512) - val sphincsKeyPair = Crypto.generateKeyPair(Crypto.SPHINCS256_SHA256) + val rsaKeyPair = Crypto.generateKeyPair(RSA_SHA256) + val ecdsaKKeyPair = Crypto.generateKeyPair(ECDSA_SECP256K1_SHA256) + val ecdsaRKeyPair = Crypto.generateKeyPair(ECDSA_SECP256R1_SHA256) + val eddsaKeyPair = Crypto.generateKeyPair(EDDSA_ED25519_SHA512) + val sphincsKeyPair = Crypto.generateKeyPair(SPHINCS256_SHA256) // not null private keys assertNotNull(rsaKeyPair.private) @@ -69,7 +78,7 @@ class CryptoUtilsTest { @Test fun `RSA full process keygen-sign-verify`() { - val keyPair = Crypto.generateKeyPair(Crypto.RSA_SHA256) + val keyPair = Crypto.generateKeyPair(RSA_SHA256) val (privKey, pubKey) = keyPair // test for some data val signedData = Crypto.doSign(privKey, testBytes) @@ -101,8 +110,8 @@ class CryptoUtilsTest { } // test for zero bytes data - val signedDataZeros = Crypto.doSign(privKey, ByteArray(100)) - val verificationZeros = Crypto.doVerify(pubKey, signedDataZeros, ByteArray(100)) + val signedDataZeros = Crypto.doSign(privKey, test100ZeroBytes) + val verificationZeros = Crypto.doVerify(pubKey, signedDataZeros, test100ZeroBytes) assertTrue(verificationZeros) // test for 1MB of data (I successfully tested it locally for 1GB as well) @@ -124,7 +133,7 @@ class CryptoUtilsTest { @Test fun `ECDSA secp256k1 full process keygen-sign-verify`() { - val keyPair = Crypto.generateKeyPair(Crypto.ECDSA_SECP256K1_SHA256) + val keyPair = Crypto.generateKeyPair(ECDSA_SECP256K1_SHA256) val (privKey, pubKey) = keyPair // test for some data val signedData = Crypto.doSign(privKey, testBytes) @@ -156,8 +165,8 @@ class CryptoUtilsTest { } // test for zero bytes data - val signedDataZeros = Crypto.doSign(privKey, ByteArray(100)) - val verificationZeros = Crypto.doVerify(pubKey, signedDataZeros, ByteArray(100)) + val signedDataZeros = Crypto.doSign(privKey, test100ZeroBytes) + val verificationZeros = Crypto.doVerify(pubKey, signedDataZeros, test100ZeroBytes) assertTrue(verificationZeros) // test for 1MB of data (I successfully tested it locally for 1GB as well) @@ -179,7 +188,7 @@ class CryptoUtilsTest { @Test fun `ECDSA secp256r1 full process keygen-sign-verify`() { - val keyPair = Crypto.generateKeyPair(Crypto.ECDSA_SECP256R1_SHA256) + val keyPair = Crypto.generateKeyPair(ECDSA_SECP256R1_SHA256) val (privKey, pubKey) = keyPair // test for some data val signedData = Crypto.doSign(privKey, testBytes) @@ -211,8 +220,8 @@ class CryptoUtilsTest { } // test for zero bytes data - val signedDataZeros = Crypto.doSign(privKey, ByteArray(100)) - val verificationZeros = Crypto.doVerify(pubKey, signedDataZeros, ByteArray(100)) + val signedDataZeros = Crypto.doSign(privKey, test100ZeroBytes) + val verificationZeros = Crypto.doVerify(pubKey, signedDataZeros, test100ZeroBytes) assertTrue(verificationZeros) // test for 1MB of data (I successfully tested it locally for 1GB as well) @@ -234,7 +243,7 @@ class CryptoUtilsTest { @Test fun `EDDSA ed25519 full process keygen-sign-verify`() { - val keyPair = Crypto.generateKeyPair(Crypto.EDDSA_ED25519_SHA512) + val keyPair = Crypto.generateKeyPair(EDDSA_ED25519_SHA512) val (privKey, pubKey) = keyPair // test for some data val signedData = Crypto.doSign(privKey, testBytes) @@ -266,8 +275,8 @@ class CryptoUtilsTest { } // test for zero bytes data - val signedDataZeros = Crypto.doSign(privKey, ByteArray(100)) - val verificationZeros = Crypto.doVerify(pubKey, signedDataZeros, ByteArray(100)) + val signedDataZeros = Crypto.doSign(privKey, test100ZeroBytes) + val verificationZeros = Crypto.doVerify(pubKey, signedDataZeros, test100ZeroBytes) assertTrue(verificationZeros) // test for 1MB of data (I successfully tested it locally for 1GB as well) @@ -289,7 +298,7 @@ class CryptoUtilsTest { @Test fun `SPHINCS-256 full process keygen-sign-verify`() { - val keyPair = Crypto.generateKeyPair(Crypto.SPHINCS256_SHA256) + val keyPair = Crypto.generateKeyPair(SPHINCS256_SHA256) val (privKey, pubKey) = keyPair // test for some data val signedData = Crypto.doSign(privKey, testBytes) @@ -321,8 +330,8 @@ class CryptoUtilsTest { } // test for zero bytes data - val signedDataZeros = Crypto.doSign(privKey, ByteArray(100)) - val verificationZeros = Crypto.doVerify(pubKey, signedDataZeros, ByteArray(100)) + val signedDataZeros = Crypto.doSign(privKey, test100ZeroBytes) + val verificationZeros = Crypto.doVerify(pubKey, signedDataZeros, test100ZeroBytes) assertTrue(verificationZeros) // test for 1MB of data (I successfully tested it locally for 1GB as well) @@ -354,7 +363,7 @@ class CryptoUtilsTest { @Test fun `RSA encode decode keys - required for serialization`() { // Generate key pair. - val keyPair = Crypto.generateKeyPair(Crypto.RSA_SHA256) + val keyPair = Crypto.generateKeyPair(RSA_SHA256) val (privKey, pubKey) = keyPair // Encode and decode private key. @@ -369,7 +378,7 @@ class CryptoUtilsTest { @Test fun `ECDSA secp256k1 encode decode keys - required for serialization`() { // Generate key pair. - val keyPair = Crypto.generateKeyPair(Crypto.ECDSA_SECP256K1_SHA256) + val keyPair = Crypto.generateKeyPair(ECDSA_SECP256K1_SHA256) val (privKey, pubKey) = keyPair // Encode and decode private key. @@ -384,7 +393,7 @@ class CryptoUtilsTest { @Test fun `ECDSA secp256r1 encode decode keys - required for serialization`() { // Generate key pair. - val keyPair = Crypto.generateKeyPair(Crypto.ECDSA_SECP256R1_SHA256) + val keyPair = Crypto.generateKeyPair(ECDSA_SECP256R1_SHA256) val (privKey, pubKey) = keyPair // Encode and decode private key. @@ -399,7 +408,7 @@ class CryptoUtilsTest { @Test fun `EdDSA encode decode keys - required for serialization`() { // Generate key pair. - val keyPair = Crypto.generateKeyPair(Crypto.EDDSA_ED25519_SHA512) + val keyPair = Crypto.generateKeyPair(EDDSA_ED25519_SHA512) val (privKey, pubKey) = keyPair // Encode and decode private key. @@ -414,7 +423,7 @@ class CryptoUtilsTest { @Test fun `SPHINCS-256 encode decode keys - required for serialization`() { // Generate key pair. - val keyPair = Crypto.generateKeyPair(Crypto.SPHINCS256_SHA256) + val keyPair = Crypto.generateKeyPair(SPHINCS256_SHA256) val privKey: BCSphincs256PrivateKey = keyPair.private as BCSphincs256PrivateKey val pubKey: BCSphincs256PublicKey = keyPair.public as BCSphincs256PublicKey @@ -443,7 +452,7 @@ class CryptoUtilsTest { @Test fun `RSA scheme finder by key type`() { - val keyPairRSA = Crypto.generateKeyPair(Crypto.RSA_SHA256) + val keyPairRSA = Crypto.generateKeyPair(RSA_SHA256) val (privRSA, pubRSA) = keyPairRSA assertEquals(privRSA.algorithm, "RSA") assertEquals(pubRSA.algorithm, "RSA") @@ -451,7 +460,7 @@ class CryptoUtilsTest { @Test fun `ECDSA secp256k1 scheme finder by key type`() { - val keyPair = Crypto.generateKeyPair(Crypto.ECDSA_SECP256K1_SHA256) + val keyPair = Crypto.generateKeyPair(ECDSA_SECP256K1_SHA256) val (privKey, pubKey) = keyPair // Encode and decode private key. @@ -466,7 +475,7 @@ class CryptoUtilsTest { @Test fun `ECDSA secp256r1 scheme finder by key type`() { - val keyPairR1 = Crypto.generateKeyPair(Crypto.ECDSA_SECP256R1_SHA256) + val keyPairR1 = Crypto.generateKeyPair(ECDSA_SECP256R1_SHA256) val (privR1, pubR1) = keyPairR1 assertEquals(privR1.algorithm, "ECDSA") assertEquals((privR1 as ECKey).parameters, ECNamedCurveTable.getParameterSpec("secp256r1")) @@ -476,7 +485,7 @@ class CryptoUtilsTest { @Test fun `EdDSA scheme finder by key type`() { - val keyPairEd = Crypto.generateKeyPair(Crypto.EDDSA_ED25519_SHA512) + val keyPairEd = Crypto.generateKeyPair(EDDSA_ED25519_SHA512) val (privEd, pubEd) = keyPairEd assertEquals(privEd.algorithm, "EdDSA") @@ -487,7 +496,7 @@ class CryptoUtilsTest { @Test fun `SPHINCS-256 scheme finder by key type`() { - val keyPairSP = Crypto.generateKeyPair(Crypto.SPHINCS256_SHA256) + val keyPairSP = Crypto.generateKeyPair(SPHINCS256_SHA256) val (privSP, pubSP) = keyPairSP assertEquals(privSP.algorithm, "SPHINCS-256") assertEquals(pubSP.algorithm, "SPHINCS-256") @@ -495,7 +504,7 @@ class CryptoUtilsTest { @Test fun `Automatic EdDSA key-type detection and decoding`() { - val keyPairEd = Crypto.generateKeyPair(Crypto.EDDSA_ED25519_SHA512) + val keyPairEd = Crypto.generateKeyPair(EDDSA_ED25519_SHA512) val (privEd, pubEd) = keyPairEd val encodedPrivEd = privEd.encoded val encodedPubEd = pubEd.encoded @@ -511,7 +520,7 @@ class CryptoUtilsTest { @Test fun `Automatic ECDSA secp256k1 key-type detection and decoding`() { - val keyPairK1 = Crypto.generateKeyPair(Crypto.ECDSA_SECP256K1_SHA256) + val keyPairK1 = Crypto.generateKeyPair(ECDSA_SECP256K1_SHA256) val (privK1, pubK1) = keyPairK1 val encodedPrivK1 = privK1.encoded val encodedPubK1 = pubK1.encoded @@ -527,7 +536,7 @@ class CryptoUtilsTest { @Test fun `Automatic ECDSA secp256r1 key-type detection and decoding`() { - val keyPairR1 = Crypto.generateKeyPair(Crypto.ECDSA_SECP256R1_SHA256) + val keyPairR1 = Crypto.generateKeyPair(ECDSA_SECP256R1_SHA256) val (privR1, pubR1) = keyPairR1 val encodedPrivR1 = privR1.encoded val encodedPubR1 = pubR1.encoded @@ -543,7 +552,7 @@ class CryptoUtilsTest { @Test fun `Automatic RSA key-type detection and decoding`() { - val keyPairRSA = Crypto.generateKeyPair(Crypto.RSA_SHA256) + val keyPairRSA = Crypto.generateKeyPair(RSA_SHA256) val (privRSA, pubRSA) = keyPairRSA val encodedPrivRSA = privRSA.encoded val encodedPubRSA = pubRSA.encoded @@ -559,7 +568,7 @@ class CryptoUtilsTest { @Test fun `Automatic SPHINCS-256 key-type detection and decoding`() { - val keyPairSP = Crypto.generateKeyPair(Crypto.SPHINCS256_SHA256) + val keyPairSP = Crypto.generateKeyPair(SPHINCS256_SHA256) val (privSP, pubSP) = keyPairSP val encodedPrivSP = privSP.encoded val encodedPubSP = pubSP.encoded @@ -575,12 +584,12 @@ class CryptoUtilsTest { @Test fun `Failure test between K1 and R1 keys`() { - val keyPairK1 = Crypto.generateKeyPair(Crypto.ECDSA_SECP256K1_SHA256) + val keyPairK1 = Crypto.generateKeyPair(ECDSA_SECP256K1_SHA256) val privK1 = keyPairK1.private val encodedPrivK1 = privK1.encoded val decodedPrivK1 = Crypto.decodePrivateKey(encodedPrivK1) - val keyPairR1 = Crypto.generateKeyPair(Crypto.ECDSA_SECP256R1_SHA256) + val keyPairR1 = Crypto.generateKeyPair(ECDSA_SECP256R1_SHA256) val privR1 = keyPairR1.private val encodedPrivR1 = privR1.encoded val decodedPrivR1 = Crypto.decodePrivateKey(encodedPrivR1) @@ -590,7 +599,7 @@ class CryptoUtilsTest { @Test fun `Decoding Failure on randomdata as key`() { - val keyPairK1 = Crypto.generateKeyPair(Crypto.ECDSA_SECP256K1_SHA256) + val keyPairK1 = Crypto.generateKeyPair(ECDSA_SECP256K1_SHA256) val privK1 = keyPairK1.private val encodedPrivK1 = privK1.encoded @@ -610,7 +619,7 @@ class CryptoUtilsTest { @Test fun `Decoding Failure on malformed keys`() { - val keyPairK1 = Crypto.generateKeyPair(Crypto.ECDSA_SECP256K1_SHA256) + val keyPairK1 = Crypto.generateKeyPair(ECDSA_SECP256K1_SHA256) val privK1 = keyPairK1.private val encodedPrivK1 = privK1.encoded @@ -630,25 +639,25 @@ class CryptoUtilsTest { @Test fun `Check ECDSA public key on curve`() { - val keyPairK1 = Crypto.generateKeyPair(Crypto.ECDSA_SECP256K1_SHA256) + val keyPairK1 = Crypto.generateKeyPair(ECDSA_SECP256K1_SHA256) val pubK1 = keyPairK1.public as BCECPublicKey - assertTrue(Crypto.publicKeyOnCurve(Crypto.ECDSA_SECP256K1_SHA256, pubK1)) + assertTrue(Crypto.publicKeyOnCurve(ECDSA_SECP256K1_SHA256, pubK1)) // use R1 curve for check. - assertFalse(Crypto.publicKeyOnCurve(Crypto.ECDSA_SECP256R1_SHA256, pubK1)) + assertFalse(Crypto.publicKeyOnCurve(ECDSA_SECP256R1_SHA256, pubK1)) // use ed25519 curve for check. - assertFalse(Crypto.publicKeyOnCurve(Crypto.EDDSA_ED25519_SHA512, pubK1)) + assertFalse(Crypto.publicKeyOnCurve(EDDSA_ED25519_SHA512, pubK1)) } @Test fun `Check EdDSA public key on curve`() { - val keyPairEdDSA = Crypto.generateKeyPair(Crypto.EDDSA_ED25519_SHA512) + val keyPairEdDSA = Crypto.generateKeyPair(EDDSA_ED25519_SHA512) val pubEdDSA = keyPairEdDSA.public - assertTrue(Crypto.publicKeyOnCurve(Crypto.EDDSA_ED25519_SHA512, pubEdDSA)) + assertTrue(Crypto.publicKeyOnCurve(EDDSA_ED25519_SHA512, pubEdDSA)) // Use R1 curve for check. - assertFalse(Crypto.publicKeyOnCurve(Crypto.ECDSA_SECP256R1_SHA256, pubEdDSA)) + assertFalse(Crypto.publicKeyOnCurve(ECDSA_SECP256R1_SHA256, pubEdDSA)) // Check for point at infinity. - val pubKeySpec = EdDSAPublicKeySpec((Crypto.EDDSA_ED25519_SHA512.algSpec as EdDSANamedCurveSpec).curve.getZero(GroupElement.Representation.P3), Crypto.EDDSA_ED25519_SHA512.algSpec as EdDSANamedCurveSpec) - assertFalse(Crypto.publicKeyOnCurve(Crypto.EDDSA_ED25519_SHA512, EdDSAPublicKey(pubKeySpec))) + val pubKeySpec = EdDSAPublicKeySpec((EDDSA_ED25519_SHA512.algSpec as EdDSANamedCurveSpec).curve.getZero(GroupElement.Representation.P3), EDDSA_ED25519_SHA512.algSpec as EdDSANamedCurveSpec) + assertFalse(Crypto.publicKeyOnCurve(EDDSA_ED25519_SHA512, EdDSAPublicKey(pubKeySpec))) } @Test(expected = IllegalArgumentException::class) @@ -658,12 +667,12 @@ class CryptoUtilsTest { val pairSun = keyGen.generateKeyPair() val pubSun = pairSun.public // Should fail as pubSun is not a BCECPublicKey. - Crypto.publicKeyOnCurve(Crypto.ECDSA_SECP256R1_SHA256, pubSun) + Crypto.publicKeyOnCurve(ECDSA_SECP256R1_SHA256, pubSun) } @Test fun `ECDSA secp256R1 deterministic key generation`() { - val (priv, pub) = Crypto.generateKeyPair(Crypto.ECDSA_SECP256R1_SHA256) + val (priv, pub) = Crypto.generateKeyPair(ECDSA_SECP256R1_SHA256) val (dpriv, dpub) = Crypto.deriveKeyPair(priv, "seed-1".toByteArray()) // Check scheme. @@ -673,11 +682,11 @@ class CryptoUtilsTest { assertTrue(dpub is BCECPublicKey) assertEquals((dpriv as ECKey).parameters, ECNamedCurveTable.getParameterSpec("secp256r1")) assertEquals((dpub as ECKey).parameters, ECNamedCurveTable.getParameterSpec("secp256r1")) - assertEquals(Crypto.findSignatureScheme(dpriv), Crypto.ECDSA_SECP256R1_SHA256) - assertEquals(Crypto.findSignatureScheme(dpub), Crypto.ECDSA_SECP256R1_SHA256) + assertEquals(Crypto.findSignatureScheme(dpriv), ECDSA_SECP256R1_SHA256) + assertEquals(Crypto.findSignatureScheme(dpub), ECDSA_SECP256R1_SHA256) // Validate public key. - assertTrue(Crypto.publicKeyOnCurve(Crypto.ECDSA_SECP256R1_SHA256, dpub)) + assertTrue(Crypto.publicKeyOnCurve(ECDSA_SECP256R1_SHA256, dpub)) // Try to sign/verify. val signedData = Crypto.doSign(dpriv, testBytes) @@ -704,7 +713,7 @@ class CryptoUtilsTest { @Test fun `ECDSA secp256K1 deterministic key generation`() { - val (priv, pub) = Crypto.generateKeyPair(Crypto.ECDSA_SECP256K1_SHA256) + val (priv, pub) = Crypto.generateKeyPair(ECDSA_SECP256K1_SHA256) val (dpriv, dpub) = Crypto.deriveKeyPair(priv, "seed-1".toByteArray()) // Check scheme. @@ -714,11 +723,11 @@ class CryptoUtilsTest { assertTrue(dpub is BCECPublicKey) assertEquals((dpriv as ECKey).parameters, ECNamedCurveTable.getParameterSpec("secp256k1")) assertEquals((dpub as ECKey).parameters, ECNamedCurveTable.getParameterSpec("secp256k1")) - assertEquals(Crypto.findSignatureScheme(dpriv), Crypto.ECDSA_SECP256K1_SHA256) - assertEquals(Crypto.findSignatureScheme(dpub), Crypto.ECDSA_SECP256K1_SHA256) + assertEquals(Crypto.findSignatureScheme(dpriv), ECDSA_SECP256K1_SHA256) + assertEquals(Crypto.findSignatureScheme(dpub), ECDSA_SECP256K1_SHA256) // Validate public key. - assertTrue(Crypto.publicKeyOnCurve(Crypto.ECDSA_SECP256K1_SHA256, dpub)) + assertTrue(Crypto.publicKeyOnCurve(ECDSA_SECP256K1_SHA256, dpub)) // Try to sign/verify. val signedData = Crypto.doSign(dpriv, testBytes) @@ -745,7 +754,7 @@ class CryptoUtilsTest { @Test fun `EdDSA ed25519 deterministic key generation`() { - val (priv, pub) = Crypto.generateKeyPair(Crypto.EDDSA_ED25519_SHA512) + val (priv, pub) = Crypto.generateKeyPair(EDDSA_ED25519_SHA512) val (dpriv, dpub) = Crypto.deriveKeyPair(priv, "seed-1".toByteArray()) // Check scheme. @@ -755,11 +764,11 @@ class CryptoUtilsTest { assertTrue(dpub is EdDSAPublicKey) assertEquals((dpriv as EdDSAKey).params, EdDSANamedCurveTable.getByName("ED25519")) assertEquals((dpub as EdDSAKey).params, EdDSANamedCurveTable.getByName("ED25519")) - assertEquals(Crypto.findSignatureScheme(dpriv), Crypto.EDDSA_ED25519_SHA512) - assertEquals(Crypto.findSignatureScheme(dpub), Crypto.EDDSA_ED25519_SHA512) + assertEquals(Crypto.findSignatureScheme(dpriv), EDDSA_ED25519_SHA512) + assertEquals(Crypto.findSignatureScheme(dpub), EDDSA_ED25519_SHA512) // Validate public key. - assertTrue(Crypto.publicKeyOnCurve(Crypto.EDDSA_ED25519_SHA512, dpub)) + assertTrue(Crypto.publicKeyOnCurve(EDDSA_ED25519_SHA512, dpub)) // Try to sign/verify. val signedData = Crypto.doSign(dpriv, testBytes) @@ -786,110 +795,131 @@ class CryptoUtilsTest { @Test fun `EdDSA ed25519 keyPair from entropy`() { - val keyPairPositive = Crypto.deriveKeyPairFromEntropy(Crypto.EDDSA_ED25519_SHA512, BigInteger("10")) + val keyPairPositive = Crypto.deriveKeyPairFromEntropy(EDDSA_ED25519_SHA512, BigInteger("10")) assertEquals("DLBL3iHCp9uRReWhhCGfCsrxZZpfAm9h9GLbfN8ijqXTq", keyPairPositive.public.toStringShort()) - val keyPairNegative = Crypto.deriveKeyPairFromEntropy(Crypto.EDDSA_ED25519_SHA512, BigInteger("-10")) + val keyPairNegative = Crypto.deriveKeyPairFromEntropy(EDDSA_ED25519_SHA512, BigInteger("-10")) assertEquals("DLC5HXnYsJAFqmM9hgPj5G8whQ4TpyE9WMBssqCayLBwA2", keyPairNegative.public.toStringShort()) - val keyPairZero = Crypto.deriveKeyPairFromEntropy(Crypto.EDDSA_ED25519_SHA512, BigInteger("0")) + val keyPairZero = Crypto.deriveKeyPairFromEntropy(EDDSA_ED25519_SHA512, BigInteger("0")) assertEquals("DL4UVhGh4tqu1G86UVoGNaDDNCMsBtNHzE6BSZuNNJN7W2", keyPairZero.public.toStringShort()) - val keyPairOne = Crypto.deriveKeyPairFromEntropy(Crypto.EDDSA_ED25519_SHA512, BigInteger("1")) + val keyPairOne = Crypto.deriveKeyPairFromEntropy(EDDSA_ED25519_SHA512, BigInteger("1")) assertEquals("DL8EZUdHixovcCynKMQzrMWBnXQAcbVDHi6ArPphqwJVzq", keyPairOne.public.toStringShort()) - val keyPairBiggerThan256bits = Crypto.deriveKeyPairFromEntropy(Crypto.EDDSA_ED25519_SHA512, BigInteger("2").pow(258).minus(BigInteger.TEN)) + val keyPairBiggerThan256bits = Crypto.deriveKeyPairFromEntropy(EDDSA_ED25519_SHA512, BigInteger("2").pow(258).minus(BigInteger.TEN)) assertEquals("DLB9K1UiBrWonn481z6NzkqoWHjMBXpfDeaet3wiwRNWSU", keyPairBiggerThan256bits.public.toStringShort()) // The underlying implementation uses the first 256 bytes of the entropy. Thus, 2^258-10 and 2^258-50 and 2^514-10 have the same impact. - val keyPairBiggerThan256bitsV2 = Crypto.deriveKeyPairFromEntropy(Crypto.EDDSA_ED25519_SHA512, BigInteger("2").pow(258).minus(BigInteger("50"))) + val keyPairBiggerThan256bitsV2 = Crypto.deriveKeyPairFromEntropy(EDDSA_ED25519_SHA512, BigInteger("2").pow(258).minus(BigInteger("50"))) assertEquals("DLB9K1UiBrWonn481z6NzkqoWHjMBXpfDeaet3wiwRNWSU", keyPairBiggerThan256bitsV2.public.toStringShort()) - val keyPairBiggerThan512bits = Crypto.deriveKeyPairFromEntropy(Crypto.EDDSA_ED25519_SHA512, BigInteger("2").pow(514).minus(BigInteger.TEN)) + val keyPairBiggerThan512bits = Crypto.deriveKeyPairFromEntropy(EDDSA_ED25519_SHA512, BigInteger("2").pow(514).minus(BigInteger.TEN)) assertEquals("DLB9K1UiBrWonn481z6NzkqoWHjMBXpfDeaet3wiwRNWSU", keyPairBiggerThan512bits.public.toStringShort()) // Try another big number. - val keyPairBiggerThan258bits = Crypto.deriveKeyPairFromEntropy(Crypto.EDDSA_ED25519_SHA512, BigInteger("2").pow(259).plus(BigInteger.ONE)) + val keyPairBiggerThan258bits = Crypto.deriveKeyPairFromEntropy(EDDSA_ED25519_SHA512, BigInteger("2").pow(259).plus(BigInteger.ONE)) assertEquals("DL5tEFVMXMGrzwjfCAW34JjkhsRkPfFyJ38iEnmpB6L2Z9", keyPairBiggerThan258bits.public.toStringShort()) } @Test fun `ECDSA R1 keyPair from entropy`() { - val keyPairPositive = Crypto.deriveKeyPairFromEntropy(Crypto.ECDSA_SECP256R1_SHA256, BigInteger("10")) + val keyPairPositive = Crypto.deriveKeyPairFromEntropy(ECDSA_SECP256R1_SHA256, BigInteger("10")) assertEquals("DLHDcxuSt9J3cbjd2Dsx4rAgYYA7BAP7A8VLrFiq1tH9yy", keyPairPositive.public.toStringShort()) // The underlying implementation uses the hash of entropy if it is out of range 2 < entropy < N, where N the order of the group. - val keyPairNegative = Crypto.deriveKeyPairFromEntropy(Crypto.ECDSA_SECP256R1_SHA256, BigInteger("-10")) + val keyPairNegative = Crypto.deriveKeyPairFromEntropy(ECDSA_SECP256R1_SHA256, BigInteger("-10")) assertEquals("DLBASmjiMZuu1g3EtdHJxfSueXE8PRoUWbkdU61Qcnpamt", keyPairNegative.public.toStringShort()) - val keyPairZero = Crypto.deriveKeyPairFromEntropy(Crypto.ECDSA_SECP256R1_SHA256, BigInteger("0")) + val keyPairZero = Crypto.deriveKeyPairFromEntropy(ECDSA_SECP256R1_SHA256, BigInteger("0")) assertEquals("DLH2FEHEnsT3MpCJt2gfyNjpqRqcBxeupK4YRPXvDsVEkb", keyPairZero.public.toStringShort()) // BigIntenger.Zero is out or range, so 1 and hash(1.toByteArray) would have the same impact. val zeroHashed = BigInteger(1, BigInteger("0").toByteArray().sha256().bytes) // Check oneHashed < N (order of the group), otherwise we would need an extra hash. - assertEquals(-1, zeroHashed.compareTo((Crypto.ECDSA_SECP256R1_SHA256.algSpec as ECNamedCurveParameterSpec).n)) - val keyPairZeroHashed = Crypto.deriveKeyPairFromEntropy(Crypto.ECDSA_SECP256R1_SHA256, zeroHashed) + assertEquals(-1, zeroHashed.compareTo((ECDSA_SECP256R1_SHA256.algSpec as ECNamedCurveParameterSpec).n)) + val keyPairZeroHashed = Crypto.deriveKeyPairFromEntropy(ECDSA_SECP256R1_SHA256, zeroHashed) assertEquals("DLH2FEHEnsT3MpCJt2gfyNjpqRqcBxeupK4YRPXvDsVEkb", keyPairZeroHashed.public.toStringShort()) - val keyPairOne = Crypto.deriveKeyPairFromEntropy(Crypto.ECDSA_SECP256R1_SHA256, BigInteger("1")) + val keyPairOne = Crypto.deriveKeyPairFromEntropy(ECDSA_SECP256R1_SHA256, BigInteger("1")) assertEquals("DLHrtKwjv6onq9HcrQDJPs8Cgtai5mZU5ZU6sb1ivJjx3z", keyPairOne.public.toStringShort()) // BigIntenger.ONE is out or range, so 1 and hash(1.toByteArray) would have the same impact. val oneHashed = BigInteger(1, BigInteger("1").toByteArray().sha256().bytes) // Check oneHashed < N (order of the group), otherwise we would need an extra hash. - assertEquals(-1, oneHashed.compareTo((Crypto.ECDSA_SECP256R1_SHA256.algSpec as ECNamedCurveParameterSpec).n)) - val keyPairOneHashed = Crypto.deriveKeyPairFromEntropy(Crypto.ECDSA_SECP256R1_SHA256, oneHashed) + assertEquals(-1, oneHashed.compareTo((ECDSA_SECP256R1_SHA256.algSpec as ECNamedCurveParameterSpec).n)) + val keyPairOneHashed = Crypto.deriveKeyPairFromEntropy(ECDSA_SECP256R1_SHA256, oneHashed) assertEquals("DLHrtKwjv6onq9HcrQDJPs8Cgtai5mZU5ZU6sb1ivJjx3z", keyPairOneHashed.public.toStringShort()) // 2 is in the range. - val keyPairTwo = Crypto.deriveKeyPairFromEntropy(Crypto.ECDSA_SECP256R1_SHA256, BigInteger("2")) + val keyPairTwo = Crypto.deriveKeyPairFromEntropy(ECDSA_SECP256R1_SHA256, BigInteger("2")) assertEquals("DLFoz6txJ3vHcKNSM1vFxHJUoEQ69PorBwW64dHsAnEoZB", keyPairTwo.public.toStringShort()) // Try big numbers that are out of range. - val keyPairBiggerThan256bits = Crypto.deriveKeyPairFromEntropy(Crypto.ECDSA_SECP256R1_SHA256, BigInteger("2").pow(258).minus(BigInteger.TEN)) + val keyPairBiggerThan256bits = Crypto.deriveKeyPairFromEntropy(ECDSA_SECP256R1_SHA256, BigInteger("2").pow(258).minus(BigInteger.TEN)) assertEquals("DLBv6fZqaCTbE4L7sgjbt19biXHMgU9CzR5s8g8XBJjZ11", keyPairBiggerThan256bits.public.toStringShort()) - val keyPairBiggerThan256bitsV2 = Crypto.deriveKeyPairFromEntropy(Crypto.ECDSA_SECP256R1_SHA256, BigInteger("2").pow(258).minus(BigInteger("50"))) + val keyPairBiggerThan256bitsV2 = Crypto.deriveKeyPairFromEntropy(ECDSA_SECP256R1_SHA256, BigInteger("2").pow(258).minus(BigInteger("50"))) assertEquals("DLANmjhGSVdLyghxcPHrn3KuGatscf6LtvqifUDxw7SGU8", keyPairBiggerThan256bitsV2.public.toStringShort()) - val keyPairBiggerThan512bits = Crypto.deriveKeyPairFromEntropy(Crypto.ECDSA_SECP256R1_SHA256, BigInteger("2").pow(514).minus(BigInteger.TEN)) + val keyPairBiggerThan512bits = Crypto.deriveKeyPairFromEntropy(ECDSA_SECP256R1_SHA256, BigInteger("2").pow(514).minus(BigInteger.TEN)) assertEquals("DL9sKwMExBTD3MnJN6LWGqo496Erkebs9fxZtXLVJUBY9Z", keyPairBiggerThan512bits.public.toStringShort()) - val keyPairBiggerThan258bits = Crypto.deriveKeyPairFromEntropy(Crypto.ECDSA_SECP256R1_SHA256, BigInteger("2").pow(259).plus(BigInteger.ONE)) + val keyPairBiggerThan258bits = Crypto.deriveKeyPairFromEntropy(ECDSA_SECP256R1_SHA256, BigInteger("2").pow(259).plus(BigInteger.ONE)) assertEquals("DLBwjWwPJSF9E7b1NWaSbEJ4oK8CF7RDGWd648TiBhZoL1", keyPairBiggerThan258bits.public.toStringShort()) } @Test fun `ECDSA K1 keyPair from entropy`() { - val keyPairPositive = Crypto.deriveKeyPairFromEntropy(Crypto.ECDSA_SECP256K1_SHA256, BigInteger("10")) + val keyPairPositive = Crypto.deriveKeyPairFromEntropy(ECDSA_SECP256K1_SHA256, BigInteger("10")) assertEquals("DL6pYKUgH17az8MLdonvvUtUPN8TqwpCGcdgLr7vg3skCU", keyPairPositive.public.toStringShort()) // The underlying implementation uses the hash of entropy if it is out of range 2 <= entropy < N, where N the order of the group. - val keyPairNegative = Crypto.deriveKeyPairFromEntropy(Crypto.ECDSA_SECP256K1_SHA256, BigInteger("-10")) + val keyPairNegative = Crypto.deriveKeyPairFromEntropy(ECDSA_SECP256K1_SHA256, BigInteger("-10")) assertEquals("DLnpXhxece69Nyqgm3pPt3yV7ESQYDJKoYxs1hKgfBAEu", keyPairNegative.public.toStringShort()) - val keyPairZero = Crypto.deriveKeyPairFromEntropy(Crypto.ECDSA_SECP256K1_SHA256, BigInteger("0")) + val keyPairZero = Crypto.deriveKeyPairFromEntropy(ECDSA_SECP256K1_SHA256, BigInteger("0")) assertEquals("DLBC28e18T6KsYwjTFfUWJfhvHjvYVapyVf6antnqUkbgd", keyPairZero.public.toStringShort()) // BigIntenger.Zero is out or range, so 1 and hash(1.toByteArray) would have the same impact. val zeroHashed = BigInteger(1, BigInteger("0").toByteArray().sha256().bytes) // Check oneHashed < N (order of the group), otherwise we would need an extra hash. - assertEquals(-1, zeroHashed.compareTo((Crypto.ECDSA_SECP256K1_SHA256.algSpec as ECNamedCurveParameterSpec).n)) - val keyPairZeroHashed = Crypto.deriveKeyPairFromEntropy(Crypto.ECDSA_SECP256K1_SHA256, zeroHashed) + assertEquals(-1, zeroHashed.compareTo((ECDSA_SECP256K1_SHA256.algSpec as ECNamedCurveParameterSpec).n)) + val keyPairZeroHashed = Crypto.deriveKeyPairFromEntropy(ECDSA_SECP256K1_SHA256, zeroHashed) assertEquals("DLBC28e18T6KsYwjTFfUWJfhvHjvYVapyVf6antnqUkbgd", keyPairZeroHashed.public.toStringShort()) - val keyPairOne = Crypto.deriveKeyPairFromEntropy(Crypto.ECDSA_SECP256K1_SHA256, BigInteger("1")) + val keyPairOne = Crypto.deriveKeyPairFromEntropy(ECDSA_SECP256K1_SHA256, BigInteger("1")) assertEquals("DLBimRXdEQhJUTpL6f9ri9woNdsze6mwkRrhsML13Eh7ET", keyPairOne.public.toStringShort()) // BigIntenger.ONE is out or range, so 1 and hash(1.toByteArray) would have the same impact. val oneHashed = BigInteger(1, BigInteger("1").toByteArray().sha256().bytes) // Check oneHashed < N (order of the group), otherwise we would need an extra hash. - assertEquals(-1, oneHashed.compareTo((Crypto.ECDSA_SECP256K1_SHA256.algSpec as ECNamedCurveParameterSpec).n)) - val keyPairOneHashed = Crypto.deriveKeyPairFromEntropy(Crypto.ECDSA_SECP256K1_SHA256, oneHashed) + assertEquals(-1, oneHashed.compareTo((ECDSA_SECP256K1_SHA256.algSpec as ECNamedCurveParameterSpec).n)) + val keyPairOneHashed = Crypto.deriveKeyPairFromEntropy(ECDSA_SECP256K1_SHA256, oneHashed) assertEquals("DLBimRXdEQhJUTpL6f9ri9woNdsze6mwkRrhsML13Eh7ET", keyPairOneHashed.public.toStringShort()) // 2 is in the range. - val keyPairTwo = Crypto.deriveKeyPairFromEntropy(Crypto.ECDSA_SECP256K1_SHA256, BigInteger("2")) + val keyPairTwo = Crypto.deriveKeyPairFromEntropy(ECDSA_SECP256K1_SHA256, BigInteger("2")) assertEquals("DLG32UWaevGw9YY7w1Rf9mmK88biavgpDnJA9bG4GapVPs", keyPairTwo.public.toStringShort()) // Try big numbers that are out of range. - val keyPairBiggerThan256bits = Crypto.deriveKeyPairFromEntropy(Crypto.ECDSA_SECP256K1_SHA256, BigInteger("2").pow(258).minus(BigInteger.TEN)) + val keyPairBiggerThan256bits = Crypto.deriveKeyPairFromEntropy(ECDSA_SECP256K1_SHA256, BigInteger("2").pow(258).minus(BigInteger.TEN)) assertEquals("DLGHsdv2xeAuM7n3sBc6mFfiphXe6VSf3YxqvviKDU6Vbd", keyPairBiggerThan256bits.public.toStringShort()) - val keyPairBiggerThan256bitsV2 = Crypto.deriveKeyPairFromEntropy(Crypto.ECDSA_SECP256K1_SHA256, BigInteger("2").pow(258).minus(BigInteger("50"))) + val keyPairBiggerThan256bitsV2 = Crypto.deriveKeyPairFromEntropy(ECDSA_SECP256K1_SHA256, BigInteger("2").pow(258).minus(BigInteger("50"))) assertEquals("DL9yJfiNGqteRrKPjGUkRQkeqzuQ4kwcYQWMCi5YKuUHrk", keyPairBiggerThan256bitsV2.public.toStringShort()) - val keyPairBiggerThan512bits = Crypto.deriveKeyPairFromEntropy(Crypto.ECDSA_SECP256K1_SHA256, BigInteger("2").pow(514).minus(BigInteger.TEN)) + val keyPairBiggerThan512bits = Crypto.deriveKeyPairFromEntropy(ECDSA_SECP256K1_SHA256, BigInteger("2").pow(514).minus(BigInteger.TEN)) assertEquals("DL3Wr5EQGrMTaKBy5XMvG8rvSfKX1AYZLCRU8kixGbxt1E", keyPairBiggerThan512bits.public.toStringShort()) - val keyPairBiggerThan258bits = Crypto.deriveKeyPairFromEntropy(Crypto.ECDSA_SECP256K1_SHA256, BigInteger("2").pow(259).plus(BigInteger.ONE)) + val keyPairBiggerThan258bits = Crypto.deriveKeyPairFromEntropy(ECDSA_SECP256K1_SHA256, BigInteger("2").pow(259).plus(BigInteger.ONE)) assertEquals("DL7NbssqvuuJ4cqFkkaVYu9j1MsVswESGgCfbqBS9ULwuM", keyPairBiggerThan258bits.public.toStringShort()) } + + @Test + fun `Ensure deterministic signatures of EdDSA, SPHINCS-256 and RSA PKCS1`() { + listOf(EDDSA_ED25519_SHA512, SPHINCS256_SHA256, RSA_SHA256) + .forEach { testDeterministicSignatures(it) } + } + + private fun testDeterministicSignatures(signatureScheme: SignatureScheme) { + val privateKey = Crypto.generateKeyPair(signatureScheme).private + val signedData1stTime = Crypto.doSign(privateKey, testBytes) + val signedData2ndTime = Crypto.doSign(privateKey, testBytes) + assertEquals(OpaqueBytes(signedData1stTime), OpaqueBytes(signedData2ndTime)) + + // Try for the special case of signing a zero array. + val signedZeroArray1stTime = Crypto.doSign(privateKey, test100ZeroBytes) + val signedZeroArray2ndTime = Crypto.doSign(privateKey, test100ZeroBytes) + assertEquals(OpaqueBytes(signedZeroArray1stTime), OpaqueBytes(signedZeroArray2ndTime)) + + // Just in case, test that signatures of different messages are not the same. + assertNotEquals(OpaqueBytes(signedData1stTime), OpaqueBytes(signedZeroArray1stTime)) + } } From b8b2cc772d27ab78330557f4deb13be112a63642 Mon Sep 17 00:00:00 2001 From: Andrius Dagys Date: Wed, 10 Oct 2018 13:31:29 +0100 Subject: [PATCH 2/6] CORDA-535: Remove the old mechanism for loading custom notary service implementations. All notary service implementations are now assumed to be loaded from CorDapps. --- docs/source/tutorial-custom-notary.rst | 13 ++-- .../net/corda/node/internal/AbstractNode.kt | 75 +++++-------------- .../node/services/config/NodeConfiguration.kt | 5 +- .../net/corda/node/services/TimedFlowTests.kt | 1 - samples/notary-demo/build.gradle | 5 +- .../corda/notarydemo/MyCustomNotaryService.kt | 5 +- 6 files changed, 32 insertions(+), 72 deletions(-) diff --git a/docs/source/tutorial-custom-notary.rst b/docs/source/tutorial-custom-notary.rst index cd102e484f..cf7b78a0ff 100644 --- a/docs/source/tutorial-custom-notary.rst +++ b/docs/source/tutorial-custom-notary.rst @@ -4,13 +4,12 @@ Writing a custom notary service (experimental) ============================================== .. warning:: Customising a notary service is still an experimental feature and not recommended for most use-cases. The APIs - for writing a custom notary may change in the future. Additionally, customising Raft or BFT notaries is not yet - fully supported. If you want to write your own Raft notary you will have to implement a custom database connector - (or use a separate database for the notary), and use a custom configuration file. + for writing a custom notary may change in the future. -Similarly to writing an oracle service, the first step is to create a service class in your CorDapp and annotate it -with ``@CordaService``. The Corda node scans for any class with this annotation and initialises them. The custom notary -service class should provide a constructor with two parameters of types ``AppServiceHub`` and ``PublicKey``. +The first step is to create a service class in your CorDapp that extends the ``NotaryService`` abstract class. +This will ensure that it is recognised as a notary service. +The custom notary service class should provide a constructor with two parameters of types ``ServiceHubInternal`` and ``PublicKey``. +Note that ``ServiceHubInternal`` does not provide any API stability guarantees. .. literalinclude:: ../../samples/notary-demo/src/main/kotlin/net/corda/notarydemo/MyCustomNotaryService.kt :language: kotlin @@ -32,5 +31,5 @@ To enable the service, add the following to the node configuration: notary : { validating : true # Set to false if your service is non-validating - custom : true + className : "net.corda.notarydemo.MyCustomValidatingNotaryService" # The fully qualified name of your service class } \ No newline at end of file diff --git a/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt b/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt index bea2d435ae..141c92778b 100644 --- a/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt +++ b/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt @@ -42,9 +42,12 @@ import net.corda.node.services.ContractUpgradeHandler import net.corda.node.services.FinalityHandler import net.corda.node.services.NotaryChangeHandler import net.corda.node.services.api.* -import net.corda.node.services.config.* +import net.corda.node.services.config.NodeConfiguration +import net.corda.node.services.config.NotaryConfig +import net.corda.node.services.config.configureWithDevSSLCertificate import net.corda.node.services.config.rpc.NodeRpcOptions import net.corda.node.services.config.shell.toShellConfig +import net.corda.node.services.config.shouldInitCrashShell import net.corda.node.services.events.NodeSchedulerService import net.corda.node.services.events.ScheduledActivityObserver import net.corda.node.services.identity.PersistentIdentityService @@ -59,7 +62,8 @@ import net.corda.node.services.network.PersistentNetworkMapCache import net.corda.node.services.persistence.* import net.corda.node.services.schema.NodeSchemaService import net.corda.node.services.statemachine.* -import net.corda.node.services.transactions.* +import net.corda.node.services.transactions.InMemoryTransactionVerifierService +import net.corda.node.services.transactions.SimpleNotaryService import net.corda.node.services.upgrade.ContractUpgradeServiceImpl import net.corda.node.services.vault.NodeVaultService import net.corda.node.utilities.* @@ -518,9 +522,9 @@ abstract class AbstractNode(val configuration: NodeConfiguration, private fun installCordaServices(myNotaryIdentity: PartyAndCertificate?) { val loadedServices = cordappLoader.cordapps.flatMap { it.services } - filterServicesToInstall(loadedServices).forEach { + loadedServices.forEach { try { - installCordaService(flowStarter, it, myNotaryIdentity) + installCordaService(flowStarter, it) } catch (e: NoSuchMethodException) { log.error("${it.name}, as a Corda service, must have a constructor with a single parameter of type " + ServiceHub::class.java.name) @@ -532,24 +536,6 @@ abstract class AbstractNode(val configuration: NodeConfiguration, } } - private fun filterServicesToInstall(loadedServices: List>): List> { - val customNotaryServiceList = loadedServices.filter { isNotaryService(it) } - if (customNotaryServiceList.isNotEmpty()) { - if (configuration.notary?.custom == true) { - require(customNotaryServiceList.size == 1) { - "Attempting to install more than one notary service: ${customNotaryServiceList.joinToString()}" - } - } else return loadedServices - customNotaryServiceList - } - return loadedServices - } - - /** - * If the [serviceClass] is a notary service, it will only be enabled if the "custom" flag is set in - * the notary configuration. - */ - private fun isNotaryService(serviceClass: Class<*>) = NotaryService::class.java.isAssignableFrom(serviceClass) - /** * This customizes the ServiceHub for each CordaService that is initiating flows. */ @@ -590,53 +576,30 @@ abstract class AbstractNode(val configuration: NodeConfiguration, override fun hashCode() = Objects.hash(serviceHub, flowStarter, serviceInstance) } - private fun installCordaService(flowStarter: FlowStarter, serviceClass: Class, myNotaryIdentity: PartyAndCertificate?) { + private fun installCordaService(flowStarter: FlowStarter, serviceClass: Class) { serviceClass.requireAnnotation() val service = try { - if (isNotaryService(serviceClass)) { - myNotaryIdentity ?: throw IllegalStateException("Trying to install a notary service but no notary identity specified") - try { - val constructor = serviceClass.getDeclaredConstructor(ServiceHubInternal::class.java, PublicKey::class.java).apply { isAccessible = true } - constructor.newInstance(services, myNotaryIdentity.owningKey ) - } catch (ex: NoSuchMethodException) { - val constructor = serviceClass.getDeclaredConstructor(AppServiceHub::class.java, PublicKey::class.java).apply { isAccessible = true } - val serviceContext = AppServiceHubImpl(services, flowStarter) - val service = constructor.newInstance(serviceContext, myNotaryIdentity.owningKey) - serviceContext.serviceInstance = service - service - } - } else { - try { - val serviceContext = AppServiceHubImpl(services, flowStarter) - val extendedServiceConstructor = serviceClass.getDeclaredConstructor(AppServiceHub::class.java).apply { isAccessible = true } - val service = extendedServiceConstructor.newInstance(serviceContext) - serviceContext.serviceInstance = service - service - } catch (ex: NoSuchMethodException) { - val constructor = serviceClass.getDeclaredConstructor(ServiceHub::class.java).apply { isAccessible = true } - log.warn("${serviceClass.name} is using legacy CordaService constructor with ServiceHub parameter. " + - "Upgrade to an AppServiceHub parameter to enable updated API features.") - constructor.newInstance(services) - } - } + val serviceContext = AppServiceHubImpl(services, flowStarter) + val extendedServiceConstructor = serviceClass.getDeclaredConstructor(AppServiceHub::class.java).apply { isAccessible = true } + val service = extendedServiceConstructor.newInstance(serviceContext) + serviceContext.serviceInstance = service + service + } catch (ex: NoSuchMethodException) { + val constructor = serviceClass.getDeclaredConstructor(ServiceHub::class.java).apply { isAccessible = true } + log.warn("${serviceClass.name} is using legacy CordaService constructor with ServiceHub parameter. " + + "Upgrade to an AppServiceHub parameter to enable updated API features.") + constructor.newInstance(services) } catch (e: InvocationTargetException) { throw ServiceInstantiationException(e.cause) } cordappServices.putInstance(serviceClass, service) - if (service is NotaryService) handleCustomNotaryService(service) service.tokenize() log.info("Installed ${serviceClass.name} Corda service") } - private fun handleCustomNotaryService(service: NotaryService) { - runOnStop += service::stop - installCoreFlow(NotaryFlow.Client::class, service::createServiceFlow) - service.start() - } - private fun registerCordappFlows() { cordappLoader.cordapps.flatMap { it.initiatedFlows } .forEach { diff --git a/node/src/main/kotlin/net/corda/node/services/config/NodeConfiguration.kt b/node/src/main/kotlin/net/corda/node/services/config/NodeConfiguration.kt index 28e73e4600..5da2e9724f 100644 --- a/node/src/main/kotlin/net/corda/node/services/config/NodeConfiguration.kt +++ b/node/src/main/kotlin/net/corda/node/services/config/NodeConfiguration.kt @@ -122,13 +122,12 @@ fun NodeConfiguration.shouldInitCrashShell() = shouldStartLocalShell() || should data class NotaryConfig(val validating: Boolean, val raft: RaftConfig? = null, val bftSMaRt: BFTSMaRtConfiguration? = null, - val custom: Boolean = false, val serviceLegalName: CordaX500Name? = null, val className: String = "net.corda.node.services.transactions.SimpleNotaryService" ) { init { - require(raft == null || bftSMaRt == null || !custom) { - "raft, bftSMaRt, and custom configs cannot be specified together" + require(raft == null || bftSMaRt == null) { + "raft and bftSMaRt configs cannot be specified together" } } diff --git a/node/src/test/kotlin/net/corda/node/services/TimedFlowTests.kt b/node/src/test/kotlin/net/corda/node/services/TimedFlowTests.kt index 71eef309ad..3dd2d51152 100644 --- a/node/src/test/kotlin/net/corda/node/services/TimedFlowTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/TimedFlowTests.kt @@ -88,7 +88,6 @@ class TimedFlowTests { val networkParameters = NetworkParametersCopier(testNetworkParameters(listOf(NotaryInfo(notaryIdentity, true)))) val notaryConfig = mock { - whenever(it.custom).thenReturn(true) whenever(it.isClusterConfig).thenReturn(true) whenever(it.validating).thenReturn(true) whenever(it.className).thenReturn(TestNotaryService::class.java.name) diff --git a/samples/notary-demo/build.gradle b/samples/notary-demo/build.gradle index 7f52642b53..acf513708c 100644 --- a/samples/notary-demo/build.gradle +++ b/samples/notary-demo/build.gradle @@ -76,7 +76,10 @@ task deployNodesCustom(type: Cordform, dependsOn: 'jar') { address "localhost:10010" adminAddress "localhost:10110" } - notary = [validating: true, "custom": true] + notary = [ + validating: true, + className: "net.corda.notarydemo.MyCustomValidatingNotaryService" + ] } } diff --git a/samples/notary-demo/src/main/kotlin/net/corda/notarydemo/MyCustomNotaryService.kt b/samples/notary-demo/src/main/kotlin/net/corda/notarydemo/MyCustomNotaryService.kt index ad684fc489..51bc918242 100644 --- a/samples/notary-demo/src/main/kotlin/net/corda/notarydemo/MyCustomNotaryService.kt +++ b/samples/notary-demo/src/main/kotlin/net/corda/notarydemo/MyCustomNotaryService.kt @@ -8,8 +8,6 @@ import net.corda.core.internal.ResolveTransactionsFlow import net.corda.core.internal.notary.NotaryInternalException import net.corda.core.internal.notary.NotaryServiceFlow import net.corda.core.internal.notary.TrustedAuthorityNotaryService -import net.corda.core.node.AppServiceHub -import net.corda.core.node.services.CordaService import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.TransactionWithSignatures import net.corda.core.transactions.WireTransaction @@ -19,13 +17,12 @@ import java.security.PublicKey import java.security.SignatureException /** - * A custom notary service should provide a constructor that accepts two parameters of types [AppServiceHub] and [PublicKey]. + * A custom notary service should provide a constructor that accepts two parameters of types [ServiceHubInternal] and [PublicKey]. * * Note that the support for custom notaries is still experimental – at present only a single-node notary service can be customised. * The notary-related APIs might change in the future. */ // START 1 -@CordaService class MyCustomValidatingNotaryService(override val services: ServiceHubInternal, override val notaryIdentityKey: PublicKey) : TrustedAuthorityNotaryService() { override val uniquenessProvider = PersistentUniquenessProvider(services.clock, services.database, services.cacheFactory) From 0e68f26c0f60207399d9978fe8eaef56bf451689 Mon Sep 17 00:00:00 2001 From: Viktor Kolomeyko Date: Wed, 10 Oct 2018 17:52:00 +0100 Subject: [PATCH 3/6] ENT-2569: Clean-up content of `registeredShutdowns`. (#4048) Please see comment for more info. --- .../net/corda/testing/node/internal/ShutdownManager.kt | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/ShutdownManager.kt b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/ShutdownManager.kt index f69c4e8de1..f7cd8f7823 100644 --- a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/ShutdownManager.kt +++ b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/ShutdownManager.kt @@ -48,7 +48,12 @@ class ShutdownManager(private val executorService: ExecutorService) { emptyList Unit>>() } else { isShutdown = true - registeredShutdowns + val result = ArrayList(registeredShutdowns) + // It is important to clear `registeredShutdowns` that has been actioned upon as more than 1 driver can be created per test. + // Given that `ShutdownManager` is reachable from `ApplicationShutdownHooks`, everything that was scheduled for shutdown + // during 1st driver launch will not be eligible for GC during second driver launch therefore retained in memory. + registeredShutdowns.clear() + result } } From 825c544cacb3f7e53646b2eb59be6efd4a436559 Mon Sep 17 00:00:00 2001 From: Chris Rankin Date: Thu, 11 Oct 2018 13:48:32 +0100 Subject: [PATCH 4/6] ENT-1906: Modify the DJVM to wrap Java primitive types. (#4035) * WIP - sandbox classloading * Fix handling of Appendable in the sandbox. * WIP - Load wrapped Java types into SandboxClassLoader. * Add explicit toDJVM() invocation after invoking Object.toString(). * Add simple ThreadLocal to the sandbox to complete AbstractStringBuilder. * Add support for Enum types inside the sandbox. * Simplify type conversions into and out of the sandbox. * Small refactors and comments to tidy up code. * Fix Enum support to include EnumSet and EnumMap. * Fix use of "$" in whitelist regexps. * Add extra methods (i.e. bridges) to stitched interfaces. * Rename ToDJVMStringWrapper to StringReturnTypeWrapper. * Support lambdas within the sandbox. * Fix mapping of java.lang.System into the sandbox. * Don't remap exception classes that we catch into sandbox classes. * Remove unnecessary "bootstrap" classes from the DJVM jar. * Ensure that Character.UnicodeScript is available inside the sandbox. * Tweak sandboxed implementations of System and Runtime. * Ensure that Character.UnicodeScript is loaded correctly as Enum type. * Disallow invoking methods of ClassLoader inside the sandbox. * Apply updates after review. * More review fixes. --- djvm/build.gradle | 14 + .../java/sandbox/java/lang/Appendable.java | 19 + .../main/java/sandbox/java/lang/Boolean.java | 100 ++++ .../src/main/java/sandbox/java/lang/Byte.java | 129 +++++ .../java/sandbox/java/lang/CharSequence.java | 21 + .../java/sandbox/java/lang/Character.java | 481 ++++++++++++++++++ .../java/sandbox/java/lang/Comparable.java | 8 + .../main/java/sandbox/java/lang/Double.java | 163 ++++++ .../src/main/java/sandbox/java/lang/Enum.java | 27 + .../main/java/sandbox/java/lang/Float.java | 163 ++++++ .../main/java/sandbox/java/lang/Integer.java | 241 +++++++++ .../main/java/sandbox/java/lang/Iterable.java | 15 + .../src/main/java/sandbox/java/lang/Long.java | 239 +++++++++ .../main/java/sandbox/java/lang/Number.java | 21 + .../main/java/sandbox/java/lang/Object.java | 71 +++ .../main/java/sandbox/java/lang/Runtime.java | 27 + .../main/java/sandbox/java/lang/Short.java | 128 +++++ .../main/java/sandbox/java/lang/String.java | 398 +++++++++++++++ .../java/sandbox/java/lang/StringBuffer.java | 20 + .../java/sandbox/java/lang/StringBuilder.java | 20 + .../main/java/sandbox/java/lang/System.java | 28 + .../java/sandbox/java/lang/ThreadLocal.java | 59 +++ .../sandbox/java/nio/charset/Charset.java | 18 + .../java/sandbox/java/util/Comparator.java | 9 + .../java/sandbox/java/util/LinkedHashMap.java | 13 + .../main/java/sandbox/java/util/Locale.java | 9 + djvm/src/main/java/sandbox/java/util/Map.java | 7 + .../sandbox/java/util/function/Function.java | 10 + .../sandbox/java/util/function/Supplier.java | 10 + .../java/sandbox/sun/misc/JavaLangAccess.java | 10 + .../java/sandbox/sun/misc/SharedSecrets.java | 20 + .../djvm/analysis/AnalysisConfiguration.kt | 127 ++++- .../djvm/analysis/ClassAndMemberVisitor.kt | 23 +- .../net/corda/djvm/analysis/ClassResolver.kt | 11 +- .../net/corda/djvm/analysis/Whitelist.kt | 39 +- .../net/corda/djvm/code/ClassMutator.kt | 50 +- .../net/corda/djvm/code/EmitterModule.kt | 94 +++- .../main/kotlin/net/corda/djvm/code/Types.kt | 4 +- .../code/instructions/ConstantInstruction.kt | 6 + .../djvm/code/instructions/MethodEntry.kt | 9 + .../corda/djvm/execution/SandboxExecutor.kt | 32 +- .../net/corda/djvm/rewiring/ClassRewriter.kt | 54 +- .../corda/djvm/rewiring/SandboxClassLoader.kt | 115 +++-- .../djvm/rewiring/SandboxClassRemapper.kt | 52 ++ .../corda/djvm/rewiring/SandboxClassWriter.kt | 6 +- .../corda/djvm/rewiring/SandboxRemapper.kt | 32 +- .../rules/implementation/ArgumentUnwrapper.kt | 35 ++ .../DisallowNonDeterministicMethods.kt | 14 +- .../rules/implementation/ReturnTypeWrapper.kt | 27 + .../implementation/RewriteClassMethods.kt | 56 ++ .../implementation/StaticConstantRemover.kt | 30 ++ .../implementation/StringConstantWrapper.kt | 22 + .../implementation/StubOutNativeMethods.kt | 2 +- .../StubOutReflectionMethods.kt | 2 +- .../net/corda/djvm/source/ClassSource.kt | 3 + .../net/corda/djvm/utilities/Discovery.kt | 2 +- .../corda/djvm/validation/RuleValidator.kt | 1 + djvm/src/main/kotlin/sandbox/Task.kt | 24 + .../src/main/kotlin/sandbox/java/lang/DJVM.kt | 158 ++++++ .../main/kotlin/sandbox/java/lang/Object.kt | 19 - .../main/kotlin/sandbox/java/lang/System.kt | 99 ---- .../kotlin/foo/bar/sandbox/KotlinClass.kt | 11 +- .../test/kotlin/net/corda/djvm/DJVMTest.kt | 126 +++++ .../test/kotlin/net/corda/djvm/TestBase.kt | 33 +- .../test/kotlin/net/corda/djvm/Utilities.kt | 22 + .../corda/djvm/analysis/ClassResolverTest.kt | 10 +- .../net/corda/djvm/analysis/WhitelistTest.kt | 4 +- .../assertions/AssertiveClassWithByteCode.kt | 5 + .../net/corda/djvm/costing/RuntimeCostTest.kt | 9 +- .../corda/djvm/execution/SandboxEnumTest.kt | 86 ++++ .../djvm/execution/SandboxExecutorTest.kt | 320 ++++++++++-- .../corda/djvm/rewiring/ClassRewriterTest.kt | 48 +- .../djvm/source/SourceClassLoaderTest.kt | 2 +- 73 files changed, 3986 insertions(+), 336 deletions(-) create mode 100644 djvm/src/main/java/sandbox/java/lang/Appendable.java create mode 100644 djvm/src/main/java/sandbox/java/lang/Boolean.java create mode 100644 djvm/src/main/java/sandbox/java/lang/Byte.java create mode 100644 djvm/src/main/java/sandbox/java/lang/CharSequence.java create mode 100644 djvm/src/main/java/sandbox/java/lang/Character.java create mode 100644 djvm/src/main/java/sandbox/java/lang/Comparable.java create mode 100644 djvm/src/main/java/sandbox/java/lang/Double.java create mode 100644 djvm/src/main/java/sandbox/java/lang/Enum.java create mode 100644 djvm/src/main/java/sandbox/java/lang/Float.java create mode 100644 djvm/src/main/java/sandbox/java/lang/Integer.java create mode 100644 djvm/src/main/java/sandbox/java/lang/Iterable.java create mode 100644 djvm/src/main/java/sandbox/java/lang/Long.java create mode 100644 djvm/src/main/java/sandbox/java/lang/Number.java create mode 100644 djvm/src/main/java/sandbox/java/lang/Object.java create mode 100644 djvm/src/main/java/sandbox/java/lang/Runtime.java create mode 100644 djvm/src/main/java/sandbox/java/lang/Short.java create mode 100644 djvm/src/main/java/sandbox/java/lang/String.java create mode 100644 djvm/src/main/java/sandbox/java/lang/StringBuffer.java create mode 100644 djvm/src/main/java/sandbox/java/lang/StringBuilder.java create mode 100644 djvm/src/main/java/sandbox/java/lang/System.java create mode 100644 djvm/src/main/java/sandbox/java/lang/ThreadLocal.java create mode 100644 djvm/src/main/java/sandbox/java/nio/charset/Charset.java create mode 100644 djvm/src/main/java/sandbox/java/util/Comparator.java create mode 100644 djvm/src/main/java/sandbox/java/util/LinkedHashMap.java create mode 100644 djvm/src/main/java/sandbox/java/util/Locale.java create mode 100644 djvm/src/main/java/sandbox/java/util/Map.java create mode 100644 djvm/src/main/java/sandbox/java/util/function/Function.java create mode 100644 djvm/src/main/java/sandbox/java/util/function/Supplier.java create mode 100644 djvm/src/main/java/sandbox/sun/misc/JavaLangAccess.java create mode 100644 djvm/src/main/java/sandbox/sun/misc/SharedSecrets.java create mode 100644 djvm/src/main/kotlin/net/corda/djvm/code/instructions/ConstantInstruction.kt create mode 100644 djvm/src/main/kotlin/net/corda/djvm/code/instructions/MethodEntry.kt create mode 100644 djvm/src/main/kotlin/net/corda/djvm/rewiring/SandboxClassRemapper.kt create mode 100644 djvm/src/main/kotlin/net/corda/djvm/rules/implementation/ArgumentUnwrapper.kt create mode 100644 djvm/src/main/kotlin/net/corda/djvm/rules/implementation/ReturnTypeWrapper.kt create mode 100644 djvm/src/main/kotlin/net/corda/djvm/rules/implementation/RewriteClassMethods.kt create mode 100644 djvm/src/main/kotlin/net/corda/djvm/rules/implementation/StaticConstantRemover.kt create mode 100644 djvm/src/main/kotlin/net/corda/djvm/rules/implementation/StringConstantWrapper.kt create mode 100644 djvm/src/main/kotlin/sandbox/Task.kt create mode 100644 djvm/src/main/kotlin/sandbox/java/lang/DJVM.kt delete mode 100644 djvm/src/main/kotlin/sandbox/java/lang/Object.kt delete mode 100644 djvm/src/main/kotlin/sandbox/java/lang/System.kt create mode 100644 djvm/src/test/kotlin/net/corda/djvm/DJVMTest.kt create mode 100644 djvm/src/test/kotlin/net/corda/djvm/Utilities.kt create mode 100644 djvm/src/test/kotlin/net/corda/djvm/execution/SandboxEnumTest.kt diff --git a/djvm/build.gradle b/djvm/build.gradle index db88e8c4c5..eb41df17cc 100644 --- a/djvm/build.gradle +++ b/djvm/build.gradle @@ -52,6 +52,20 @@ shadowJar { baseName 'corda-djvm' classifier '' relocate 'org.objectweb.asm', 'djvm.org.objectweb.asm' + + // These particular classes are only needed to "bootstrap" + // the compilation of the other sandbox classes. At runtime, + // we will generate better versions from deterministic-rt.jar. + exclude 'sandbox/java/lang/Appendable.class' + exclude 'sandbox/java/lang/CharSequence.class' + exclude 'sandbox/java/lang/Character\$*.class' + exclude 'sandbox/java/lang/Comparable.class' + exclude 'sandbox/java/lang/Enum.class' + exclude 'sandbox/java/lang/Iterable.class' + exclude 'sandbox/java/lang/StringBuffer.class' + exclude 'sandbox/java/lang/StringBuilder.class' + exclude 'sandbox/java/nio/**' + exclude 'sandbox/java/util/**' } assemble.dependsOn shadowJar diff --git a/djvm/src/main/java/sandbox/java/lang/Appendable.java b/djvm/src/main/java/sandbox/java/lang/Appendable.java new file mode 100644 index 0000000000..168607c511 --- /dev/null +++ b/djvm/src/main/java/sandbox/java/lang/Appendable.java @@ -0,0 +1,19 @@ +package sandbox.java.lang; + +import java.io.IOException; + +/** + * This is a dummy class that implements just enough of [java.lang.Appendable] + * to keep [sandbox.java.lang.StringBuilder], [sandbox.java.lang.StringBuffer] + * and [sandbox.java.lang.String] honest. + * Note that it does not extend [java.lang.Appendable]. + */ +public interface Appendable { + + Appendable append(CharSequence csq, int start, int end) throws IOException; + + Appendable append(CharSequence csq) throws IOException; + + Appendable append(char c) throws IOException; + +} diff --git a/djvm/src/main/java/sandbox/java/lang/Boolean.java b/djvm/src/main/java/sandbox/java/lang/Boolean.java new file mode 100644 index 0000000000..6d347fdd3e --- /dev/null +++ b/djvm/src/main/java/sandbox/java/lang/Boolean.java @@ -0,0 +1,100 @@ +package sandbox.java.lang; + +import org.jetbrains.annotations.NotNull; + +import java.io.Serializable; + +@SuppressWarnings({"unused", "WeakerAccess"}) +public final class Boolean extends Object implements Comparable, Serializable { + + public static final Boolean TRUE = new Boolean(true); + public static final Boolean FALSE = new Boolean(false); + + @SuppressWarnings("unchecked") + public static final Class TYPE = (Class) java.lang.Boolean.TYPE; + + private final boolean value; + + public Boolean(boolean value) { + this.value = value; + } + + public Boolean(String s) { + this(parseBoolean(s)); + } + + @Override + public boolean equals(java.lang.Object other) { + return (other instanceof Boolean) && ((Boolean) other).value == value; + } + + @Override + public int hashCode() { + return hashCode(value); + } + + public static int hashCode(boolean value) { + return java.lang.Boolean.hashCode(value); + } + + public boolean booleanValue() { + return value; + } + + @Override + @NotNull + public java.lang.String toString() { + return java.lang.Boolean.toString(value); + } + + @Override + @NotNull + public String toDJVMString() { + return toString(value); + } + + public static String toString(boolean b) { + return String.valueOf(b); + } + + @Override + @NotNull + java.lang.Boolean fromDJVM() { + return value; + } + + @Override + public int compareTo(@NotNull Boolean other) { + return compare(value, other.value); + } + + public static int compare(boolean x, boolean y) { + return java.lang.Boolean.compare(x, y); + } + + public static boolean parseBoolean(String s) { + return java.lang.Boolean.parseBoolean(String.fromDJVM(s)); + } + + public static Boolean valueOf(boolean b) { + return b ? TRUE : FALSE; + } + + public static Boolean valueOf(String s) { + return valueOf(parseBoolean(s)); + } + + public static boolean logicalAnd(boolean a, boolean b) { + return java.lang.Boolean.logicalAnd(a, b); + } + + public static boolean logicalOr(boolean a, boolean b) { + return java.lang.Boolean.logicalOr(a, b); + } + + public static boolean logicalXor(boolean a, boolean b) { + return java.lang.Boolean.logicalXor(a, b); + } + + public static Boolean toDJVM(java.lang.Boolean b) { return (b == null) ? null : new Boolean(b); } +} diff --git a/djvm/src/main/java/sandbox/java/lang/Byte.java b/djvm/src/main/java/sandbox/java/lang/Byte.java new file mode 100644 index 0000000000..95b329f25d --- /dev/null +++ b/djvm/src/main/java/sandbox/java/lang/Byte.java @@ -0,0 +1,129 @@ +package sandbox.java.lang; + +import org.jetbrains.annotations.NotNull; + +@SuppressWarnings({"unused", "WeakerAccess"}) +public final class Byte extends Number implements Comparable { + public static final byte MIN_VALUE = java.lang.Byte.MIN_VALUE; + public static final byte MAX_VALUE = java.lang.Byte.MAX_VALUE; + public static final int BYTES = java.lang.Byte.BYTES; + public static final int SIZE = java.lang.Byte.SIZE; + + @SuppressWarnings("unchecked") + public static final Class TYPE = (Class) java.lang.Byte.TYPE; + + private final byte value; + + public Byte(byte value) { + this.value = value; + } + + public Byte(String s) throws NumberFormatException { + this.value = parseByte(s); + } + + @Override + public byte byteValue() { + return value; + } + + @Override + public short shortValue() { + return (short) value; + } + + @Override + public int intValue() { + return (int) value; + } + + @Override + public long longValue() { + return (long) value; + } + + @Override + public float floatValue() { + return (float) value; + } + + @Override + public double doubleValue() { + return (double) value; + } + + @Override + public int hashCode() { + return hashCode(value); + } + + public static int hashCode(byte b) { + return java.lang.Byte.hashCode(b); + } + + @Override + public boolean equals(java.lang.Object other) { + return (other instanceof Byte) && ((Byte) other).value == value; + } + + @Override + @NotNull + public java.lang.String toString() { + return java.lang.Byte.toString(value); + } + + @Override + @NotNull + java.lang.Byte fromDJVM() { + return value; + } + + @Override + public int compareTo(@NotNull Byte other) { + return compare(this.value, other.value); + } + + public static int compare(byte x, byte y) { + return java.lang.Byte.compare(x, y); + } + + public static String toString(byte b) { + return Integer.toString(b); + } + + public static Byte valueOf(byte b) { + return new Byte(b); + } + + public static byte parseByte(String s, int radix) throws NumberFormatException { + return java.lang.Byte.parseByte(String.fromDJVM(s), radix); + } + + public static byte parseByte(String s) throws NumberFormatException { + return java.lang.Byte.parseByte(String.fromDJVM(s)); + } + + public static Byte valueOf(String s, int radix) throws NumberFormatException { + return toDJVM(java.lang.Byte.valueOf(String.fromDJVM(s), radix)); + } + + public static Byte valueOf(String s) throws NumberFormatException { + return toDJVM(java.lang.Byte.valueOf(String.fromDJVM(s))); + } + + public static Byte decode(String s) throws NumberFormatException { + return toDJVM(java.lang.Byte.decode(String.fromDJVM(s))); + } + + public static int toUnsignedInt(byte b) { + return java.lang.Byte.toUnsignedInt(b); + } + + public static long toUnsignedLong(byte b) { + return java.lang.Byte.toUnsignedLong(b); + } + + public static Byte toDJVM(java.lang.Byte b) { + return (b == null) ? null : valueOf(b); + } +} diff --git a/djvm/src/main/java/sandbox/java/lang/CharSequence.java b/djvm/src/main/java/sandbox/java/lang/CharSequence.java new file mode 100644 index 0000000000..1847103093 --- /dev/null +++ b/djvm/src/main/java/sandbox/java/lang/CharSequence.java @@ -0,0 +1,21 @@ +package sandbox.java.lang; + +import org.jetbrains.annotations.NotNull; + +/** + * This is a dummy class that implements just enough of [java.lang.CharSequence] + * to allow us to compile [sandbox.java.lang.String]. + */ +public interface CharSequence extends java.lang.CharSequence { + + @Override + CharSequence subSequence(int start, int end); + + @NotNull + String toDJVMString(); + + @Override + @NotNull + java.lang.String toString(); + +} diff --git a/djvm/src/main/java/sandbox/java/lang/Character.java b/djvm/src/main/java/sandbox/java/lang/Character.java new file mode 100644 index 0000000000..2db6054272 --- /dev/null +++ b/djvm/src/main/java/sandbox/java/lang/Character.java @@ -0,0 +1,481 @@ +package sandbox.java.lang; + +import org.jetbrains.annotations.NotNull; + +import java.io.Serializable; + +@SuppressWarnings({"unused", "WeakerAccess"}) +public final class Character extends Object implements Comparable, Serializable { + public static final int MIN_RADIX = java.lang.Character.MIN_RADIX; + public static final int MAX_RADIX = java.lang.Character.MAX_RADIX; + public static final char MIN_VALUE = java.lang.Character.MIN_VALUE; + public static final char MAX_VALUE = java.lang.Character.MAX_VALUE; + + @SuppressWarnings("unchecked") + public static final Class TYPE = (Class) java.lang.Character.TYPE; + + public static final byte UNASSIGNED = java.lang.Character.UNASSIGNED; + public static final byte UPPERCASE_LETTER = java.lang.Character.UPPERCASE_LETTER; + public static final byte LOWERCASE_LETTER = java.lang.Character.LOWERCASE_LETTER; + public static final byte TITLECASE_LETTER = java.lang.Character.TITLECASE_LETTER; + public static final byte MODIFIER_LETTER = java.lang.Character.MODIFIER_LETTER; + public static final byte OTHER_LETTER = java.lang.Character.OTHER_LETTER; + public static final byte NON_SPACING_MARK = java.lang.Character.NON_SPACING_MARK; + public static final byte ENCLOSING_MARK = java.lang.Character.ENCLOSING_MARK; + public static final byte COMBINING_SPACING_MARK = java.lang.Character.COMBINING_SPACING_MARK; + public static final byte DECIMAL_DIGIT_NUMBER = java.lang.Character.DECIMAL_DIGIT_NUMBER; + public static final byte LETTER_NUMBER = java.lang.Character.LETTER_NUMBER; + public static final byte OTHER_NUMBER = java.lang.Character.OTHER_NUMBER; + public static final byte SPACE_SEPARATOR = java.lang.Character.SPACE_SEPARATOR; + public static final byte LINE_SEPARATOR = java.lang.Character.LINE_SEPARATOR; + public static final byte PARAGRAPH_SEPARATOR = java.lang.Character.PARAGRAPH_SEPARATOR; + public static final byte CONTROL = java.lang.Character.CONTROL; + public static final byte FORMAT = java.lang.Character.FORMAT; + public static final byte PRIVATE_USE = java.lang.Character.PRIVATE_USE; + public static final byte SURROGATE = java.lang.Character.SURROGATE; + public static final byte DASH_PUNCTUATION = java.lang.Character.DASH_PUNCTUATION; + public static final byte START_PUNCTUATION = java.lang.Character.START_PUNCTUATION; + public static final byte END_PUNCTUATION = java.lang.Character.END_PUNCTUATION; + public static final byte CONNECTOR_PUNCTUATION = java.lang.Character.CONNECTOR_PUNCTUATION; + public static final byte OTHER_PUNCTUATION = java.lang.Character.OTHER_PUNCTUATION; + public static final byte MATH_SYMBOL = java.lang.Character.MATH_SYMBOL; + public static final byte CURRENCY_SYMBOL = java.lang.Character.CURRENCY_SYMBOL; + public static final byte MODIFIER_SYMBOL = java.lang.Character.MODIFIER_SYMBOL; + public static final byte OTHER_SYMBOL = java.lang.Character.OTHER_SYMBOL; + public static final byte INITIAL_QUOTE_PUNCTUATION = java.lang.Character.INITIAL_QUOTE_PUNCTUATION; + public static final byte FINAL_QUOTE_PUNCTUATION = java.lang.Character.FINAL_QUOTE_PUNCTUATION; + public static final byte DIRECTIONALITY_UNDEFINED = java.lang.Character.DIRECTIONALITY_UNDEFINED; + public static final byte DIRECTIONALITY_LEFT_TO_RIGHT = java.lang.Character.DIRECTIONALITY_LEFT_TO_RIGHT; + public static final byte DIRECTIONALITY_RIGHT_TO_LEFT = java.lang.Character.DIRECTIONALITY_RIGHT_TO_LEFT; + public static final byte DIRECTIONALITY_RIGHT_TO_LEFT_ARABIC = java.lang.Character.DIRECTIONALITY_RIGHT_TO_LEFT_ARABIC; + public static final byte DIRECTIONALITY_EUROPEAN_NUMBER = java.lang.Character.DIRECTIONALITY_EUROPEAN_NUMBER; + public static final byte DIRECTIONALITY_EUROPEAN_NUMBER_SEPARATOR = java.lang.Character.DIRECTIONALITY_EUROPEAN_NUMBER_SEPARATOR; + public static final byte DIRECTIONALITY_EUROPEAN_NUMBER_TERMINATOR = java.lang.Character.DIRECTIONALITY_EUROPEAN_NUMBER_TERMINATOR; + public static final byte DIRECTIONALITY_ARABIC_NUMBER = java.lang.Character.DIRECTIONALITY_ARABIC_NUMBER; + public static final byte DIRECTIONALITY_COMMON_NUMBER_SEPARATOR = java.lang.Character.DIRECTIONALITY_COMMON_NUMBER_SEPARATOR; + public static final byte DIRECTIONALITY_NONSPACING_MARK = java.lang.Character.DIRECTIONALITY_NONSPACING_MARK; + public static final byte DIRECTIONALITY_BOUNDARY_NEUTRAL = java.lang.Character.DIRECTIONALITY_BOUNDARY_NEUTRAL; + public static final byte DIRECTIONALITY_PARAGRAPH_SEPARATOR = java.lang.Character.DIRECTIONALITY_PARAGRAPH_SEPARATOR; + public static final byte DIRECTIONALITY_SEGMENT_SEPARATOR = java.lang.Character.DIRECTIONALITY_SEGMENT_SEPARATOR; + public static final byte DIRECTIONALITY_WHITESPACE = java.lang.Character.DIRECTIONALITY_WHITESPACE; + public static final byte DIRECTIONALITY_OTHER_NEUTRALS = java.lang.Character.DIRECTIONALITY_OTHER_NEUTRALS; + public static final byte DIRECTIONALITY_LEFT_TO_RIGHT_EMBEDDING = java.lang.Character.DIRECTIONALITY_LEFT_TO_RIGHT_EMBEDDING; + public static final byte DIRECTIONALITY_LEFT_TO_RIGHT_OVERRIDE = java.lang.Character.DIRECTIONALITY_LEFT_TO_RIGHT_OVERRIDE; + public static final byte DIRECTIONALITY_RIGHT_TO_LEFT_EMBEDDING = java.lang.Character.DIRECTIONALITY_RIGHT_TO_LEFT_EMBEDDING; + public static final byte DIRECTIONALITY_RIGHT_TO_LEFT_OVERRIDE = java.lang.Character.DIRECTIONALITY_RIGHT_TO_LEFT_OVERRIDE; + public static final byte DIRECTIONALITY_POP_DIRECTIONAL_FORMAT = java.lang.Character.DIRECTIONALITY_POP_DIRECTIONAL_FORMAT; + public static final char MIN_HIGH_SURROGATE = java.lang.Character.MIN_HIGH_SURROGATE; + public static final char MAX_HIGH_SURROGATE = java.lang.Character.MAX_HIGH_SURROGATE; + public static final char MIN_LOW_SURROGATE = java.lang.Character.MIN_LOW_SURROGATE; + public static final char MAX_LOW_SURROGATE = java.lang.Character.MAX_LOW_SURROGATE; + public static final char MIN_SURROGATE = java.lang.Character.MIN_SURROGATE; + public static final char MAX_SURROGATE = java.lang.Character.MAX_SURROGATE; + public static final int MIN_SUPPLEMENTARY_CODE_POINT = java.lang.Character.MIN_SUPPLEMENTARY_CODE_POINT; + public static final int MIN_CODE_POINT = java.lang.Character.MIN_CODE_POINT; + public static final int MAX_CODE_POINT = java.lang.Character.MAX_CODE_POINT; + public static final int BYTES = java.lang.Character.BYTES; + public static final int SIZE = java.lang.Character.SIZE; + + private final char value; + + public Character(char c) { + this.value = c; + } + + public char charValue() { + return this.value; + } + + @Override + public int hashCode() { + return hashCode(this.value); + } + + public static int hashCode(char value) { + return java.lang.Character.hashCode(value); + } + + @Override + public boolean equals(java.lang.Object other) { + return (other instanceof Character) && ((Character) other).value == value; + } + + @Override + @NotNull + public java.lang.String toString() { + return java.lang.Character.toString(value); + } + + @Override + @NotNull + public String toDJVMString() { + return toString(value); + } + + @Override + @NotNull + java.lang.Character fromDJVM() { + return value; + } + + @Override + public int compareTo(@NotNull Character var1) { + return compare(this.value, var1.value); + } + + public static int compare(char x, char y) { + return java.lang.Character.compare(x, y); + } + + public static String toString(char c) { + return String.toDJVM(java.lang.Character.toString(c)); + } + + public static Character valueOf(char c) { + return (c <= 127) ? Cache.cache[(int)c] : new Character(c); + } + + public static boolean isValidCodePoint(int codePoint) { + return java.lang.Character.isValidCodePoint(codePoint); + } + + public static boolean isBmpCodePoint(int codePoint) { + return java.lang.Character.isBmpCodePoint(codePoint); + } + + public static boolean isSupplementaryCodePoint(int codePoint) { + return java.lang.Character.isSupplementaryCodePoint(codePoint); + } + + public static boolean isHighSurrogate(char ch) { + return java.lang.Character.isHighSurrogate(ch); + } + + public static boolean isLowSurrogate(char ch) { + return java.lang.Character.isLowSurrogate(ch); + } + + public static boolean isSurrogate(char ch) { + return java.lang.Character.isSurrogate(ch); + } + + public static boolean isSurrogatePair(char high, char low) { + return java.lang.Character.isSurrogatePair(high, low); + } + + public static int charCount(int codePoint) { + return java.lang.Character.charCount(codePoint); + } + + public static int toCodePoint(char high, char low) { + return java.lang.Character.toCodePoint(high, low); + } + + public static int codePointAt(CharSequence seq, int index) { + return java.lang.Character.codePointAt(seq, index); + } + + public static int codePointAt(char[] a, int index) { + return java.lang.Character.codePointAt(a, index); + } + + public static int codePointAt(char[] a, int index, int limit) { + return java.lang.Character.codePointAt(a, index, limit); + } + + public static int codePointBefore(CharSequence seq, int index) { + return java.lang.Character.codePointBefore(seq, index); + } + + public static int codePointBefore(char[] a, int index) { + return java.lang.Character.codePointBefore(a, index); + } + + public static int codePointBefore(char[] a, int index, int limit) { + return java.lang.Character.codePointBefore(a, index, limit); + } + + public static char highSurrogate(int codePoint) { + return java.lang.Character.highSurrogate(codePoint); + } + + public static char lowSurrogate(int codePoint) { + return java.lang.Character.lowSurrogate(codePoint); + } + + public static int toChars(int codePoint, char[] dst, int dstIndex) { + return java.lang.Character.toChars(codePoint, dst, dstIndex); + } + + public static char[] toChars(int codePoint) { + return java.lang.Character.toChars(codePoint); + } + + public static int codePointCount(CharSequence seq, int beginIndex, int endIndex) { + return java.lang.Character.codePointCount(seq, beginIndex, endIndex); + } + + public static int codePointCount(char[] a, int offset, int count) { + return java.lang.Character.codePointCount(a, offset, count); + } + + public static int offsetByCodePoints(CharSequence seq, int index, int codePointOffset) { + return java.lang.Character.offsetByCodePoints(seq, index, codePointOffset); + } + + public static int offsetByCodePoints(char[] a, int start, int count, int index, int codePointOffset) { + return java.lang.Character.offsetByCodePoints(a, start, count, index, codePointOffset); + } + + public static boolean isLowerCase(char ch) { + return java.lang.Character.isLowerCase(ch); + } + + public static boolean isLowerCase(int codePoint) { + return java.lang.Character.isLowerCase(codePoint); + } + + public static boolean isUpperCase(char ch) { + return java.lang.Character.isUpperCase(ch); + } + + public static boolean isUpperCase(int codePoint) { + return java.lang.Character.isUpperCase(codePoint); + } + + public static boolean isTitleCase(char ch) { + return java.lang.Character.isTitleCase(ch); + } + + public static boolean isTitleCase(int codePoint) { + return java.lang.Character.isTitleCase(codePoint); + } + + public static boolean isDigit(char ch) { + return java.lang.Character.isDigit(ch); + } + + public static boolean isDigit(int codePoint) { + return java.lang.Character.isDigit(codePoint); + } + + public static boolean isDefined(char ch) { + return java.lang.Character.isDefined(ch); + } + + public static boolean isDefined(int codePoint) { + return java.lang.Character.isDefined(codePoint); + } + + public static boolean isLetter(char ch) { + return java.lang.Character.isLetter(ch); + } + + public static boolean isLetter(int codePoint) { + return java.lang.Character.isLetter(codePoint); + } + + public static boolean isLetterOrDigit(char ch) { + return java.lang.Character.isLetterOrDigit(ch); + } + + public static boolean isLetterOrDigit(int codePoint) { + return java.lang.Character.isLetterOrDigit(codePoint); + } + + @Deprecated + public static boolean isJavaLetter(char ch) { + return java.lang.Character.isJavaLetter(ch); + } + + @Deprecated + public static boolean isJavaLetterOrDigit(char ch) { + return java.lang.Character.isJavaLetterOrDigit(ch); + } + + public static boolean isAlphabetic(int codePoint) { + return java.lang.Character.isAlphabetic(codePoint); + } + + public static boolean isIdeographic(int codePoint) { + return java.lang.Character.isIdeographic(codePoint); + } + + public static boolean isJavaIdentifierStart(char ch) { + return java.lang.Character.isJavaIdentifierStart(ch); + } + + public static boolean isJavaIdentifierStart(int codePoint) { + return java.lang.Character.isJavaIdentifierStart(codePoint); + } + + public static boolean isJavaIdentifierPart(char ch) { + return java.lang.Character.isJavaIdentifierPart(ch); + } + + public static boolean isJavaIdentifierPart(int codePoint) { + return java.lang.Character.isJavaIdentifierPart(codePoint); + } + + public static boolean isUnicodeIdentifierStart(char ch) { + return java.lang.Character.isUnicodeIdentifierStart(ch); + } + + public static boolean isUnicodeIdentifierStart(int codePoint) { + return java.lang.Character.isUnicodeIdentifierStart(codePoint); + } + + public static boolean isUnicodeIdentifierPart(char ch) { + return java.lang.Character.isUnicodeIdentifierPart(ch); + } + + public static boolean isUnicodeIdentifierPart(int codePoint) { + return java.lang.Character.isUnicodeIdentifierPart(codePoint); + } + + public static boolean isIdentifierIgnorable(char ch) { + return java.lang.Character.isIdentifierIgnorable(ch); + } + + public static boolean isIdentifierIgnorable(int codePoint) { + return java.lang.Character.isIdentifierIgnorable(codePoint); + } + + public static char toLowerCase(char ch) { + return java.lang.Character.toLowerCase(ch); + } + + public static int toLowerCase(int codePoint) { + return java.lang.Character.toLowerCase(codePoint); + } + + public static char toUpperCase(char ch) { + return java.lang.Character.toUpperCase(ch); + } + + public static int toUpperCase(int codePoint) { + return java.lang.Character.toUpperCase(codePoint); + } + + public static char toTitleCase(char ch) { + return java.lang.Character.toTitleCase(ch); + } + + public static int toTitleCase(int codePoint) { + return java.lang.Character.toTitleCase(codePoint); + } + + public static int digit(char ch, int radix) { + return java.lang.Character.digit(ch, radix); + } + + public static int digit(int codePoint, int radix) { + return java.lang.Character.digit(codePoint, radix); + } + + public static int getNumericValue(char ch) { + return java.lang.Character.getNumericValue(ch); + } + + public static int getNumericValue(int codePoint) { + return java.lang.Character.getNumericValue(codePoint); + } + + @Deprecated + public static boolean isSpace(char ch) { + return java.lang.Character.isSpace(ch); + } + + public static boolean isSpaceChar(char ch) { + return java.lang.Character.isSpaceChar(ch); + } + + public static boolean isSpaceChar(int codePoint) { + return java.lang.Character.isSpaceChar(codePoint); + } + + public static boolean isWhitespace(char ch) { + return java.lang.Character.isWhitespace(ch); + } + + public static boolean isWhitespace(int codePoint) { + return java.lang.Character.isWhitespace(codePoint); + } + + public static boolean isISOControl(char ch) { + return java.lang.Character.isISOControl(ch); + } + + public static boolean isISOControl(int codePoint) { + return java.lang.Character.isISOControl(codePoint); + } + + public static int getType(char ch) { + return java.lang.Character.getType(ch); + } + + public static int getType(int codePoint) { + return java.lang.Character.getType(codePoint); + } + + public static char forDigit(int digit, int radix) { + return java.lang.Character.forDigit(digit, radix); + } + + public static byte getDirectionality(char ch) { + return java.lang.Character.getDirectionality(ch); + } + + public static byte getDirectionality(int codePoint) { + return java.lang.Character.getDirectionality(codePoint); + } + + public static boolean isMirrored(char ch) { + return java.lang.Character.isMirrored(ch); + } + + public static boolean isMirrored(int codePoint) { + return java.lang.Character.isMirrored(codePoint); + } + + public static String getName(int codePoint) { + return String.toDJVM(java.lang.Character.getName(codePoint)); + } + + public static Character toDJVM(java.lang.Character c) { + return (c == null) ? null : valueOf(c); + } + + // These three nested classes are placeholders to ensure that + // the Character class bytecode is generated correctly. The + // real classes will be loaded from the from the bootstrap jar + // and then mapped into the sandbox.* namespace. + public static final class UnicodeScript extends Enum { + private UnicodeScript(String name, int index) { + super(name, index); + } + + @Override + public int compareTo(@NotNull UnicodeScript other) { + throw new UnsupportedOperationException("Bootstrap implementation"); + } + } + public static final class UnicodeBlock extends Subset {} + public static class Subset extends Object {} + + /** + * Keep pre-allocated instances of the first 128 characters + * on the basis that these will be used most frequently. + */ + private static class Cache { + private static final Character[] cache = new Character[128]; + + static { + for (int c = 0; c < cache.length; ++c) { + cache[c] = new Character((char) c); + } + } + + private Cache() {} + } +} diff --git a/djvm/src/main/java/sandbox/java/lang/Comparable.java b/djvm/src/main/java/sandbox/java/lang/Comparable.java new file mode 100644 index 0000000000..686539c1b4 --- /dev/null +++ b/djvm/src/main/java/sandbox/java/lang/Comparable.java @@ -0,0 +1,8 @@ +package sandbox.java.lang; + +/** + * This is a dummy class that implements just enough of [java.lang.Comparable] + * to allow us to compile [sandbox.java.lang.String]. + */ +public interface Comparable extends java.lang.Comparable { +} diff --git a/djvm/src/main/java/sandbox/java/lang/Double.java b/djvm/src/main/java/sandbox/java/lang/Double.java new file mode 100644 index 0000000000..d3488edde2 --- /dev/null +++ b/djvm/src/main/java/sandbox/java/lang/Double.java @@ -0,0 +1,163 @@ +package sandbox.java.lang; + +import org.jetbrains.annotations.NotNull; + +@SuppressWarnings({"unused", "WeakerAccess"}) +public final class Double extends Number implements Comparable { + public static final double POSITIVE_INFINITY = java.lang.Double.POSITIVE_INFINITY; + public static final double NEGATIVE_INFINITY = java.lang.Double.NEGATIVE_INFINITY; + public static final double NaN = java.lang.Double.NaN; + public static final double MAX_VALUE = java.lang.Double.MAX_VALUE; + public static final double MIN_NORMAL = java.lang.Double.MIN_NORMAL; + public static final double MIN_VALUE = java.lang.Double.MIN_VALUE; + public static final int MAX_EXPONENT = java.lang.Double.MAX_EXPONENT; + public static final int MIN_EXPONENT = java.lang.Double.MIN_EXPONENT; + public static final int BYTES = java.lang.Double.BYTES; + public static final int SIZE = java.lang.Double.SIZE; + + @SuppressWarnings("unchecked") + public static final Class TYPE = (Class) java.lang.Double.TYPE; + + private final double value; + + public Double(double value) { + this.value = value; + } + + public Double(String s) throws NumberFormatException { + this.value = parseDouble(s); + } + + @Override + public double doubleValue() { + return value; + } + + @Override + public float floatValue() { + return (float)value; + } + + @Override + public long longValue() { + return (long)value; + } + + @Override + public int intValue() { + return (int)value; + } + + @Override + public short shortValue() { + return (short)value; + } + + @Override + public byte byteValue() { + return (byte)value; + } + + public boolean isNaN() { + return java.lang.Double.isNaN(value); + } + + public boolean isInfinite() { + return isInfinite(this.value); + } + + @Override + public boolean equals(java.lang.Object other) { + return (other instanceof Double) && doubleToLongBits(((Double)other).value) == doubleToLongBits(value); + } + + @Override + public int hashCode() { + return hashCode(value); + } + + public static int hashCode(double d) { + return java.lang.Double.hashCode(d); + } + + @Override + @NotNull + public java.lang.String toString() { + return java.lang.Double.toString(value); + } + + @Override + @NotNull + java.lang.Double fromDJVM() { + return value; + } + + @Override + public int compareTo(@NotNull Double other) { + return compare(this.value, other.value); + } + + public static String toString(double d) { + return String.toDJVM(java.lang.Double.toString(d)); + } + + public static String toHexString(double d) { + return String.toDJVM(java.lang.Double.toHexString(d)); + } + + public static Double valueOf(String s) throws NumberFormatException { + return toDJVM(java.lang.Double.valueOf(String.fromDJVM(s))); + } + + public static Double valueOf(double d) { + return new Double(d); + } + + public static double parseDouble(String s) throws NumberFormatException { + return java.lang.Double.parseDouble(String.fromDJVM(s)); + } + + public static boolean isNaN(double d) { + return java.lang.Double.isNaN(d); + } + + public static boolean isInfinite(double d) { + return java.lang.Double.isInfinite(d); + } + + public static boolean isFinite(double d) { + return java.lang.Double.isFinite(d); + } + + public static long doubleToLongBits(double d) { + return java.lang.Double.doubleToLongBits(d); + } + + public static long doubleToRawLongBits(double d) { + return java.lang.Double.doubleToRawLongBits(d); + } + + public static double longBitsToDouble(long bits) { + return java.lang.Double.longBitsToDouble(bits); + } + + public static int compare(double d1, double d2) { + return java.lang.Double.compare(d1, d2); + } + + public static double sum(double a, double b) { + return java.lang.Double.sum(a, b); + } + + public static double max(double a, double b) { + return java.lang.Double.max(a, b); + } + + public static double min(double a, double b) { + return java.lang.Double.min(a, b); + } + + public static Double toDJVM(java.lang.Double d) { + return (d == null) ? null : valueOf(d); + } +} diff --git a/djvm/src/main/java/sandbox/java/lang/Enum.java b/djvm/src/main/java/sandbox/java/lang/Enum.java new file mode 100644 index 0000000000..ffcdd8c916 --- /dev/null +++ b/djvm/src/main/java/sandbox/java/lang/Enum.java @@ -0,0 +1,27 @@ +package sandbox.java.lang; + +import java.io.Serializable; + +/** + * This is a dummy class. We will load the actual Enum class at run-time. + */ +@SuppressWarnings("unused") +public abstract class Enum> extends Object implements Comparable, Serializable { + + private final String name; + private final int ordinal; + + protected Enum(String name, int ordinal) { + this.name = name; + this.ordinal = ordinal; + } + + public String name() { + return name; + } + + public int ordinal() { + return ordinal; + } + +} diff --git a/djvm/src/main/java/sandbox/java/lang/Float.java b/djvm/src/main/java/sandbox/java/lang/Float.java new file mode 100644 index 0000000000..bebc75f916 --- /dev/null +++ b/djvm/src/main/java/sandbox/java/lang/Float.java @@ -0,0 +1,163 @@ +package sandbox.java.lang; + +import org.jetbrains.annotations.NotNull; + +@SuppressWarnings({"unused", "WeakerAccess"}) +public final class Float extends Number implements Comparable { + public static final float POSITIVE_INFINITY = java.lang.Float.POSITIVE_INFINITY; + public static final float NEGATIVE_INFINITY = java.lang.Float.NEGATIVE_INFINITY; + public static final float NaN = java.lang.Float.NaN; + public static final float MAX_VALUE = java.lang.Float.MAX_VALUE; + public static final float MIN_NORMAL = java.lang.Float.MIN_NORMAL; + public static final float MIN_VALUE = java.lang.Float.MIN_VALUE; + public static final int MAX_EXPONENT = java.lang.Float.MAX_EXPONENT; + public static final int MIN_EXPONENT = java.lang.Float.MIN_EXPONENT; + public static final int BYTES = java.lang.Float.BYTES; + public static final int SIZE = java.lang.Float.SIZE; + + @SuppressWarnings("unchecked") + public static final Class TYPE = (Class) java.lang.Float.TYPE; + + private final float value; + + public Float(float value) { + this.value = value; + } + + public Float(String s) throws NumberFormatException { + this.value = parseFloat(s); + } + + @Override + public int hashCode() { + return hashCode(value); + } + + public static int hashCode(float f) { + return java.lang.Float.hashCode(f); + } + + @Override + public boolean equals(java.lang.Object other) { + return other instanceof Float && floatToIntBits(((Float)other).value) == floatToIntBits(this.value); + } + + @Override + @NotNull + public java.lang.String toString() { + return java.lang.Float.toString(value); + } + + @Override + @NotNull + java.lang.Float fromDJVM() { + return value; + } + + @Override + public double doubleValue() { + return (double)value; + } + + @Override + public float floatValue() { + return value; + } + + @Override + public long longValue() { + return (long)value; + } + + @Override + public int intValue() { + return (int)value; + } + + @Override + public short shortValue() { + return (short)value; + } + + @Override + public byte byteValue() { + return (byte)value; + } + + @Override + public int compareTo(@NotNull Float other) { + return compare(this.value, other.value); + } + + public boolean isNaN() { + return isNaN(value); + } + + public boolean isInfinite() { + return isInfinite(value); + } + + public static String toString(float f) { + return String.valueOf(f); + } + + public static String toHexString(float f) { + return String.toDJVM(java.lang.Float.toHexString(f)); + } + + public static Float valueOf(String s) throws NumberFormatException { + return toDJVM(java.lang.Float.valueOf(String.fromDJVM(s))); + } + + public static Float valueOf(float f) { + return new Float(f); + } + + public static float parseFloat(String s) throws NumberFormatException { + return java.lang.Float.parseFloat(String.fromDJVM(s)); + } + + public static boolean isNaN(float f) { + return java.lang.Float.isNaN(f); + } + + public static boolean isInfinite(float f) { + return java.lang.Float.isInfinite(f); + } + + public static boolean isFinite(float f) { + return java.lang.Float.isFinite(f); + } + + public static int floatToIntBits(float f) { + return java.lang.Float.floatToIntBits(f); + } + + public static int floatToRawIntBits(float f) { + return java.lang.Float.floatToIntBits(f); + } + + public static float intBitsToFloat(int bits) { + return java.lang.Float.intBitsToFloat(bits); + } + + public static int compare(float f1, float f2) { + return java.lang.Float.compare(f1, f2); + } + + public static float sum(float a, float b) { + return java.lang.Float.sum(a, b); + } + + public static float max(float a, float b) { + return java.lang.Float.max(a, b); + } + + public static float min(float a, float b) { + return java.lang.Float.min(a, b); + } + + public static Float toDJVM(java.lang.Float f) { + return (f == null) ? null : valueOf(f); + } +} diff --git a/djvm/src/main/java/sandbox/java/lang/Integer.java b/djvm/src/main/java/sandbox/java/lang/Integer.java new file mode 100644 index 0000000000..ae05ea0f91 --- /dev/null +++ b/djvm/src/main/java/sandbox/java/lang/Integer.java @@ -0,0 +1,241 @@ +package sandbox.java.lang; + +import org.jetbrains.annotations.NotNull; + +@SuppressWarnings({"unused", "WeakerAccess"}) +public final class Integer extends Number implements Comparable { + + public static final int MIN_VALUE = java.lang.Integer.MIN_VALUE; + public static final int MAX_VALUE = java.lang.Integer.MAX_VALUE; + public static final int BYTES = java.lang.Integer.BYTES; + public static final int SIZE = java.lang.Integer.SIZE; + + static final int[] SIZE_TABLE = new int[] { 9, 99, 999, 9999, 99999, 999999, 9999999, 99999999, 999999999, MAX_VALUE }; + + @SuppressWarnings("unchecked") + public static final Class TYPE = (Class) java.lang.Integer.TYPE; + + private final int value; + + public Integer(int value) { + this.value = value; + } + + public Integer(String s) throws NumberFormatException { + this.value = parseInt(s, 10); + } + + @Override + public int hashCode() { + return Integer.hashCode(value); + } + + public static int hashCode(int i) { + return java.lang.Integer.hashCode(i); + } + + @Override + public boolean equals(java.lang.Object other) { + return (other instanceof Integer) && (value == ((Integer) other).value); + } + + @Override + public int intValue() { + return value; + } + + @Override + public long longValue() { + return value; + } + + @Override + public short shortValue() { + return (short) value; + } + + @Override + public byte byteValue() { + return (byte) value; + } + + @Override + public float floatValue() { + return (float) value; + } + + @Override + public double doubleValue() { + return (double) value; + } + + @Override + public int compareTo(@NotNull Integer other) { + return compare(this.value, other.value); + } + + @Override + @NotNull + public java.lang.String toString() { + return java.lang.Integer.toString(value); + } + + @Override + @NotNull + java.lang.Integer fromDJVM() { + return value; + } + + public static String toString(int i, int radix) { + return String.toDJVM(java.lang.Integer.toString(i, radix)); + } + + public static String toUnsignedString(int i, int radix) { + return String.toDJVM(java.lang.Integer.toUnsignedString(i, radix)); + } + + public static String toHexString(int i) { + return String.toDJVM(java.lang.Integer.toHexString(i)); + } + + public static String toOctalString(int i) { + return String.toDJVM(java.lang.Integer.toOctalString(i)); + } + + public static String toBinaryString(int i) { + return String.toDJVM(java.lang.Integer.toBinaryString(i)); + } + + public static String toString(int i) { + return String.toDJVM(java.lang.Integer.toString(i)); + } + + public static String toUnsignedString(int i) { + return String.toDJVM(java.lang.Integer.toUnsignedString(i)); + } + + public static int parseInt(String s, int radix) throws NumberFormatException { + return java.lang.Integer.parseInt(String.fromDJVM(s), radix); + } + + public static int parseInt(String s) throws NumberFormatException { + return java.lang.Integer.parseInt(String.fromDJVM(s)); + } + + public static int parseUnsignedInt(String s, int radix) throws NumberFormatException { + return java.lang.Integer.parseUnsignedInt(String.fromDJVM(s), radix); + } + + public static int parseUnsignedInt(String s) throws NumberFormatException { + return java.lang.Integer.parseUnsignedInt(String.fromDJVM(s)); + } + + public static Integer valueOf(String s, int radix) throws NumberFormatException { + return toDJVM(java.lang.Integer.valueOf(String.fromDJVM(s), radix)); + } + + public static Integer valueOf(String s) throws NumberFormatException { + return toDJVM(java.lang.Integer.valueOf(String.fromDJVM(s))); + } + + public static Integer valueOf(int i) { + return new Integer(i); + } + + public static Integer decode(String nm) throws NumberFormatException { + return new Integer(java.lang.Integer.decode(String.fromDJVM(nm))); + } + + public static int compare(int x, int y) { + return java.lang.Integer.compare(x, y); + } + + public static int compareUnsigned(int x, int y) { + return java.lang.Integer.compareUnsigned(x, y); + } + + public static long toUnsignedLong(int x) { + return java.lang.Integer.toUnsignedLong(x); + } + + public static int divideUnsigned(int dividend, int divisor) { + return java.lang.Integer.divideUnsigned(dividend, divisor); + } + + public static int remainderUnsigned(int dividend, int divisor) { + return java.lang.Integer.remainderUnsigned(dividend, divisor); + } + + public static int highestOneBit(int i) { + return java.lang.Integer.highestOneBit(i); + } + + public static int lowestOneBit(int i) { + return java.lang.Integer.lowestOneBit(i); + } + + public static int numberOfLeadingZeros(int i) { + return java.lang.Integer.numberOfLeadingZeros(i); + } + + public static int numberOfTrailingZeros(int i) { + return java.lang.Integer.numberOfTrailingZeros(i); + } + + public static int bitCount(int i) { + return java.lang.Integer.bitCount(i); + } + + public static int rotateLeft(int i, int distance) { + return java.lang.Integer.rotateLeft(i, distance); + } + + public static int rotateRight(int i, int distance) { + return java.lang.Integer.rotateRight(i, distance); + } + + public static int reverse(int i) { + return java.lang.Integer.reverse(i); + } + + public static int signum(int i) { + return java.lang.Integer.signum(i); + } + + public static int reverseBytes(int i) { + return java.lang.Integer.reverseBytes(i); + } + + public static int sum(int a, int b) { + return java.lang.Integer.sum(a, b); + } + + public static int max(int a, int b) { + return java.lang.Integer.max(a, b); + } + + public static int min(int a, int b) { + return java.lang.Integer.min(a, b); + } + + public static Integer toDJVM(java.lang.Integer i) { + return (i == null) ? null : valueOf(i); + } + + static int stringSize(final int number) { + int i = 0; + while (number > SIZE_TABLE[i]) { + ++i; + } + return i + 1; + } + + static void getChars(final int number, int index, char[] buffer) { + java.lang.String s = java.lang.Integer.toString(number); + int length = s.length(); + + while (length > 0) { + buffer[--index] = s.charAt(--length); + } + } +} diff --git a/djvm/src/main/java/sandbox/java/lang/Iterable.java b/djvm/src/main/java/sandbox/java/lang/Iterable.java new file mode 100644 index 0000000000..6032fd97db --- /dev/null +++ b/djvm/src/main/java/sandbox/java/lang/Iterable.java @@ -0,0 +1,15 @@ +package sandbox.java.lang; + +import org.jetbrains.annotations.NotNull; + +import java.util.Iterator; + +/** + * This is a dummy class that implements just enough of [java.lang.Iterable] + * to allow us to compile [sandbox.java.lang.String]. + */ +public interface Iterable extends java.lang.Iterable { + @Override + @NotNull + Iterator iterator(); +} diff --git a/djvm/src/main/java/sandbox/java/lang/Long.java b/djvm/src/main/java/sandbox/java/lang/Long.java new file mode 100644 index 0000000000..0f07158af1 --- /dev/null +++ b/djvm/src/main/java/sandbox/java/lang/Long.java @@ -0,0 +1,239 @@ +package sandbox.java.lang; + +import org.jetbrains.annotations.NotNull; + +@SuppressWarnings({"unused", "WeakerAccess"}) +public final class Long extends Number implements Comparable { + + public static final long MIN_VALUE = java.lang.Long.MIN_VALUE; + public static final long MAX_VALUE = java.lang.Long.MAX_VALUE; + public static final int BYTES = java.lang.Long.BYTES; + public static final int SIZE = java.lang.Long.SIZE; + + @SuppressWarnings("unchecked") + public static final Class TYPE = (Class) java.lang.Long.TYPE; + + private final long value; + + public Long(long value) { + this.value = value; + } + + public Long(String s) throws NumberFormatException { + this.value = parseLong(s, 10); + } + + @Override + public int hashCode() { + return hashCode(value); + } + + @Override + public boolean equals(java.lang.Object other) { + return (other instanceof Long) && ((Long) other).longValue() == value; + } + + public static int hashCode(long l) { + return java.lang.Long.hashCode(l); + } + + @Override + public int intValue() { + return (int) value; + } + + @Override + public long longValue() { + return value; + } + + @Override + public short shortValue() { + return (short) value; + } + + @Override + public byte byteValue() { + return (byte) value; + } + + @Override + public float floatValue() { + return (float) value; + } + + @Override + public double doubleValue() { + return (double) value; + } + + @Override + public int compareTo(@NotNull Long other) { + return compare(value, other.value); + } + + public static int compare(long x, long y) { + return java.lang.Long.compare(x, y); + } + + @Override + @NotNull + java.lang.Long fromDJVM() { + return value; + } + + @Override + @NotNull + public java.lang.String toString() { + return java.lang.Long.toString(value); + } + + public static String toString(long l) { + return String.toDJVM(java.lang.Long.toString(l)); + } + + public static String toString(long l, int radix) { + return String.toDJVM(java.lang.Long.toString(l, radix)); + } + + public static String toUnsignedString(long l, int radix) { + return String.toDJVM(java.lang.Long.toUnsignedString(l, radix)); + } + + public static String toUnsignedString(long l) { + return String.toDJVM(java.lang.Long.toUnsignedString(l)); + } + + public static String toHexString(long l) { + return String.toDJVM(java.lang.Long.toHexString(l)); + } + + public static String toOctalString(long l) { + return String.toDJVM(java.lang.Long.toOctalString(l)); + } + + public static String toBinaryString(long l) { + return String.toDJVM(java.lang.Long.toBinaryString(l)); + } + + public static long parseLong(String s, int radix) throws NumberFormatException { + return java.lang.Long.parseLong(String.fromDJVM(s), radix); + } + + public static long parseLong(String s) throws NumberFormatException { + return java.lang.Long.parseLong(String.fromDJVM(s)); + } + + public static long parseUnsignedLong(String s, int radix) throws NumberFormatException { + return java.lang.Long.parseUnsignedLong(String.fromDJVM(s), radix); + } + + public static long parseUnsignedLong(String s) throws NumberFormatException { + return java.lang.Long.parseUnsignedLong(String.fromDJVM(s)); + } + + public static Long valueOf(String s, int radix) throws NumberFormatException { + return toDJVM(java.lang.Long.valueOf(String.fromDJVM(s), radix)); + } + + public static Long valueOf(String s) throws NumberFormatException { + return toDJVM(java.lang.Long.valueOf(String.fromDJVM(s))); + } + + public static Long valueOf(long l) { + return new Long(l); + } + + public static Long decode(String s) throws NumberFormatException { + return toDJVM(java.lang.Long.decode(String.fromDJVM(s))); + } + + public static int compareUnsigned(long x, long y) { + return java.lang.Long.compareUnsigned(x, y); + } + + public static long divideUnsigned(long dividend, long divisor) { + return java.lang.Long.divideUnsigned(dividend, divisor); + } + + public static long remainderUnsigned(long dividend, long divisor) { + return java.lang.Long.remainderUnsigned(dividend, divisor); + } + + public static long highestOneBit(long l) { + return java.lang.Long.highestOneBit(l); + } + + public static long lowestOneBit(long l) { + return java.lang.Long.lowestOneBit(l); + } + + public static int numberOfLeadingZeros(long l) { + return java.lang.Long.numberOfLeadingZeros(l); + } + + public static int numberOfTrailingZeros(long l) { + return java.lang.Long.numberOfTrailingZeros(l); + } + + public static int bitCount(long l) { + return java.lang.Long.bitCount(l); + } + + public static long rotateLeft(long i, int distance) { + return java.lang.Long.rotateLeft(i, distance); + } + + public static long rotateRight(long i, int distance) { + return java.lang.Long.rotateRight(i, distance); + } + + public static long reverse(long l) { + return java.lang.Long.reverse(l); + } + + public static int signum(long l) { + return java.lang.Long.signum(l); + } + + public static long reverseBytes(long l) { + return java.lang.Long.reverseBytes(l); + } + + public static long sum(long a, long b) { + return java.lang.Long.sum(a, b); + } + + public static long max(long a, long b) { + return java.lang.Long.max(a, b); + } + + public static long min(long a, long b) { + return java.lang.Long.min(a, b); + } + + public static Long toDJVM(java.lang.Long l) { + return (l == null) ? null : valueOf(l); + } + + static int stringSize(final long number) { + long l = 10; + int i = 1; + + while ((i < 19) && (number >= l)) { + l *= 10; + ++i; + } + + return i; + } + + static void getChars(final long number, int index, char[] buffer) { + java.lang.String s = java.lang.Long.toString(number); + int length = s.length(); + + while (length > 0) { + buffer[--index] = s.charAt(--length); + } + } +} diff --git a/djvm/src/main/java/sandbox/java/lang/Number.java b/djvm/src/main/java/sandbox/java/lang/Number.java new file mode 100644 index 0000000000..89d0a7fd8e --- /dev/null +++ b/djvm/src/main/java/sandbox/java/lang/Number.java @@ -0,0 +1,21 @@ +package sandbox.java.lang; + +import org.jetbrains.annotations.NotNull; +import java.io.Serializable; + +@SuppressWarnings("unused") +public abstract class Number extends Object implements Serializable { + + public abstract double doubleValue(); + public abstract float floatValue(); + public abstract long longValue(); + public abstract int intValue(); + public abstract short shortValue(); + public abstract byte byteValue(); + + @Override + @NotNull + public String toDJVMString() { + return String.toDJVM(toString()); + } +} diff --git a/djvm/src/main/java/sandbox/java/lang/Object.java b/djvm/src/main/java/sandbox/java/lang/Object.java new file mode 100644 index 0000000000..4208a52a53 --- /dev/null +++ b/djvm/src/main/java/sandbox/java/lang/Object.java @@ -0,0 +1,71 @@ +package sandbox.java.lang; + +import org.jetbrains.annotations.NotNull; +import sandbox.net.corda.djvm.rules.RuleViolationError; + +public class Object { + + @Override + public int hashCode() { + return sandbox.java.lang.System.identityHashCode(this); + } + + @Override + @NotNull + public java.lang.String toString() { + return toDJVMString().toString(); + } + + @NotNull + public String toDJVMString() { + return String.toDJVM("sandbox.java.lang.Object@" + java.lang.Integer.toString(hashCode(), 16)); + } + + @NotNull + java.lang.Object fromDJVM() { + return this; + } + + public static java.lang.Object[] fromDJVM(java.lang.Object[] args) { + if (args == null) { + return null; + } + + java.lang.Object[] unwrapped = (java.lang.Object[]) java.lang.reflect.Array.newInstance( + fromDJVM(args.getClass().getComponentType()), args.length + ); + int i = 0; + for (java.lang.Object arg : args) { + unwrapped[i] = unwrap(arg); + ++i; + } + return unwrapped; + } + + private static java.lang.Object unwrap(java.lang.Object arg) { + if (arg instanceof Object) { + return ((Object) arg).fromDJVM(); + } else if (Object[].class.isAssignableFrom(arg.getClass())) { + return fromDJVM((Object[]) arg); + } else { + return arg; + } + } + + private static Class fromDJVM(Class type) { + try { + java.lang.String name = type.getName(); + return Class.forName(name.startsWith("sandbox.") ? name.substring(8) : name); + } catch (ClassNotFoundException e) { + throw new RuleViolationError(e.getMessage()); + } + } + + static java.util.Locale fromDJVM(sandbox.java.util.Locale locale) { + return java.util.Locale.forLanguageTag(locale.toLanguageTag().fromDJVM()); + } + + static java.nio.charset.Charset fromDJVM(sandbox.java.nio.charset.Charset charset) { + return java.nio.charset.Charset.forName(charset.name().fromDJVM()); + } +} diff --git a/djvm/src/main/java/sandbox/java/lang/Runtime.java b/djvm/src/main/java/sandbox/java/lang/Runtime.java new file mode 100644 index 0000000000..830233072b --- /dev/null +++ b/djvm/src/main/java/sandbox/java/lang/Runtime.java @@ -0,0 +1,27 @@ +package sandbox.java.lang; + +@SuppressWarnings("unused") +public final class Runtime extends Object { + private static final Runtime RUNTIME = new Runtime(); + + private Runtime() {} + + public static Runtime getRuntime() { + return RUNTIME; + } + + /** + * Everything inside the sandbox is single-threaded. + * @return 1 + */ + public int availableProcessors() { + return 1; + } + + public void loadLibrary(String libraryName) {} + + public void load(String fileName) {} + + public void runFinalization() {} + public void gc() {} +} diff --git a/djvm/src/main/java/sandbox/java/lang/Short.java b/djvm/src/main/java/sandbox/java/lang/Short.java new file mode 100644 index 0000000000..a0e1cbfd39 --- /dev/null +++ b/djvm/src/main/java/sandbox/java/lang/Short.java @@ -0,0 +1,128 @@ +package sandbox.java.lang; + +import org.jetbrains.annotations.NotNull; + +@SuppressWarnings({"unused", "WeakerAccess"}) +public final class Short extends Number implements Comparable { + public static final short MIN_VALUE = java.lang.Short.MIN_VALUE; + public static final short MAX_VALUE = java.lang.Short.MAX_VALUE; + public static final int BYTES = java.lang.Short.BYTES; + public static final int SIZE = java.lang.Short.SIZE; + + @SuppressWarnings("unchecked") + public static final Class TYPE = (Class) java.lang.Short.TYPE; + + private final short value; + + public Short(short value) { + this.value = value; + } + + public Short(String s) throws NumberFormatException { + this.value = parseShort(s); + } + + @Override + public byte byteValue() { + return (byte)value; + } + + @Override + public short shortValue() { + return value; + } + + @Override + public int intValue() { + return value; + } + + @Override + public long longValue() { + return (long)value; + } + + @Override + public float floatValue() { + return (float)value; + } + + @Override + public double doubleValue() { + return (double)value; + } + + @Override + @NotNull + public java.lang.String toString() { + return java.lang.Integer.toString(value); + } + + @Override + @NotNull + java.lang.Short fromDJVM() { + return value; + } + + @Override + public int hashCode() { + return hashCode(value); + } + + public static int hashCode(short value) { + return java.lang.Short.hashCode(value); + } + + @Override + public boolean equals(java.lang.Object other) { + return (other instanceof Short) && ((Short) other).value == value; + } + + public int compareTo(@NotNull Short other) { + return compare(this.value, other.value); + } + + public static int compare(short x, short y) { + return java.lang.Short.compare(x, y); + } + + public static short reverseBytes(short value) { + return java.lang.Short.reverseBytes(value); + } + + public static int toUnsignedInt(short x) { + return java.lang.Short.toUnsignedInt(x); + } + + public static long toUnsignedLong(short x) { + return java.lang.Short.toUnsignedLong(x); + } + + public static short parseShort(String s, int radix) throws NumberFormatException { + return java.lang.Short.parseShort(String.fromDJVM(s), radix); + } + + public static short parseShort(String s) throws NumberFormatException { + return java.lang.Short.parseShort(String.fromDJVM(s)); + } + + public static Short valueOf(String s, int radix) throws NumberFormatException { + return toDJVM(java.lang.Short.valueOf(String.fromDJVM(s), radix)); + } + + public static Short valueOf(String s) throws NumberFormatException { + return toDJVM(java.lang.Short.valueOf(String.fromDJVM(s))); + } + + public static Short valueOf(short s) { + return new Short(s); + } + + public static Short decode(String nm) throws NumberFormatException { + return toDJVM(java.lang.Short.decode(String.fromDJVM(nm))); + } + + public static Short toDJVM(java.lang.Short i) { + return (i == null) ? null : valueOf(i); + } +} \ No newline at end of file diff --git a/djvm/src/main/java/sandbox/java/lang/String.java b/djvm/src/main/java/sandbox/java/lang/String.java new file mode 100644 index 0000000000..4cce494d30 --- /dev/null +++ b/djvm/src/main/java/sandbox/java/lang/String.java @@ -0,0 +1,398 @@ +package sandbox.java.lang; + +import org.jetbrains.annotations.NotNull; +import sandbox.java.nio.charset.Charset; +import sandbox.java.util.Comparator; +import sandbox.java.util.Locale; + +import java.io.Serializable; +import java.io.UnsupportedEncodingException; + +@SuppressWarnings("unused") +public final class String extends Object implements Comparable, CharSequence, Serializable { + public static final Comparator CASE_INSENSITIVE_ORDER = new CaseInsensitiveComparator(); + + private static class CaseInsensitiveComparator extends Object implements Comparator, Serializable { + @Override + public int compare(String s1, String s2) { + return java.lang.String.CASE_INSENSITIVE_ORDER.compare(String.fromDJVM(s1), String.fromDJVM(s2)); + } + } + + private static final String TRUE = new String("true"); + private static final String FALSE = new String("false"); + + private final java.lang.String value; + + public String() { + this.value = ""; + } + + public String(java.lang.String value) { + this.value = value; + } + + public String(char value[]) { + this.value = new java.lang.String(value); + } + + public String(char value[], int offset, int count) { + this.value = new java.lang.String(value, offset, count); + } + + public String(int[] codePoints, int offset, int count) { + this.value = new java.lang.String(codePoints, offset, count); + } + + @Deprecated + public String(byte ascii[], int hibyte, int offset, int count) { + this.value = new java.lang.String(ascii, hibyte, offset, count); + } + + @Deprecated + public String(byte ascii[], int hibyte) { + this.value = new java.lang.String(ascii, hibyte); + } + + public String(byte bytes[], int offset, int length, String charsetName) + throws UnsupportedEncodingException { + this.value = new java.lang.String(bytes, offset, length, fromDJVM(charsetName)); + } + + public String(byte bytes[], int offset, int length, Charset charset) { + this.value = new java.lang.String(bytes, offset, length, fromDJVM(charset)); + } + + public String(byte bytes[], String charsetName) + throws UnsupportedEncodingException { + this.value = new java.lang.String(bytes, fromDJVM(charsetName)); + } + + public String(byte bytes[], Charset charset) { + this.value = new java.lang.String(bytes, fromDJVM(charset)); + } + + public String(byte bytes[], int offset, int length) { + this.value = new java.lang.String(bytes, offset, length); + } + + public String(byte bytes[]) { + this.value = new java.lang.String(bytes); + } + + public String(StringBuffer buffer) { + this.value = buffer.toString(); + } + + public String(StringBuilder builder) { + this.value = builder.toString(); + } + + @Override + public char charAt(int index) { + return value.charAt(index); + } + + @Override + public int length() { + return value.length(); + } + + public boolean isEmpty() { + return value.isEmpty(); + } + + public int codePointAt(int index) { + return value.codePointAt(index); + } + + public int codePointBefore(int index) { + return value.codePointBefore(index); + } + + public int codePointCount(int beginIndex, int endIndex) { + return value.codePointCount(beginIndex, endIndex); + } + + public int offsetByCodePoints(int index, int codePointOffset) { + return value.offsetByCodePoints(index, codePointOffset); + } + + public void getChars(int srcBegin, int srcEnd, char dst[], int dstBegin) { + value.getChars(srcBegin, srcEnd, dst, dstBegin); + } + + @Deprecated + public void getBytes(int srcBegin, int srcEnd, byte dst[], int dstBegin) { + value.getBytes(srcBegin, srcEnd, dst, dstBegin); + } + + public byte[] getBytes(String charsetName) throws UnsupportedEncodingException { + return value.getBytes(fromDJVM(charsetName)); + } + + public byte[] getBytes(Charset charset) { + return value.getBytes(fromDJVM(charset)); + } + + public byte[] getBytes() { + return value.getBytes(); + } + + @Override + public boolean equals(java.lang.Object other) { + return (other instanceof String) && ((String) other).value.equals(value); + } + + @Override + public int hashCode() { + return value.hashCode(); + } + + @Override + @NotNull + public java.lang.String toString() { + return value; + } + + @Override + @NotNull + public String toDJVMString() { + return this; + } + + @Override + @NotNull + java.lang.String fromDJVM() { + return value; + } + + public boolean contentEquals(StringBuffer sb) { + return value.contentEquals((CharSequence) sb); + } + + public boolean contentEquals(CharSequence cs) { + return value.contentEquals(cs); + } + + public boolean equalsIgnoreCase(String anotherString) { + return value.equalsIgnoreCase(fromDJVM(anotherString)); + } + + @Override + public CharSequence subSequence(int start, int end) { + return toDJVM((java.lang.String) value.subSequence(start, end)); + } + + @Override + public int compareTo(@NotNull String other) { + return value.compareTo(other.toString()); + } + + public int compareToIgnoreCase(String str) { + return value.compareToIgnoreCase(fromDJVM(str)); + } + + public boolean regionMatches(int toffset, String other, int ooffset, int len) { + return value.regionMatches(toffset, fromDJVM(other), ooffset, len); + } + + public boolean regionMatches(boolean ignoreCase, int toffset, + String other, int ooffset, int len) { + return value.regionMatches(ignoreCase, toffset, fromDJVM(other), ooffset, len); + } + + public boolean startsWith(String prefix, int toffset) { + return value.startsWith(fromDJVM(prefix), toffset); + } + + public boolean startsWith(String prefix) { + return value.startsWith(fromDJVM(prefix)); + } + + public boolean endsWith(String suffix) { + return value.endsWith(fromDJVM(suffix)); + } + + public int indexOf(int ch) { + return value.indexOf(ch); + } + + public int indexOf(int ch, int fromIndex) { + return value.indexOf(ch, fromIndex); + } + + public int lastIndexOf(int ch) { + return value.lastIndexOf(ch); + } + + public int lastIndexOf(int ch, int fromIndex) { + return value.lastIndexOf(ch, fromIndex); + } + + public int indexOf(String str) { + return value.indexOf(fromDJVM(str)); + } + + public int indexOf(String str, int fromIndex) { + return value.indexOf(fromDJVM(str), fromIndex); + } + + public int lastIndexOf(String str) { + return value.lastIndexOf(fromDJVM(str)); + } + + public int lastIndexOf(String str, int fromIndex) { + return value.lastIndexOf(fromDJVM(str), fromIndex); + } + + public String substring(int beginIndex) { + return toDJVM(value.substring(beginIndex)); + } + + public String substring(int beginIndex, int endIndex) { + return toDJVM(value.substring(beginIndex, endIndex)); + } + + public String concat(String str) { + return toDJVM(value.concat(fromDJVM(str))); + } + + public String replace(char oldChar, char newChar) { + return toDJVM(value.replace(oldChar, newChar)); + } + + public boolean matches(String regex) { + return value.matches(fromDJVM(regex)); + } + + public boolean contains(CharSequence s) { + return value.contains(s); + } + + public String replaceFirst(String regex, String replacement) { + return toDJVM(value.replaceFirst(fromDJVM(regex), fromDJVM(replacement))); + } + + public String replaceAll(String regex, String replacement) { + return toDJVM(value.replaceAll(fromDJVM(regex), fromDJVM(replacement))); + } + + public String replace(CharSequence target, CharSequence replacement) { + return toDJVM(value.replace(target, replacement)); + } + + public String[] split(String regex, int limit) { + return toDJVM(value.split(fromDJVM(regex), limit)); + } + + public String[] split(String regex) { + return toDJVM(value.split(fromDJVM(regex))); + } + + public String toLowerCase(Locale locale) { + return toDJVM(value.toLowerCase(fromDJVM(locale))); + } + + public String toLowerCase() { + return toDJVM(value.toLowerCase()); + } + + public String toUpperCase(Locale locale) { + return toDJVM(value.toUpperCase(fromDJVM(locale))); + } + + public String toUpperCase() { + return toDJVM(value.toUpperCase()); + } + + public String trim() { + return toDJVM(value.trim()); + } + + public char[] toCharArray() { + return value.toCharArray(); + } + + public static String format(String format, java.lang.Object... args) { + return toDJVM(java.lang.String.format(fromDJVM(format), fromDJVM(args))); + } + + public static String format(Locale locale, String format, java.lang.Object... args) { + return toDJVM(java.lang.String.format(fromDJVM(locale), fromDJVM(format), fromDJVM(args))); + } + + public static String join(CharSequence delimiter, CharSequence... elements) { + return toDJVM(java.lang.String.join(delimiter, elements)); + } + + public static String join(CharSequence delimiter, + Iterable elements) { + return toDJVM(java.lang.String.join(delimiter, elements)); + } + + public static String valueOf(java.lang.Object obj) { + return (obj instanceof Object) ? ((Object) obj).toDJVMString() : toDJVM(java.lang.String.valueOf(obj)); + } + + public static String valueOf(char data[]) { + return toDJVM(java.lang.String.valueOf(data)); + } + + public static String valueOf(char data[], int offset, int count) { + return toDJVM(java.lang.String.valueOf(data, offset, count)); + } + + public static String copyValueOf(char data[], int offset, int count) { + return toDJVM(java.lang.String.copyValueOf(data, offset, count)); + } + + public static String copyValueOf(char data[]) { + return toDJVM(java.lang.String.copyValueOf(data)); + } + + public static String valueOf(boolean b) { + return b ? TRUE : FALSE; + } + + public static String valueOf(char c) { + return toDJVM(java.lang.String.valueOf(c)); + } + + public static String valueOf(int i) { + return toDJVM(java.lang.String.valueOf(i)); + } + + public static String valueOf(long l) { + return toDJVM(java.lang.String.valueOf(l)); + } + + public static String valueOf(float f) { + return toDJVM(java.lang.String.valueOf(f)); + } + + public static String valueOf(double d) { + return toDJVM(java.lang.String.valueOf(d)); + } + + static String[] toDJVM(java.lang.String[] value) { + if (value == null) { + return null; + } + String[] result = new String[value.length]; + int i = 0; + for (java.lang.String v : value) { + result[i] = toDJVM(v); + ++i; + } + return result; + } + + public static String toDJVM(java.lang.String value) { + return (value == null) ? null : new String(value); + } + + public static java.lang.String fromDJVM(String value) { + return (value == null) ? null : value.fromDJVM(); + } +} \ No newline at end of file diff --git a/djvm/src/main/java/sandbox/java/lang/StringBuffer.java b/djvm/src/main/java/sandbox/java/lang/StringBuffer.java new file mode 100644 index 0000000000..e9cbcad328 --- /dev/null +++ b/djvm/src/main/java/sandbox/java/lang/StringBuffer.java @@ -0,0 +1,20 @@ +package sandbox.java.lang; + +import java.io.Serializable; + +/** + * This is a dummy class that implements just enough of [java.lang.StringBuffer] + * to allow us to compile [sandbox.java.lang.String]. + */ +public abstract class StringBuffer extends Object implements CharSequence, Appendable, Serializable { + + @Override + public abstract StringBuffer append(CharSequence seq); + + @Override + public abstract StringBuffer append(CharSequence seq, int start, int end); + + @Override + public abstract StringBuffer append(char c); + +} diff --git a/djvm/src/main/java/sandbox/java/lang/StringBuilder.java b/djvm/src/main/java/sandbox/java/lang/StringBuilder.java new file mode 100644 index 0000000000..ed80b2e508 --- /dev/null +++ b/djvm/src/main/java/sandbox/java/lang/StringBuilder.java @@ -0,0 +1,20 @@ +package sandbox.java.lang; + +import java.io.Serializable; + +/** + * This is a dummy class that implements just enough of [java.lang.StringBuilder] + * to allow us to compile [sandbox.java.lang.String]. + */ +public abstract class StringBuilder extends Object implements Appendable, CharSequence, Serializable { + + @Override + public abstract StringBuilder append(CharSequence seq); + + @Override + public abstract StringBuilder append(CharSequence seq, int start, int end); + + @Override + public abstract StringBuilder append(char c); + +} diff --git a/djvm/src/main/java/sandbox/java/lang/System.java b/djvm/src/main/java/sandbox/java/lang/System.java new file mode 100644 index 0000000000..95525d0b50 --- /dev/null +++ b/djvm/src/main/java/sandbox/java/lang/System.java @@ -0,0 +1,28 @@ +package sandbox.java.lang; + +@SuppressWarnings({"WeakerAccess", "unused"}) +public final class System extends Object { + + private System() {} + + /* + * This class is duplicated into every sandbox, where everything is single-threaded. + */ + private static final java.util.Map objectHashCodes = new java.util.LinkedHashMap<>(); + private static int objectCounter = 0; + + public static int identityHashCode(java.lang.Object obj) { + int nativeHashCode = java.lang.System.identityHashCode(obj); + // TODO Instead of using a magic offset below, one could take in a per-context seed + return objectHashCodes.computeIfAbsent(nativeHashCode, i -> ++objectCounter + 0xfed_c0de); + } + + public static final String lineSeparator = String.toDJVM("\n"); + + public static void arraycopy(java.lang.Object src, int srcPos, java.lang.Object dest, int destPos, int length) { + java.lang.System.arraycopy(src, srcPos, dest, destPos, length); + } + + public static void runFinalization() {} + public static void gc() {} +} diff --git a/djvm/src/main/java/sandbox/java/lang/ThreadLocal.java b/djvm/src/main/java/sandbox/java/lang/ThreadLocal.java new file mode 100644 index 0000000000..f416d3db16 --- /dev/null +++ b/djvm/src/main/java/sandbox/java/lang/ThreadLocal.java @@ -0,0 +1,59 @@ +package sandbox.java.lang; + +import sandbox.java.util.function.Supplier; + +/** + * Everything inside the sandbox is single-threaded, so this + * implementation of ThreadLocal is sufficient. + * @param + */ +@SuppressWarnings({"unused", "WeakerAccess"}) +public class ThreadLocal extends Object { + + private T value; + private boolean isSet; + + public ThreadLocal() { + } + + protected T initialValue() { + return null; + } + + public T get() { + if (!isSet) { + set(initialValue()); + } + return value; + } + + public void set(T value) { + this.value = value; + this.isSet = true; + } + + public void remove() { + value = null; + isSet = false; + } + + public static ThreadLocal withInitial(Supplier supplier) { + return new SuppliedThreadLocal<>(supplier); + } + + // Stub class for compiling ThreadLocal. The sandbox will import the + // actual SuppliedThreadLocal class at run-time. Having said that, we + // still need a working implementation here for the sake of our tests. + static final class SuppliedThreadLocal extends ThreadLocal { + private final Supplier supplier; + + SuppliedThreadLocal(Supplier supplier) { + this.supplier = supplier; + } + + @Override + protected T initialValue() { + return supplier.get(); + } + } +} diff --git a/djvm/src/main/java/sandbox/java/nio/charset/Charset.java b/djvm/src/main/java/sandbox/java/nio/charset/Charset.java new file mode 100644 index 0000000000..371a21404a --- /dev/null +++ b/djvm/src/main/java/sandbox/java/nio/charset/Charset.java @@ -0,0 +1,18 @@ +package sandbox.java.nio.charset; + +/** + * This is a dummy class that implements just enough of [java.nio.charset.Charset] + * to allow us to compile [sandbox.java.lang.String]. + */ +@SuppressWarnings("unused") +public abstract class Charset extends sandbox.java.lang.Object { + private final sandbox.java.lang.String canonicalName; + + protected Charset(sandbox.java.lang.String canonicalName, sandbox.java.lang.String[] aliases) { + this.canonicalName = canonicalName; + } + + public final sandbox.java.lang.String name() { + return canonicalName; + } +} diff --git a/djvm/src/main/java/sandbox/java/util/Comparator.java b/djvm/src/main/java/sandbox/java/util/Comparator.java new file mode 100644 index 0000000000..20679dee59 --- /dev/null +++ b/djvm/src/main/java/sandbox/java/util/Comparator.java @@ -0,0 +1,9 @@ +package sandbox.java.util; + +/** + * This is a dummy class that implements just enough of [java.util.Comparator] + * to allow us to compile [sandbox.java.lang.String]. + */ +@FunctionalInterface +public interface Comparator extends java.util.Comparator { +} diff --git a/djvm/src/main/java/sandbox/java/util/LinkedHashMap.java b/djvm/src/main/java/sandbox/java/util/LinkedHashMap.java new file mode 100644 index 0000000000..37d8c56210 --- /dev/null +++ b/djvm/src/main/java/sandbox/java/util/LinkedHashMap.java @@ -0,0 +1,13 @@ +package sandbox.java.util; + +/** + * This is a dummy class to bootstrap us into the sandbox. + */ +public class LinkedHashMap extends java.util.LinkedHashMap implements Map { + public LinkedHashMap(int initialSize) { + super(initialSize); + } + + public LinkedHashMap() { + } +} diff --git a/djvm/src/main/java/sandbox/java/util/Locale.java b/djvm/src/main/java/sandbox/java/util/Locale.java new file mode 100644 index 0000000000..3ceaea9382 --- /dev/null +++ b/djvm/src/main/java/sandbox/java/util/Locale.java @@ -0,0 +1,9 @@ +package sandbox.java.util; + +/** + * This is a dummy class that implements just enough of [java.util.Locale] + * to allow us to compile [sandbox.java.lang.String]. + */ +public abstract class Locale extends sandbox.java.lang.Object { + public abstract sandbox.java.lang.String toLanguageTag(); +} diff --git a/djvm/src/main/java/sandbox/java/util/Map.java b/djvm/src/main/java/sandbox/java/util/Map.java new file mode 100644 index 0000000000..576e462583 --- /dev/null +++ b/djvm/src/main/java/sandbox/java/util/Map.java @@ -0,0 +1,7 @@ +package sandbox.java.util; + +/** + * This is a dummy class to bootstrap us into the sandbox. + */ +public interface Map extends java.util.Map { +} diff --git a/djvm/src/main/java/sandbox/java/util/function/Function.java b/djvm/src/main/java/sandbox/java/util/function/Function.java new file mode 100644 index 0000000000..5cd806a01e --- /dev/null +++ b/djvm/src/main/java/sandbox/java/util/function/Function.java @@ -0,0 +1,10 @@ +package sandbox.java.util.function; + +/** + * This is a dummy class that implements just enough of [java.util.function.Function] + * to allow us to compile [sandbox.Task]. + */ +@FunctionalInterface +public interface Function { + R apply(T item); +} diff --git a/djvm/src/main/java/sandbox/java/util/function/Supplier.java b/djvm/src/main/java/sandbox/java/util/function/Supplier.java new file mode 100644 index 0000000000..31f236bae6 --- /dev/null +++ b/djvm/src/main/java/sandbox/java/util/function/Supplier.java @@ -0,0 +1,10 @@ +package sandbox.java.util.function; + +/** + * This is a dummy class that implements just enough of [java.util.function.Supplier] + * to allow us to compile [sandbox.java.lang.ThreadLocal]. + */ +@FunctionalInterface +public interface Supplier { + T get(); +} diff --git a/djvm/src/main/java/sandbox/sun/misc/JavaLangAccess.java b/djvm/src/main/java/sandbox/sun/misc/JavaLangAccess.java new file mode 100644 index 0000000000..189a7f9711 --- /dev/null +++ b/djvm/src/main/java/sandbox/sun/misc/JavaLangAccess.java @@ -0,0 +1,10 @@ +package sandbox.sun.misc; + +import sandbox.java.lang.Enum; + +@SuppressWarnings("unused") +public interface JavaLangAccess { + + > E[] getEnumConstantsShared(Class enumClass); + +} diff --git a/djvm/src/main/java/sandbox/sun/misc/SharedSecrets.java b/djvm/src/main/java/sandbox/sun/misc/SharedSecrets.java new file mode 100644 index 0000000000..a03f7689c1 --- /dev/null +++ b/djvm/src/main/java/sandbox/sun/misc/SharedSecrets.java @@ -0,0 +1,20 @@ +package sandbox.sun.misc; + +import sandbox.java.lang.Enum; + +@SuppressWarnings("unused") +public class SharedSecrets extends sandbox.java.lang.Object { + private static final JavaLangAccess javaLangAccess = new JavaLangAccessImpl(); + + private static class JavaLangAccessImpl implements JavaLangAccess { + @SuppressWarnings("unchecked") + @Override + public > E[] getEnumConstantsShared(Class enumClass) { + return (E[]) sandbox.java.lang.DJVM.getEnumConstantsShared(enumClass); + } + } + + public static JavaLangAccess getJavaLangAccess() { + return javaLangAccess; + } +} diff --git a/djvm/src/main/kotlin/net/corda/djvm/analysis/AnalysisConfiguration.kt b/djvm/src/main/kotlin/net/corda/djvm/analysis/AnalysisConfiguration.kt index 2a1e7d63cf..f8d87fd1ea 100644 --- a/djvm/src/main/kotlin/net/corda/djvm/analysis/AnalysisConfiguration.kt +++ b/djvm/src/main/kotlin/net/corda/djvm/analysis/AnalysisConfiguration.kt @@ -1,12 +1,17 @@ package net.corda.djvm.analysis +import net.corda.djvm.code.EmitterModule import net.corda.djvm.code.ruleViolationError import net.corda.djvm.code.thresholdViolationError import net.corda.djvm.messages.Severity import net.corda.djvm.references.ClassModule +import net.corda.djvm.references.Member import net.corda.djvm.references.MemberModule +import net.corda.djvm.references.MethodBody import net.corda.djvm.source.BootstrapClassLoader import net.corda.djvm.source.SourceClassLoader +import org.objectweb.asm.Opcodes.* +import org.objectweb.asm.Type import sandbox.net.corda.djvm.costing.RuntimeCostAccounter import java.io.Closeable import java.io.IOException @@ -41,19 +46,22 @@ class AnalysisConfiguration( /** * Classes that have already been declared in the sandbox namespace and that should be made - * available inside the sandboxed environment. + * available inside the sandboxed environment. These classes belong to the application + * classloader and so are shared across all sandboxes. */ - val pinnedClasses: Set = setOf( - SANDBOXED_OBJECT, - RuntimeCostAccounter.TYPE_NAME, - ruleViolationError, - thresholdViolationError - ) + additionalPinnedClasses + val pinnedClasses: Set = MANDATORY_PINNED_CLASSES + additionalPinnedClasses + + /** + * These interfaces are modified as they are mapped into the sandbox by + * having their unsandboxed version "stitched in" as a super-interface. + * And in some cases, we need to add some synthetic bridge methods as well. + */ + val stitchedInterfaces: Map> get() = STITCHED_INTERFACES /** * Functionality used to resolve the qualified name and relevant information about a class. */ - val classResolver: ClassResolver = ClassResolver(pinnedClasses, whitelist, SANDBOX_PREFIX) + val classResolver: ClassResolver = ClassResolver(pinnedClasses, TEMPLATE_CLASSES, whitelist, SANDBOX_PREFIX) private val bootstrapClassLoader = bootstrapJar?.let { BootstrapClassLoader(it, classResolver) } val supportingClassLoader = SourceClassLoader(classPath, classResolver, bootstrapClassLoader) @@ -65,13 +73,114 @@ class AnalysisConfiguration( } } + fun isTemplateClass(className: String): Boolean = className in TEMPLATE_CLASSES + fun isPinnedClass(className: String): Boolean = className in pinnedClasses + companion object { /** * The package name prefix to use for classes loaded into a sandbox. */ private const val SANDBOX_PREFIX: String = "sandbox/" - private const val SANDBOXED_OBJECT = SANDBOX_PREFIX + "java/lang/Object" + /** + * These class must belong to the application class loader. + * They should already exist within the sandbox namespace. + */ + private val MANDATORY_PINNED_CLASSES: Set = setOf( + RuntimeCostAccounter.TYPE_NAME, + ruleViolationError, + thresholdViolationError + ) + + /** + * These classes will be duplicated into every sandbox's + * classloader. + */ + private val TEMPLATE_CLASSES: Set = setOf( + java.lang.Boolean::class.java, + java.lang.Byte::class.java, + java.lang.Character::class.java, + java.lang.Double::class.java, + java.lang.Float::class.java, + java.lang.Integer::class.java, + java.lang.Long::class.java, + java.lang.Number::class.java, + java.lang.Runtime::class.java, + java.lang.Short::class.java, + java.lang.String::class.java, + java.lang.String.CASE_INSENSITIVE_ORDER::class.java, + java.lang.System::class.java, + java.lang.ThreadLocal::class.java, + kotlin.Any::class.java, + sun.misc.JavaLangAccess::class.java, + sun.misc.SharedSecrets::class.java + ).sandboxed() + setOf( + "sandbox/Task", + "sandbox/java/lang/DJVM", + "sandbox/sun/misc/SharedSecrets\$1", + "sandbox/sun/misc/SharedSecrets\$JavaLangAccessImpl" + ) + + /** + * These interfaces will be modified as follows when + * added to the sandbox: + * + * interface sandbox.A extends A + */ + private val STITCHED_INTERFACES: Map> = mapOf( + sandboxed(CharSequence::class.java) to listOf( + object : MethodBuilder( + access = ACC_PUBLIC or ACC_SYNTHETIC or ACC_BRIDGE, + className = "sandbox/java/lang/CharSequence", + memberName = "subSequence", + descriptor = "(II)Ljava/lang/CharSequence;" + ) { + override fun writeBody(emitter: EmitterModule) = with(emitter) { + pushObject(0) + pushInteger(1) + pushInteger(2) + invokeInterface(className, memberName, "(II)L$className;") + returnObject() + } + }.withBody() + .build(), + MethodBuilder( + access = ACC_PUBLIC or ACC_ABSTRACT, + className = "sandbox/java/lang/CharSequence", + memberName = "toString", + descriptor = "()Ljava/lang/String;" + ).build() + ), + sandboxed(Comparable::class.java) to emptyList(), + sandboxed(Comparator::class.java) to emptyList(), + sandboxed(Iterable::class.java) to emptyList() + ) + + private fun sandboxed(clazz: Class<*>) = SANDBOX_PREFIX + Type.getInternalName(clazz) + private fun Set>.sandboxed(): Set = map(Companion::sandboxed).toSet() } + private open class MethodBuilder( + protected val access: Int, + protected val className: String, + protected val memberName: String, + protected val descriptor: String) { + private val bodies = mutableListOf() + + protected open fun writeBody(emitter: EmitterModule) {} + + fun withBody(): MethodBuilder { + bodies.add(::writeBody) + return this + } + + fun build() = Member( + access = access, + className = className, + memberName = memberName, + signature = descriptor, + genericsDetails = "", + body = bodies + ) + } } diff --git a/djvm/src/main/kotlin/net/corda/djvm/analysis/ClassAndMemberVisitor.kt b/djvm/src/main/kotlin/net/corda/djvm/analysis/ClassAndMemberVisitor.kt index d0d9cb4e8c..8bfb997ae7 100644 --- a/djvm/src/main/kotlin/net/corda/djvm/analysis/ClassAndMemberVisitor.kt +++ b/djvm/src/main/kotlin/net/corda/djvm/analysis/ClassAndMemberVisitor.kt @@ -85,8 +85,12 @@ open class ClassAndMemberVisitor( /** * Process class after it has been fully traversed and analyzed. + * The [classVisitor] has finished visiting all of the class's + * existing elements (i.e. methods, fields, inner classes etc) + * and is about to complete. However, it can still add new + * elements to the class, if required. */ - open fun visitClassEnd(clazz: ClassRepresentation) {} + open fun visitClassEnd(classVisitor: ClassVisitor, clazz: ClassRepresentation) {} /** * Extract the meta-data indicating the source file of the traversed class (i.e., where it is compiled from). @@ -136,7 +140,7 @@ open class ClassAndMemberVisitor( */ protected fun shouldBeProcessed(className: String): Boolean { return !configuration.whitelist.inNamespace(className) && - className !in configuration.pinnedClasses + !configuration.isPinnedClass(className) } /** @@ -241,7 +245,7 @@ open class ClassAndMemberVisitor( .getClassReferencesFromClass(currentClass!!, configuration.analyzeAnnotations) .forEach(::recordTypeReference) captureExceptions { - visitClassEnd(currentClass!!) + visitClassEnd(this, currentClass!!) } super.visitEnd() } @@ -385,7 +389,9 @@ open class ClassAndMemberVisitor( */ override fun visitCode() { tryReplaceMethodBody() - super.visitCode() + visit(MethodEntry(method)) { + super.visitCode() + } } /** @@ -494,6 +500,15 @@ open class ClassAndMemberVisitor( } } + /** + * Transform values loaded from the constants pool. + */ + override fun visitLdcInsn(value: Any) { + visit(ConstantInstruction(value), defaultFirst = true) { + super.visitLdcInsn(value) + } + } + /** * Finish visiting this method, writing any new method body byte-code * if we haven't written it already. This would (presumably) only happen diff --git a/djvm/src/main/kotlin/net/corda/djvm/analysis/ClassResolver.kt b/djvm/src/main/kotlin/net/corda/djvm/analysis/ClassResolver.kt index b1aa3ae541..a05b4ec7ec 100644 --- a/djvm/src/main/kotlin/net/corda/djvm/analysis/ClassResolver.kt +++ b/djvm/src/main/kotlin/net/corda/djvm/analysis/ClassResolver.kt @@ -26,6 +26,7 @@ import net.corda.djvm.code.asResourcePath */ class ClassResolver( private val pinnedClasses: Set, + private val templateClasses: Set, private val whitelist: Whitelist, private val sandboxPrefix: String ) { @@ -83,7 +84,7 @@ class ClassResolver( * Reverse the resolution of a class name. */ fun reverse(resolvedClassName: String): String { - if (resolvedClassName in pinnedClasses) { + if (resolvedClassName in pinnedClasses || resolvedClassName in templateClasses) { return resolvedClassName } if (resolvedClassName.startsWith(sandboxPrefix)) { @@ -103,10 +104,10 @@ class ClassResolver( } /** - * Resolve class name from a fully qualified name. + * Resolve sandboxed class name from a fully qualified name. */ private fun resolveName(name: String): String { - return if (isPinnedOrWhitelistedClass(name)) { + return if (isPinnedOrWhitelistedClass(name) || name in templateClasses) { name } else { "$sandboxPrefix$name" @@ -122,10 +123,10 @@ class ClassResolver( sandboxRegex.matches(name) } - private val sandboxRegex = "^$sandboxPrefix.*$".toRegex() + private val sandboxRegex = "^$sandboxPrefix.*\$".toRegex() companion object { - private val complexArrayTypeRegex = "^(\\[+)L(.*);$".toRegex() + private val complexArrayTypeRegex = "^(\\[+)L(.*);\$".toRegex() } } \ No newline at end of file diff --git a/djvm/src/main/kotlin/net/corda/djvm/analysis/Whitelist.kt b/djvm/src/main/kotlin/net/corda/djvm/analysis/Whitelist.kt index 3cbbfe8223..c19cc8111e 100644 --- a/djvm/src/main/kotlin/net/corda/djvm/analysis/Whitelist.kt +++ b/djvm/src/main/kotlin/net/corda/djvm/analysis/Whitelist.kt @@ -89,36 +89,25 @@ open class Whitelist private constructor( * Enumerate all the entries of the whitelist. */ val items: Set - get() = textEntries + entries.map { it.pattern } + get() = textEntries + entries.map(Regex::pattern) companion object { private val everythingRegex = setOf(".*".toRegex()) private val minimumSet = setOf( - "^java/lang/Boolean(\\..*)?$".toRegex(), - "^java/lang/Byte(\\..*)?$".toRegex(), - "^java/lang/Character(\\..*)?$".toRegex(), - "^java/lang/Class(\\..*)?$".toRegex(), - "^java/lang/ClassLoader(\\..*)?$".toRegex(), - "^java/lang/Cloneable(\\..*)?$".toRegex(), - "^java/lang/Comparable(\\..*)?$".toRegex(), - "^java/lang/Double(\\..*)?$".toRegex(), - "^java/lang/Enum(\\..*)?$".toRegex(), - "^java/lang/Float(\\..*)?$".toRegex(), - "^java/lang/Integer(\\..*)?$".toRegex(), - "^java/lang/Iterable(\\..*)?$".toRegex(), - "^java/lang/Long(\\..*)?$".toRegex(), - "^java/lang/Number(\\..*)?$".toRegex(), - "^java/lang/Object(\\..*)?$".toRegex(), - "^java/lang/Override(\\..*)?$".toRegex(), - "^java/lang/Short(\\..*)?$".toRegex(), - "^java/lang/String(\\..*)?$".toRegex(), - "^java/lang/ThreadDeath(\\..*)?$".toRegex(), - "^java/lang/Throwable(\\..*)?$".toRegex(), - "^java/lang/Void(\\..*)?$".toRegex(), - "^java/lang/.*Error(\\..*)?$".toRegex(), - "^java/lang/.*Exception(\\..*)?$".toRegex(), - "^java/lang/reflect/Array(\\..*)?$".toRegex() + "^java/lang/Class(\\..*)?\$".toRegex(), + "^java/lang/ClassLoader(\\..*)?\$".toRegex(), + "^java/lang/Cloneable(\\..*)?\$".toRegex(), + "^java/lang/Object(\\..*)?\$".toRegex(), + "^java/lang/Override(\\..*)?\$".toRegex(), + // TODO: sandbox exception handling! + "^java/lang/StackTraceElement\$".toRegex(), + "^java/lang/Throwable\$".toRegex(), + "^java/lang/Void\$".toRegex(), + "^java/lang/invoke/LambdaMetafactory\$".toRegex(), + "^java/lang/invoke/MethodHandles(\\\$.*)?\$".toRegex(), + "^java/lang/reflect/Array(\\..*)?\$".toRegex(), + "^java/io/Serializable\$".toRegex() ) /** diff --git a/djvm/src/main/kotlin/net/corda/djvm/code/ClassMutator.kt b/djvm/src/main/kotlin/net/corda/djvm/code/ClassMutator.kt index 777e69f9fe..3c800d9859 100644 --- a/djvm/src/main/kotlin/net/corda/djvm/code/ClassMutator.kt +++ b/djvm/src/main/kotlin/net/corda/djvm/code/ClassMutator.kt @@ -2,26 +2,48 @@ package net.corda.djvm.code import net.corda.djvm.analysis.AnalysisConfiguration import net.corda.djvm.analysis.ClassAndMemberVisitor +import net.corda.djvm.code.instructions.MethodEntry import net.corda.djvm.references.ClassRepresentation import net.corda.djvm.references.Member +import net.corda.djvm.references.MethodBody import net.corda.djvm.utilities.Processor import net.corda.djvm.utilities.loggerFor import org.objectweb.asm.ClassVisitor +import org.objectweb.asm.Opcodes.* /** * Helper class for applying a set of definition providers and emitters to a class or set of classes. * * @param classVisitor Class visitor to use when traversing the structure of classes. + * @property configuration The configuration to use for class analysis. * @property definitionProviders A set of providers used to update the name or meta-data of classes and members. - * @property emitters A set of code emitters used to modify and instrument method bodies. + * @param emitters A set of code emitters used to modify and instrument method bodies. */ class ClassMutator( classVisitor: ClassVisitor, private val configuration: AnalysisConfiguration, private val definitionProviders: List = emptyList(), - private val emitters: List = emptyList() + emitters: List = emptyList() ) : ClassAndMemberVisitor(configuration, classVisitor) { + /** + * Internal [Emitter] to add static field initializers to + * any class constructor method. + */ + private inner class PrependClassInitializer : Emitter { + override fun emit(context: EmitterContext, instruction: Instruction) = context.emit { + if (instruction is MethodEntry + && instruction.method.memberName == "" && instruction.method.signature == "()V" + && initializers.isNotEmpty()) { + writeByteCode(initializers) + initializers.clear() + } + } + } + + private val emitters: List = emitters + PrependClassInitializer() + private val initializers = mutableListOf() + /** * Tracks whether any modifications have been applied to any of the processed class(es) and pertinent members. */ @@ -44,6 +66,29 @@ class ClassMutator( return super.visitClass(resultingClass) } + /** + * If we have some static fields to initialise, and haven't already added them + * to an existing class initialiser block then we need to create one. + */ + override fun visitClassEnd(classVisitor: ClassVisitor, clazz: ClassRepresentation) { + tryWriteClassInitializer(classVisitor) + super.visitClassEnd(classVisitor, clazz) + } + + private fun tryWriteClassInitializer(classVisitor: ClassVisitor) { + if (initializers.isNotEmpty()) { + classVisitor.visitMethod(ACC_STATIC, "", "()V", null, null)?.also { mv -> + mv.visitCode() + EmitterModule(mv).writeByteCode(initializers) + mv.visitInsn(RETURN) + mv.visitMaxs(-1, -1) + mv.visitEnd() + } + initializers.clear() + hasBeenModified = true + } + } + /** * Apply definition providers to a method. This can be used to update the name or definition (pertinent meta-data) * of a class member. @@ -71,6 +116,7 @@ class ClassMutator( } if (field != resultingField) { logger.trace("Field has been mutated {}", field) + initializers += resultingField.body hasBeenModified = true } return super.visitField(clazz, resultingField) diff --git a/djvm/src/main/kotlin/net/corda/djvm/code/EmitterModule.kt b/djvm/src/main/kotlin/net/corda/djvm/code/EmitterModule.kt index afe9b5165d..2e2d2fc2e4 100644 --- a/djvm/src/main/kotlin/net/corda/djvm/code/EmitterModule.kt +++ b/djvm/src/main/kotlin/net/corda/djvm/code/EmitterModule.kt @@ -1,5 +1,6 @@ package net.corda.djvm.code +import net.corda.djvm.references.MethodBody import org.objectweb.asm.Label import org.objectweb.asm.MethodVisitor import org.objectweb.asm.Opcodes.* @@ -44,17 +45,9 @@ class EmitterModule( } /** - * Emit instruction for loading an integer constant onto the stack. + * Emit instruction for loading a constant onto the stack. */ - fun loadConstant(constant: Int) { - hasEmittedCustomCode = true - methodVisitor.visitLdcInsn(constant) - } - - /** - * Emit instruction for loading a string constant onto the stack. - */ - fun loadConstant(constant: String) { + fun loadConstant(constant: Any) { hasEmittedCustomCode = true methodVisitor.visitLdcInsn(constant) } @@ -67,6 +60,14 @@ class EmitterModule( methodVisitor.visitMethodInsn(INVOKESTATIC, owner, name, descriptor, isInterface) } + /** + * Emit instruction for invoking a virtual method. + */ + fun invokeVirtual(owner: String, name: String, descriptor: String, isInterface: Boolean = false) { + hasEmittedCustomCode = true + methodVisitor.visitMethodInsn(INVOKEVIRTUAL, owner, name, descriptor, isInterface) + } + /** * Emit instruction for invoking a special method, e.g. a constructor or a method on a super-type. */ @@ -82,6 +83,19 @@ class EmitterModule( invokeSpecial(Type.getInternalName(T::class.java), name, descriptor, isInterface) } + fun invokeInterface(owner: String, name: String, descriptor: String) { + methodVisitor.visitMethodInsn(INVOKEINTERFACE, owner, name, descriptor, true) + hasEmittedCustomCode = true + } + + /** + * Emit instruction for storing a value into a static field. + */ + fun putStatic(owner: String, name: String, descriptor: String) { + methodVisitor.visitFieldInsn(PUTSTATIC, owner, name, descriptor) + hasEmittedCustomCode = true + } + /** * Emit instruction for popping one element off the stack. */ @@ -98,11 +112,52 @@ class EmitterModule( methodVisitor.visitInsn(DUP) } + /** + * Emit instruction for pushing an object reference + * from a register onto the stack. + */ + fun pushObject(regNum: Int) { + methodVisitor.visitVarInsn(ALOAD, regNum) + hasEmittedCustomCode = true + } + + /** + * Emit instruction for pushing an integer value + * from a register onto the stack. + */ + fun pushInteger(regNum: Int) { + methodVisitor.visitVarInsn(ILOAD, regNum) + hasEmittedCustomCode = true + } + + /** + * Emit instructions to rearrange the stack as follows: + * [W1] [W3] + * [W2] -> [W1] + * [w3] [W2] + */ + fun raiseThirdWordToTop() { + methodVisitor.visitInsn(DUP2_X1) + methodVisitor.visitInsn(POP2) + hasEmittedCustomCode = true + } + + /** + * Emit instructions to rearrange the stack as follows: + * [W1] [W2] + * [W2] -> [W3] + * [W3] [W1] + */ + fun sinkTopToThirdWord() { + methodVisitor.visitInsn(DUP_X2) + methodVisitor.visitInsn(POP) + hasEmittedCustomCode = true + } + /** * Emit a sequence of instructions for instantiating and throwing an exception based on the provided message. */ fun throwException(exceptionType: Class, message: String) { - hasEmittedCustomCode = true val exceptionName = Type.getInternalName(exceptionType) new(exceptionName) methodVisitor.visitInsn(DUP) @@ -121,6 +176,14 @@ class EmitterModule( hasEmittedCustomCode = true } + /** + * Emit instruction for a function that returns an object reference. + */ + fun returnObject() { + methodVisitor.visitInsn(ARETURN) + hasEmittedCustomCode = true + } + /** * Emit instructions for a new line number. */ @@ -131,6 +194,15 @@ class EmitterModule( hasEmittedCustomCode = true } + /** + * Write the bytecode from these [MethodBody] objects as provided. + */ + fun writeByteCode(bodies: Iterable) { + for (body in bodies) { + body(this) + } + } + /** * Tell the code writer not to emit the default instruction. */ diff --git a/djvm/src/main/kotlin/net/corda/djvm/code/Types.kt b/djvm/src/main/kotlin/net/corda/djvm/code/Types.kt index e137f196d5..93a9c5bf7d 100644 --- a/djvm/src/main/kotlin/net/corda/djvm/code/Types.kt +++ b/djvm/src/main/kotlin/net/corda/djvm/code/Types.kt @@ -12,4 +12,6 @@ val thresholdViolationError: String = Type.getInternalName(ThresholdViolationErr * Local extension method for normalizing a class name. */ val String.asPackagePath: String get() = this.replace('/', '.') -val String.asResourcePath: String get() = this.replace('.', '/') \ No newline at end of file +val String.asResourcePath: String get() = this.replace('.', '/') + +val String.emptyAsNull: String? get() = if (isEmpty()) null else this \ No newline at end of file diff --git a/djvm/src/main/kotlin/net/corda/djvm/code/instructions/ConstantInstruction.kt b/djvm/src/main/kotlin/net/corda/djvm/code/instructions/ConstantInstruction.kt new file mode 100644 index 0000000000..ebd90d6f02 --- /dev/null +++ b/djvm/src/main/kotlin/net/corda/djvm/code/instructions/ConstantInstruction.kt @@ -0,0 +1,6 @@ +package net.corda.djvm.code.instructions + +import net.corda.djvm.code.Instruction +import org.objectweb.asm.Opcodes + +class ConstantInstruction(val value: Any) : Instruction(Opcodes.LDC) \ No newline at end of file diff --git a/djvm/src/main/kotlin/net/corda/djvm/code/instructions/MethodEntry.kt b/djvm/src/main/kotlin/net/corda/djvm/code/instructions/MethodEntry.kt new file mode 100644 index 0000000000..cde092c056 --- /dev/null +++ b/djvm/src/main/kotlin/net/corda/djvm/code/instructions/MethodEntry.kt @@ -0,0 +1,9 @@ +package net.corda.djvm.code.instructions + +import net.corda.djvm.references.Member + +/** + * Pseudo-instruction marking the beginning of a method. + * @property method [Member] describing this method. + */ +class MethodEntry(val method: Member): NoOperationInstruction() \ No newline at end of file diff --git a/djvm/src/main/kotlin/net/corda/djvm/execution/SandboxExecutor.kt b/djvm/src/main/kotlin/net/corda/djvm/execution/SandboxExecutor.kt index b69585538f..245e00902d 100644 --- a/djvm/src/main/kotlin/net/corda/djvm/execution/SandboxExecutor.kt +++ b/djvm/src/main/kotlin/net/corda/djvm/execution/SandboxExecutor.kt @@ -5,6 +5,7 @@ import net.corda.djvm.analysis.AnalysisContext import net.corda.djvm.messages.Message import net.corda.djvm.references.ClassReference import net.corda.djvm.references.MemberReference +import net.corda.djvm.references.ReferenceWithLocation import net.corda.djvm.rewiring.LoadedClass import net.corda.djvm.rewiring.SandboxClassLoader import net.corda.djvm.rewiring.SandboxClassLoadingException @@ -67,12 +68,24 @@ open class SandboxExecutor( val context = AnalysisContext.fromConfiguration(configuration.analysisConfiguration) val result = IsolatedTask(runnableClass.qualifiedClassName, configuration).run { validate(context, classLoader, classSources) - val loadedClass = classLoader.loadClassAndBytes(runnableClass, context) - val instance = loadedClass.type.newInstance() - val method = loadedClass.type.getMethod("apply", Any::class.java) + + // Load the "entry-point" task class into the sandbox. This task will marshall + // the input and outputs between Java types and sandbox wrapper types. + val taskClass = Class.forName("sandbox.Task", false, classLoader) + + // Create the user's task object inside the sandbox. + val runnable = classLoader.loadForSandbox(runnableClass, context).type.newInstance() + + // Fetch this sandbox's instance of Class so we can retrieve Task(Function) + // and then instantiate the Task. + val functionClass = Class.forName("sandbox.java.util.function.Function", false, classLoader) + val task = taskClass.getDeclaredConstructor(functionClass).newInstance(runnable) + + // Execute the task... + val method = taskClass.getMethod("apply", Any::class.java) try { @Suppress("UNCHECKED_CAST") - method.invoke(instance, input) as? TOutput + method.invoke(task, input) as? TOutput } catch (ex: InvocationTargetException) { throw ex.targetException } @@ -101,7 +114,7 @@ open class SandboxExecutor( fun load(classSource: ClassSource): LoadedClass { val context = AnalysisContext.fromConfiguration(configuration.analysisConfiguration) val result = IsolatedTask("LoadClass", configuration).run { - classLoader.loadClassAndBytes(classSource, context) + classLoader.loadForSandbox(classSource, context) } return result.output ?: throw ClassNotFoundException(classSource.qualifiedClassName) } @@ -146,7 +159,7 @@ open class SandboxExecutor( ): ReferenceValidationSummary { processClassQueue(*classSources.toTypedArray()) { classSource, className -> val didLoad = try { - classLoader.loadClassAndBytes(classSource, context) + classLoader.loadForSandbox(classSource, context) true } catch (exception: SandboxClassLoadingException) { // Continue; all warnings and errors are captured in [context.messages] @@ -155,7 +168,7 @@ open class SandboxExecutor( if (didLoad) { context.classes[className]?.apply { context.references.referencesFromLocation(className) - .map { it.reference } + .map(ReferenceWithLocation::reference) .filterIsInstance() .filter { it.className != className } .distinct() @@ -201,6 +214,7 @@ open class SandboxExecutor( } } - private val logger = loggerFor>() - + private companion object { + private val logger = loggerFor>() + } } diff --git a/djvm/src/main/kotlin/net/corda/djvm/rewiring/ClassRewriter.kt b/djvm/src/main/kotlin/net/corda/djvm/rewiring/ClassRewriter.kt index 473718512a..081bff4fa5 100644 --- a/djvm/src/main/kotlin/net/corda/djvm/rewiring/ClassRewriter.kt +++ b/djvm/src/main/kotlin/net/corda/djvm/rewiring/ClassRewriter.kt @@ -1,23 +1,26 @@ package net.corda.djvm.rewiring import net.corda.djvm.SandboxConfiguration +import net.corda.djvm.analysis.AnalysisConfiguration import net.corda.djvm.analysis.AnalysisContext +import net.corda.djvm.analysis.ClassAndMemberVisitor.Companion.API_VERSION import net.corda.djvm.code.ClassMutator +import net.corda.djvm.code.EmitterModule +import net.corda.djvm.code.emptyAsNull +import net.corda.djvm.references.Member import net.corda.djvm.utilities.loggerFor import org.objectweb.asm.ClassReader -import org.objectweb.asm.commons.ClassRemapper +import org.objectweb.asm.ClassVisitor /** - * Functionality for rewrite parts of a class as it is being loaded. + * Functionality for rewriting parts of a class as it is being loaded. * * @property configuration The configuration of the sandbox. * @property classLoader The class loader used to load the classes that are to be rewritten. - * @property remapper A sandbox-aware remapper for inspecting and correcting type names and descriptors. */ open class ClassRewriter( private val configuration: SandboxConfiguration, - private val classLoader: ClassLoader, - private val remapper: SandboxRemapper = SandboxRemapper(configuration.analysisConfiguration.classResolver) + private val classLoader: ClassLoader ) { /** @@ -29,20 +32,53 @@ open class ClassRewriter( fun rewrite(reader: ClassReader, context: AnalysisContext): ByteCode { logger.debug("Rewriting class {}...", reader.className) val writer = SandboxClassWriter(reader, classLoader) - val classRemapper = ClassRemapper(writer, remapper) + val analysisConfiguration = configuration.analysisConfiguration + val classRemapper = SandboxClassRemapper(InterfaceStitcher(writer, analysisConfiguration), analysisConfiguration) val visitor = ClassMutator( classRemapper, - configuration.analysisConfiguration, + analysisConfiguration, configuration.definitionProviders, configuration.emitters ) visitor.analyze(reader, context, options = ClassReader.EXPAND_FRAMES) - val hasBeenModified = visitor.hasBeenModified - return ByteCode(writer.toByteArray(), hasBeenModified) + return ByteCode(writer.toByteArray(), visitor.hasBeenModified) } private companion object { private val logger = loggerFor() } + /** + * Extra visitor that is applied after [SandboxRemapper]. This "stitches" the original + * unmapped interface as a super-interface of the mapped version. + */ + private class InterfaceStitcher(parent: ClassVisitor, private val configuration: AnalysisConfiguration) + : ClassVisitor(API_VERSION, parent) + { + private val extraMethods = mutableListOf() + + override fun visit(version: Int, access: Int, className: String, signature: String?, superName: String?, interfaces: Array?) { + val stitchedInterfaces = configuration.stitchedInterfaces[className]?.let { methods -> + extraMethods += methods + arrayOf(*(interfaces ?: emptyArray()), configuration.classResolver.reverse(className)) + } ?: interfaces + + super.visit(version, access, className, signature, superName, stitchedInterfaces) + } + + override fun visitEnd() { + for (method in extraMethods) { + method.apply { + visitMethod(access, memberName, signature, genericsDetails.emptyAsNull, exceptions.toTypedArray())?.also { mv -> + mv.visitCode() + EmitterModule(mv).writeByteCode(body) + mv.visitMaxs(-1, -1) + mv.visitEnd() + } + } + } + extraMethods.clear() + super.visitEnd() + } + } } diff --git a/djvm/src/main/kotlin/net/corda/djvm/rewiring/SandboxClassLoader.kt b/djvm/src/main/kotlin/net/corda/djvm/rewiring/SandboxClassLoader.kt index 5740534526..7f2abccb6a 100644 --- a/djvm/src/main/kotlin/net/corda/djvm/rewiring/SandboxClassLoader.kt +++ b/djvm/src/main/kotlin/net/corda/djvm/rewiring/SandboxClassLoader.kt @@ -18,7 +18,7 @@ import net.corda.djvm.validation.RuleValidator class SandboxClassLoader( configuration: SandboxConfiguration, private val context: AnalysisContext -) : ClassLoader(null) { +) : ClassLoader() { private val analysisConfiguration = configuration.analysisConfiguration @@ -36,11 +36,6 @@ class SandboxClassLoader( val analyzer: ClassAndMemberVisitor get() = ruleValidator - /** - * Set of classes that should be left untouched due to pinning. - */ - private val pinnedClasses = analysisConfiguration.pinnedClasses - /** * Set of classes that should be left untouched due to whitelisting. */ @@ -61,6 +56,17 @@ class SandboxClassLoader( */ private val rewriter: ClassRewriter = ClassRewriter(configuration, supportingClassLoader) + /** + * Given a class name, provide its corresponding [LoadedClass] for the sandbox. + */ + fun loadForSandbox(name: String, context: AnalysisContext): LoadedClass { + return loadClassAndBytes(ClassSource.fromClassName(analysisConfiguration.classResolver.resolveNormalized(name)), context) + } + + fun loadForSandbox(source: ClassSource, context: AnalysisContext): LoadedClass { + return loadForSandbox(source.qualifiedClassName, context) + } + /** * Load the class with the specified binary name. * @@ -69,69 +75,68 @@ class SandboxClassLoader( * * @return The resulting Class object. */ + @Throws(ClassNotFoundException::class) override fun loadClass(name: String, resolve: Boolean): Class<*> { - return loadClassAndBytes(ClassSource.fromClassName(name), context).type + val source = ClassSource.fromClassName(name) + return if (name.startsWith("sandbox.") && !analysisConfiguration.isPinnedClass(source.internalClassName)) { + loadClassAndBytes(source, context).type + } else { + super.loadClass(name, resolve) + } } /** * Load the class with the specified binary name. * - * @param source The class source, including the binary name of the class. + * @param request The class request, including the binary name of the class. * @param context The context in which the analysis is conducted. * * @return The resulting Class object and its byte code representation. */ - fun loadClassAndBytes(source: ClassSource, context: AnalysisContext): LoadedClass { - logger.debug("Loading class {}, origin={}...", source.qualifiedClassName, source.origin) - val name = analysisConfiguration.classResolver.reverseNormalized(source.qualifiedClassName) - val resolvedName = analysisConfiguration.classResolver.resolveNormalized(name) + private fun loadClassAndBytes(request: ClassSource, context: AnalysisContext): LoadedClass { + logger.debug("Loading class {}, origin={}...", request.qualifiedClassName, request.origin) + val requestedPath = request.internalClassName + val sourceName = analysisConfiguration.classResolver.reverseNormalized(request.qualifiedClassName) + val resolvedName = analysisConfiguration.classResolver.resolveNormalized(sourceName) // Check if the class has already been loaded. - val loadedClass = loadedClasses[name] + val loadedClass = loadedClasses[requestedPath] if (loadedClass != null) { - logger.trace("Class {} already loaded", source.qualifiedClassName) + logger.trace("Class {} already loaded", request.qualifiedClassName) return loadedClass + } else if (analysisConfiguration.isPinnedClass(requestedPath)) { + logger.debug("Class {} is loaded unmodified", request.qualifiedClassName) + return loadUnmodifiedClass(requestedPath) } - // Load the byte code for the specified class. - val reader = supportingClassLoader.classReader(name, context, source.origin) + val byteCode = if (analysisConfiguration.isTemplateClass(requestedPath)) { + loadUnmodifiedByteCode(requestedPath) + } else { + // Load the byte code for the specified class. + val reader = supportingClassLoader.classReader(sourceName, context, request.origin) - // Analyse the class if not matching the whitelist. - val readClassName = reader.className - if (!analysisConfiguration.whitelist.matches(readClassName)) { - logger.trace("Class {} does not match with the whitelist", source.qualifiedClassName) - logger.trace("Analyzing class {}...", source.qualifiedClassName) - analyzer.analyze(reader, context) - } - - // Check if the class should be left untouched. - val qualifiedName = name.asResourcePath - if (qualifiedName in pinnedClasses) { - logger.trace("Class {} is marked as pinned", source.qualifiedClassName) - val pinnedClasses = LoadedClass( - supportingClassLoader.loadClass(name), - ByteCode(ByteArray(0), false) - ) - loadedClasses[name] = pinnedClasses - if (source.origin != null) { - context.recordClassOrigin(name, ClassReference(source.origin)) + // Analyse the class if not matching the whitelist. + val readClassName = reader.className + if (!analysisConfiguration.whitelist.matches(readClassName)) { + logger.trace("Class {} does not match with the whitelist", request.qualifiedClassName) + logger.trace("Analyzing class {}...", request.qualifiedClassName) + analyzer.analyze(reader, context) } - return pinnedClasses - } - // Check if any errors were found during analysis. - if (context.messages.errorCount > 0) { - logger.trace("Errors detected after analyzing class {}", source.qualifiedClassName) - throw SandboxClassLoadingException(context) - } + // Check if any errors were found during analysis. + if (context.messages.errorCount > 0) { + logger.debug("Errors detected after analyzing class {}", request.qualifiedClassName) + throw SandboxClassLoadingException(context) + } - // Transform the class definition and byte code in accordance with provided rules. - val byteCode = rewriter.rewrite(reader, context) + // Transform the class definition and byte code in accordance with provided rules. + rewriter.rewrite(reader, context) + } // Try to define the transformed class. val clazz = try { when { - whitelistedClasses.matches(qualifiedName) -> supportingClassLoader.loadClass(name) + whitelistedClasses.matches(sourceName.asResourcePath) -> supportingClassLoader.loadClass(sourceName) else -> defineClass(resolvedName, byteCode.bytes, 0, byteCode.bytes.size) } } catch (exception: SecurityException) { @@ -140,19 +145,31 @@ class SandboxClassLoader( // Cache transformed class. val classWithByteCode = LoadedClass(clazz, byteCode) - loadedClasses[name] = classWithByteCode - if (source.origin != null) { - context.recordClassOrigin(name, ClassReference(source.origin)) + loadedClasses[requestedPath] = classWithByteCode + if (request.origin != null) { + context.recordClassOrigin(sourceName, ClassReference(request.origin)) } logger.debug("Loaded class {}, bytes={}, isModified={}", - source.qualifiedClassName, byteCode.bytes.size, byteCode.isModified) + request.qualifiedClassName, byteCode.bytes.size, byteCode.isModified) return classWithByteCode } + private fun loadUnmodifiedByteCode(internalClassName: String): ByteCode { + return ByteCode((getSystemClassLoader().getResourceAsStream("$internalClassName.class") + ?: throw ClassNotFoundException(internalClassName)).readBytes(), false) + } + + private fun loadUnmodifiedClass(className: String): LoadedClass { + return LoadedClass(supportingClassLoader.loadClass(className), UNMODIFIED).apply { + loadedClasses[className] = this + } + } + private companion object { private val logger = loggerFor() + private val UNMODIFIED = ByteCode(ByteArray(0), false) } } diff --git a/djvm/src/main/kotlin/net/corda/djvm/rewiring/SandboxClassRemapper.kt b/djvm/src/main/kotlin/net/corda/djvm/rewiring/SandboxClassRemapper.kt new file mode 100644 index 0000000000..7412999727 --- /dev/null +++ b/djvm/src/main/kotlin/net/corda/djvm/rewiring/SandboxClassRemapper.kt @@ -0,0 +1,52 @@ +package net.corda.djvm.rewiring + +import net.corda.djvm.analysis.AnalysisConfiguration +import net.corda.djvm.analysis.ClassAndMemberVisitor.Companion.API_VERSION +import org.objectweb.asm.ClassVisitor +import org.objectweb.asm.Label +import org.objectweb.asm.MethodVisitor +import org.objectweb.asm.commons.ClassRemapper + +class SandboxClassRemapper(cv: ClassVisitor, private val configuration: AnalysisConfiguration) + : ClassRemapper(cv, SandboxRemapper(configuration.classResolver, configuration.whitelist) +) { + override fun createMethodRemapper(mv: MethodVisitor): MethodVisitor { + return MethodRemapperWithPinning(mv, super.createMethodRemapper(mv)) + } + + /** + * Do not attempt to remap references to methods and fields on pinned classes. + * For example, the methods on [RuntimeCostAccounter] really DO use [java.lang.String] + * rather than [sandbox.java.lang.String]. + */ + private inner class MethodRemapperWithPinning(private val nonmapper: MethodVisitor, remapper: MethodVisitor) + : MethodVisitor(API_VERSION, remapper) { + + private fun mapperFor(element: Element): MethodVisitor { + return if (configuration.isPinnedClass(element.owner) || configuration.isTemplateClass(element.owner) || isUnmapped(element)) { + nonmapper + } else { + mv + } + } + + override fun visitMethodInsn(opcode: Int, owner: String, name: String, descriptor: String, isInterface: Boolean) { + val method = Element(owner, name, descriptor) + return mapperFor(method).visitMethodInsn(opcode, owner, name, descriptor, isInterface) + } + + override fun visitTryCatchBlock(start: Label, end: Label, handler: Label, type: String?) { + // Don't map caught exception names - these could be thrown by the JVM itself. + nonmapper.visitTryCatchBlock(start, end, handler, type) + } + + override fun visitFieldInsn(opcode: Int, owner: String, name: String, descriptor: String) { + val field = Element(owner, name, descriptor) + return mapperFor(field).visitFieldInsn(opcode, owner, name, descriptor) + } + } + + private fun isUnmapped(element: Element): Boolean = configuration.whitelist.matches(element.owner) + + private data class Element(val owner: String, val name: String, val descriptor: String) +} \ No newline at end of file diff --git a/djvm/src/main/kotlin/net/corda/djvm/rewiring/SandboxClassWriter.kt b/djvm/src/main/kotlin/net/corda/djvm/rewiring/SandboxClassWriter.kt index e1a051d45c..fc0ad559f6 100644 --- a/djvm/src/main/kotlin/net/corda/djvm/rewiring/SandboxClassWriter.kt +++ b/djvm/src/main/kotlin/net/corda/djvm/rewiring/SandboxClassWriter.kt @@ -8,13 +8,13 @@ import org.objectweb.asm.ClassWriter.COMPUTE_MAXS import org.objectweb.asm.Type /** - * Class writer for sandbox execution, with configurable a [classLoader] to ensure correct deduction of the used class + * Class writer for sandbox execution, with a configurable classloader to ensure correct deduction of the used class * hierarchy. * * @param classReader The [ClassReader] used to read the original class. It will be used to copy the entire constant * pool and bootstrap methods from the original class and also to copy other fragments of original byte code where * applicable. - * @property classLoader The class loader used to load the classes that are to be rewritten. + * @property cloader The class loader used to load the classes that are to be rewritten. * @param flags Option flags that can be used to modify the default behaviour of this class. Must be zero or a * combination of [COMPUTE_MAXS] and [COMPUTE_FRAMES]. These option flags do not affect methods that are copied as is * in the new class. This means that neither the maximum stack size nor the stack frames will be computed for these @@ -61,7 +61,7 @@ open class SandboxClassWriter( } } - companion object { + private companion object { private const val OBJECT_NAME = "java/lang/Object" diff --git a/djvm/src/main/kotlin/net/corda/djvm/rewiring/SandboxRemapper.kt b/djvm/src/main/kotlin/net/corda/djvm/rewiring/SandboxRemapper.kt index 566b377fe6..e828fa4480 100644 --- a/djvm/src/main/kotlin/net/corda/djvm/rewiring/SandboxRemapper.kt +++ b/djvm/src/main/kotlin/net/corda/djvm/rewiring/SandboxRemapper.kt @@ -1,15 +1,19 @@ package net.corda.djvm.rewiring import net.corda.djvm.analysis.ClassResolver +import net.corda.djvm.analysis.Whitelist +import org.objectweb.asm.* import org.objectweb.asm.commons.Remapper /** * Class name and descriptor re-mapper for use in a sandbox. * * @property classResolver Functionality for resolving the class name of a sandboxed or sandboxable class. + * @property whitelist Identifies the Java APIs which are not mapped into the sandbox namespace. */ open class SandboxRemapper( - private val classResolver: ClassResolver + private val classResolver: ClassResolver, + private val whitelist: Whitelist ) : Remapper() { /** @@ -26,6 +30,32 @@ open class SandboxRemapper( return rewriteTypeName(super.map(typename)) } + /** + * Mapper for [Type] and [Handle] objects. + */ + override fun mapValue(obj: Any?): Any? { + return if (obj is Handle && whitelist.matches(obj.owner)) { + obj + } else { + super.mapValue(obj) + } + } + + /** + * All [Object.toString] methods must be transformed to [sandbox.java.lang.Object.toDJVMString], + * to allow the return type to change to [sandbox.java.lang.String]. + * + * The [sandbox.java.lang.Object] class is pinned and not mapped. + */ + override fun mapMethodName(owner: String, name: String, descriptor: String): String { + val newName = if (name == "toString" && descriptor == "()Ljava/lang/String;") { + "toDJVMString" + } else { + name + } + return super.mapMethodName(owner, newName, descriptor) + } + /** * Function for rewriting a descriptor. */ diff --git a/djvm/src/main/kotlin/net/corda/djvm/rules/implementation/ArgumentUnwrapper.kt b/djvm/src/main/kotlin/net/corda/djvm/rules/implementation/ArgumentUnwrapper.kt new file mode 100644 index 0000000000..952efc4251 --- /dev/null +++ b/djvm/src/main/kotlin/net/corda/djvm/rules/implementation/ArgumentUnwrapper.kt @@ -0,0 +1,35 @@ +package net.corda.djvm.rules.implementation + +import net.corda.djvm.code.Emitter +import net.corda.djvm.code.EmitterContext +import net.corda.djvm.code.Instruction +import net.corda.djvm.code.instructions.MemberAccessInstruction + +/** + * Some whitelisted functions have [java.lang.String] arguments, so we + * need to unwrap the [sandbox.java.lang.String] object before invoking. + * + * There are lots of rabbits in this hole because method arguments are + * theoretically arbitrary. However, in practice WE control the whitelist. + */ +class ArgumentUnwrapper : Emitter { + override fun emit(context: EmitterContext, instruction: Instruction) = context.emit { + if (instruction is MemberAccessInstruction && context.whitelist.matches(instruction.owner)) { + fun unwrapString() = invokeStatic("sandbox/java/lang/String", "fromDJVM", "(Lsandbox/java/lang/String;)Ljava/lang/String;") + + if (hasStringArgument(instruction)) { + unwrapString() + } else if (instruction.owner == "java/lang/Class" && instruction.signature.startsWith("(Ljava/lang/String;ZLjava/lang/ClassLoader;)")) { + /** + * [kotlin.jvm.internal.Intrinsics.checkHasClass] invokes [Class.forName], so I'm + * adding support for both of this function's variants. For now. + */ + raiseThirdWordToTop() + unwrapString() + sinkTopToThirdWord() + } + } + } + + private fun hasStringArgument(method: MemberAccessInstruction) = method.signature.contains("Ljava/lang/String;)") +} \ No newline at end of file diff --git a/djvm/src/main/kotlin/net/corda/djvm/rules/implementation/DisallowNonDeterministicMethods.kt b/djvm/src/main/kotlin/net/corda/djvm/rules/implementation/DisallowNonDeterministicMethods.kt index 04ef9e3d5c..2e3b43f935 100644 --- a/djvm/src/main/kotlin/net/corda/djvm/rules/implementation/DisallowNonDeterministicMethods.kt +++ b/djvm/src/main/kotlin/net/corda/djvm/rules/implementation/DisallowNonDeterministicMethods.kt @@ -17,7 +17,7 @@ class DisallowNonDeterministicMethods : Emitter { override fun emit(context: EmitterContext, instruction: Instruction) = context.emit { if (instruction is MemberAccessInstruction && isForbidden(instruction)) { when (instruction.operation) { - INVOKEVIRTUAL -> { + INVOKEVIRTUAL, INVOKESPECIAL -> { throwException("Disallowed reference to API; ${memberFormatter.format(instruction.member)}") preventDefault() } @@ -31,12 +31,20 @@ class DisallowNonDeterministicMethods : Emitter { || instruction.signature.contains("Ljava/lang/reflect/")) ) + private fun isClassLoading(instruction: MemberAccessInstruction): Boolean = + (instruction.owner == "java/lang/ClassLoader") && instruction.memberName in CLASSLOADING_METHODS + private fun isObjectMonitor(instruction: MemberAccessInstruction): Boolean = - (instruction.signature == "()V" && (instruction.memberName == "notify" || instruction.memberName == "notifyAll" || instruction.memberName == "wait")) + (instruction.signature == "()V" && instruction.memberName in MONITOR_METHODS) || (instruction.memberName == "wait" && (instruction.signature == "(J)V" || instruction.signature == "(JI)V")) private fun isForbidden(instruction: MemberAccessInstruction): Boolean - = instruction.isMethod && (isClassReflection(instruction) || isObjectMonitor(instruction)) + = instruction.isMethod && (isClassReflection(instruction) || isObjectMonitor(instruction) || isClassLoading(instruction)) private val memberFormatter = MemberFormatter() + + private companion object { + private val MONITOR_METHODS = setOf("notify", "notifyAll", "wait") + private val CLASSLOADING_METHODS = setOf("defineClass", "loadClass", "findClass") + } } \ No newline at end of file diff --git a/djvm/src/main/kotlin/net/corda/djvm/rules/implementation/ReturnTypeWrapper.kt b/djvm/src/main/kotlin/net/corda/djvm/rules/implementation/ReturnTypeWrapper.kt new file mode 100644 index 0000000000..7f103f346b --- /dev/null +++ b/djvm/src/main/kotlin/net/corda/djvm/rules/implementation/ReturnTypeWrapper.kt @@ -0,0 +1,27 @@ +package net.corda.djvm.rules.implementation + +import net.corda.djvm.code.Emitter +import net.corda.djvm.code.EmitterContext +import net.corda.djvm.code.Instruction +import net.corda.djvm.code.instructions.MemberAccessInstruction + +/** + * Whitelisted classes may still return [java.lang.String] from some + * functions, e.g. [java.lang.Object.toString]. So always explicitly + * invoke [sandbox.java.lang.String.toDJVM] after these. + */ +class ReturnTypeWrapper : Emitter { + override fun emit(context: EmitterContext, instruction: Instruction) = context.emit { + if (instruction is MemberAccessInstruction && context.whitelist.matches(instruction.owner)) { + fun invokeMethod() = invokeVirtual(instruction.owner, instruction.memberName, instruction.signature) + + if (hasStringReturnType(instruction)) { + preventDefault() + invokeMethod() + invokeStatic("sandbox/java/lang/String", "toDJVM", "(Ljava/lang/String;)Lsandbox/java/lang/String;") + } + } + } + + private fun hasStringReturnType(method: MemberAccessInstruction) = method.signature.endsWith(")Ljava/lang/String;") +} \ No newline at end of file diff --git a/djvm/src/main/kotlin/net/corda/djvm/rules/implementation/RewriteClassMethods.kt b/djvm/src/main/kotlin/net/corda/djvm/rules/implementation/RewriteClassMethods.kt new file mode 100644 index 0000000000..555ada7ed1 --- /dev/null +++ b/djvm/src/main/kotlin/net/corda/djvm/rules/implementation/RewriteClassMethods.kt @@ -0,0 +1,56 @@ +package net.corda.djvm.rules.implementation + +import net.corda.djvm.code.Emitter +import net.corda.djvm.code.EmitterContext +import net.corda.djvm.code.Instruction +import net.corda.djvm.code.instructions.MemberAccessInstruction +import org.objectweb.asm.Opcodes.* + +/** + * The enum-related methods on [Class] all require that enums use [java.lang.Enum] + * as their super class. So replace their all invocations with ones to equivalent + * methods on the DJVM class that require [sandbox.java.lang.Enum] instead. + */ +class RewriteClassMethods : Emitter { + override fun emit(context: EmitterContext, instruction: Instruction) = context.emit { + if (instruction is MemberAccessInstruction && instruction.owner == "java/lang/Class") { + when (instruction.operation) { + INVOKEVIRTUAL -> if (instruction.memberName == "enumConstantDirectory" && instruction.signature == "()Ljava/util/Map;") { + invokeStatic( + owner = "sandbox/java/lang/DJVM", + name = "enumConstantDirectory", + descriptor = "(Ljava/lang/Class;)Lsandbox/java/util/Map;" + ) + preventDefault() + } else if (instruction.memberName == "isEnum" && instruction.signature == "()Z") { + invokeStatic( + owner = "sandbox/java/lang/DJVM", + name = "isEnum", + descriptor = "(Ljava/lang/Class;)Z" + ) + preventDefault() + } else if (instruction.memberName == "getEnumConstants" && instruction.signature == "()[Ljava/lang/Object;") { + invokeStatic( + owner = "sandbox/java/lang/DJVM", + name = "getEnumConstants", + descriptor = "(Ljava/lang/Class;)[Ljava/lang/Object;") + preventDefault() + } + + INVOKESTATIC -> if (isClassForName(instruction)) { + invokeStatic( + owner = "sandbox/java/lang/DJVM", + name = "classForName", + descriptor = instruction.signature + ) + preventDefault() + } + } + } + } + + private fun isClassForName(instruction: MemberAccessInstruction): Boolean + = instruction.memberName == "forName" && + (instruction.signature == "(Ljava/lang/String;)Ljava/lang/Class;" || + instruction.signature == "(Ljava/lang/String;ZLjava/lang/ClassLoader;)Ljava/lang/Class;") +} diff --git a/djvm/src/main/kotlin/net/corda/djvm/rules/implementation/StaticConstantRemover.kt b/djvm/src/main/kotlin/net/corda/djvm/rules/implementation/StaticConstantRemover.kt new file mode 100644 index 0000000000..ea6826903e --- /dev/null +++ b/djvm/src/main/kotlin/net/corda/djvm/rules/implementation/StaticConstantRemover.kt @@ -0,0 +1,30 @@ +package net.corda.djvm.rules.implementation + +import net.corda.djvm.analysis.AnalysisRuntimeContext +import net.corda.djvm.code.EmitterModule +import net.corda.djvm.code.MemberDefinitionProvider +import net.corda.djvm.references.Member + +/** + * Removes static constant objects that are initialised directly in the byte-code. + * Currently, the only use-case is for re-initialising [String] fields. + */ +class StaticConstantRemover : MemberDefinitionProvider { + + override fun define(context: AnalysisRuntimeContext, member: Member): Member = when { + isConstantField(member) -> member.copy(body = listOf(StringFieldInitializer(member)::writeInitializer), value = null) + else -> member + } + + private fun isConstantField(member: Member): Boolean = member.value != null && member.signature == "Ljava/lang/String;" + + class StringFieldInitializer(private val member: Member) { + fun writeInitializer(emitter: EmitterModule): Unit = with(emitter) { + member.value?.apply { + loadConstant(this) + invokeStatic("sandbox/java/lang/String", "toDJVM", "(Ljava/lang/String;)Lsandbox/java/lang/String;", false) + putStatic(member.className, member.memberName, "Lsandbox/java/lang/String;") + } + } + } +} \ No newline at end of file diff --git a/djvm/src/main/kotlin/net/corda/djvm/rules/implementation/StringConstantWrapper.kt b/djvm/src/main/kotlin/net/corda/djvm/rules/implementation/StringConstantWrapper.kt new file mode 100644 index 0000000000..6223cfd0b4 --- /dev/null +++ b/djvm/src/main/kotlin/net/corda/djvm/rules/implementation/StringConstantWrapper.kt @@ -0,0 +1,22 @@ +package net.corda.djvm.rules.implementation + +import net.corda.djvm.code.Emitter +import net.corda.djvm.code.EmitterContext +import net.corda.djvm.code.Instruction +import net.corda.djvm.code.instructions.ConstantInstruction + +/** + * Ensure that [String] constants loaded from the Constants + * Pool are wrapped into [sandbox.java.lang.String]. + */ +class StringConstantWrapper : Emitter { + override fun emit(context: EmitterContext, instruction: Instruction) = context.emit { + if (instruction is ConstantInstruction) { + when (instruction.value) { + is String -> { + invokeStatic("sandbox/java/lang/String", "toDJVM", "(Ljava/lang/String;)Lsandbox/java/lang/String;", false) + } + } + } + } +} \ No newline at end of file diff --git a/djvm/src/main/kotlin/net/corda/djvm/rules/implementation/StubOutNativeMethods.kt b/djvm/src/main/kotlin/net/corda/djvm/rules/implementation/StubOutNativeMethods.kt index 74a58f6c7f..d1b6918fef 100644 --- a/djvm/src/main/kotlin/net/corda/djvm/rules/implementation/StubOutNativeMethods.kt +++ b/djvm/src/main/kotlin/net/corda/djvm/rules/implementation/StubOutNativeMethods.kt @@ -23,7 +23,7 @@ class StubOutNativeMethods : MemberDefinitionProvider { private fun writeExceptionMethodBody(emitter: EmitterModule): Unit = with(emitter) { lineNumber(0) - throwException(RuleViolationError::class.java, "Native method has been deleted") + throwException("Native method has been deleted") } private fun writeStubMethodBody(emitter: EmitterModule): Unit = with(emitter) { diff --git a/djvm/src/main/kotlin/net/corda/djvm/rules/implementation/StubOutReflectionMethods.kt b/djvm/src/main/kotlin/net/corda/djvm/rules/implementation/StubOutReflectionMethods.kt index 4e486bf289..9c60f420bd 100644 --- a/djvm/src/main/kotlin/net/corda/djvm/rules/implementation/StubOutReflectionMethods.kt +++ b/djvm/src/main/kotlin/net/corda/djvm/rules/implementation/StubOutReflectionMethods.kt @@ -19,7 +19,7 @@ class StubOutReflectionMethods : MemberDefinitionProvider { private fun writeMethodBody(emitter: EmitterModule): Unit = with(emitter) { lineNumber(0) - throwException(RuleViolationError::class.java, "Disallowed reference to reflection API") + throwException("Disallowed reference to reflection API") } // The method must be public and with a Java implementation. diff --git a/djvm/src/main/kotlin/net/corda/djvm/source/ClassSource.kt b/djvm/src/main/kotlin/net/corda/djvm/source/ClassSource.kt index cebe0b0b82..99ef5319fb 100644 --- a/djvm/src/main/kotlin/net/corda/djvm/source/ClassSource.kt +++ b/djvm/src/main/kotlin/net/corda/djvm/source/ClassSource.kt @@ -1,17 +1,20 @@ package net.corda.djvm.source +import net.corda.djvm.code.asResourcePath import java.nio.file.Path /** * The source of one or more compiled Java classes. * * @property qualifiedClassName The fully qualified class name. + * @property internalClassName The fully qualified internal class name, i.e. with '/' instead of '.'. * @property origin The origin of the class source, if any. */ class ClassSource private constructor( val qualifiedClassName: String = "", val origin: String? = null ) { + val internalClassName: String = qualifiedClassName.asResourcePath companion object { diff --git a/djvm/src/main/kotlin/net/corda/djvm/utilities/Discovery.kt b/djvm/src/main/kotlin/net/corda/djvm/utilities/Discovery.kt index 9092e5c044..33886e76ae 100644 --- a/djvm/src/main/kotlin/net/corda/djvm/utilities/Discovery.kt +++ b/djvm/src/main/kotlin/net/corda/djvm/utilities/Discovery.kt @@ -7,7 +7,7 @@ import java.lang.reflect.Modifier * Find and instantiate types that implement a certain interface. */ object Discovery { - const val FORBIDDEN_CLASS_MASK = (Modifier.STATIC or Modifier.ABSTRACT) + const val FORBIDDEN_CLASS_MASK = (Modifier.STATIC or Modifier.ABSTRACT or Modifier.PRIVATE or Modifier.PROTECTED) /** * Get an instance of each concrete class that implements interface or class [T]. diff --git a/djvm/src/main/kotlin/net/corda/djvm/validation/RuleValidator.kt b/djvm/src/main/kotlin/net/corda/djvm/validation/RuleValidator.kt index 1f4ead8cd1..b409e83df5 100644 --- a/djvm/src/main/kotlin/net/corda/djvm/validation/RuleValidator.kt +++ b/djvm/src/main/kotlin/net/corda/djvm/validation/RuleValidator.kt @@ -17,6 +17,7 @@ import org.objectweb.asm.ClassVisitor * Helper class for validating a set of rules for a class or set of classes. * * @property rules A set of rules to validate for provided classes. + * @param configuration The configuration to use for class analysis. * @param classVisitor Class visitor to use when traversing the structure of classes. */ class RuleValidator( diff --git a/djvm/src/main/kotlin/sandbox/Task.kt b/djvm/src/main/kotlin/sandbox/Task.kt new file mode 100644 index 0000000000..8a2bbab78a --- /dev/null +++ b/djvm/src/main/kotlin/sandbox/Task.kt @@ -0,0 +1,24 @@ +@file:JvmName("TaskTypes") +package sandbox + +import sandbox.java.lang.sandbox +import sandbox.java.lang.unsandbox + +typealias SandboxFunction = sandbox.java.util.function.Function + +@Suppress("unused") +class Task(private val function: SandboxFunction?) : SandboxFunction { + + /** + * This function runs inside the sandbox. It marshalls the input + * object to its sandboxed equivalent, executes the user's code + * and then marshalls the result out again. + * + * The marshalling should be effective for Java primitives, + * Strings and Enums, as well as for arrays of these types. + */ + override fun apply(input: Any?): Any? { + return function?.apply(input?.sandbox())?.unsandbox() + } + +} diff --git a/djvm/src/main/kotlin/sandbox/java/lang/DJVM.kt b/djvm/src/main/kotlin/sandbox/java/lang/DJVM.kt new file mode 100644 index 0000000000..b6a3acdc77 --- /dev/null +++ b/djvm/src/main/kotlin/sandbox/java/lang/DJVM.kt @@ -0,0 +1,158 @@ +@file:JvmName("DJVM") +@file:Suppress("unused") +package sandbox.java.lang + +import org.objectweb.asm.Opcodes.ACC_ENUM + +private const val SANDBOX_PREFIX = "sandbox." + +fun Any.unsandbox(): Any { + return when (this) { + is Enum<*> -> fromDJVMEnum() + is Object -> fromDJVM() + is Array<*> -> fromDJVMArray() + else -> this + } +} + +fun Any.sandbox(): Any { + return when (this) { + is kotlin.String -> String.toDJVM(this) + is kotlin.Char -> Character.toDJVM(this) + is kotlin.Long -> Long.toDJVM(this) + is kotlin.Int -> Integer.toDJVM(this) + is kotlin.Short -> Short.toDJVM(this) + is kotlin.Byte -> Byte.toDJVM(this) + is kotlin.Float -> Float.toDJVM(this) + is kotlin.Double -> Double.toDJVM(this) + is kotlin.Boolean -> Boolean.toDJVM(this) + is kotlin.Enum<*> -> toDJVMEnum() + is Array<*> -> toDJVMArray() + else -> this + } +} + +private fun Array<*>.fromDJVMArray(): Array<*> = Object.fromDJVM(this) + +/** + * These functions use the "current" classloader, i.e. classloader + * that owns this DJVM class. + */ +private fun Class<*>.toDJVMType(): Class<*> = Class.forName(name.toSandboxPackage()) +private fun Class<*>.fromDJVMType(): Class<*> = Class.forName(name.fromSandboxPackage()) + +private fun kotlin.String.toSandboxPackage(): kotlin.String { + return if (startsWith(SANDBOX_PREFIX)) { + this + } else { + SANDBOX_PREFIX + this + } +} + +private fun kotlin.String.fromSandboxPackage(): kotlin.String { + return if (startsWith(SANDBOX_PREFIX)) { + drop(SANDBOX_PREFIX.length) + } else { + this + } +} + +private inline fun Array<*>.toDJVMArray(): Array { + @Suppress("unchecked_cast") + return (java.lang.reflect.Array.newInstance(javaClass.componentType.toDJVMType(), size) as Array).also { + for ((i, item) in withIndex()) { + it[i] = item?.sandbox() as T + } + } +} + +private fun Enum<*>.fromDJVMEnum(): kotlin.Enum<*> { + return javaClass.fromDJVMType().enumConstants[ordinal()] as kotlin.Enum<*> +} + +private fun kotlin.Enum<*>.toDJVMEnum(): Enum<*> { + @Suppress("unchecked_cast") + return (getEnumConstants(javaClass.toDJVMType() as Class>) as Array>)[ordinal] +} + +/** + * Replacement functions for the members of Class<*> that support Enums. + */ +fun isEnum(clazz: Class<*>): kotlin.Boolean + = (clazz.modifiers and ACC_ENUM != 0) && (clazz.superclass == sandbox.java.lang.Enum::class.java) + +fun getEnumConstants(clazz: Class>): Array<*>? { + return getEnumConstantsShared(clazz)?.clone() +} + +internal fun enumConstantDirectory(clazz: Class>): sandbox.java.util.Map>? { + // DO NOT replace get with Kotlin's [] because Kotlin would use java.util.Map. + return allEnumDirectories.get(clazz) ?: createEnumDirectory(clazz) +} + +@Suppress("unchecked_cast") +internal fun getEnumConstantsShared(clazz: Class>): Array>? { + return if (isEnum(clazz)) { + // DO NOT replace get with Kotlin's [] because Kotlin would use java.util.Map. + allEnums.get(clazz) ?: createEnum(clazz) + } else { + null + } +} + +@Suppress("unchecked_cast") +private fun createEnum(clazz: Class>): Array>? { + return clazz.getMethod("values").let { method -> + method.isAccessible = true + method.invoke(null) as? Array> + // DO NOT replace put with Kotlin's [] because Kotlin would use java.util.Map. + }?.apply { allEnums.put(clazz, this) } +} + +private fun createEnumDirectory(clazz: Class>): sandbox.java.util.Map> { + val universe = getEnumConstantsShared(clazz) ?: throw IllegalArgumentException("${clazz.name} is not an enum type") + val directory = sandbox.java.util.LinkedHashMap>(2 * universe.size) + for (entry in universe) { + // DO NOT replace put with Kotlin's [] because Kotlin would use java.util.Map. + directory.put(entry.name(), entry) + } + // DO NOT replace put with Kotlin's [] because Kotlin would use java.util.Map. + allEnumDirectories.put(clazz, directory) + return directory +} + +private val allEnums: sandbox.java.util.Map>, Array>> = sandbox.java.util.LinkedHashMap() +private val allEnumDirectories: sandbox.java.util.Map>, sandbox.java.util.Map>> = sandbox.java.util.LinkedHashMap() + +/** + * Replacement functions for Class<*>.forName(...) which protect + * against users loading classes from outside the sandbox. + */ +@Throws(ClassNotFoundException::class) +fun classForName(className: kotlin.String): Class<*> { + return Class.forName(toSandbox(className)) +} + +@Throws(ClassNotFoundException::class) +fun classForName(className: kotlin.String, initialize: kotlin.Boolean, classLoader: ClassLoader): Class<*> { + return Class.forName(toSandbox(className), initialize, classLoader) +} + +/** + * Force the qualified class name into the sandbox.* namespace. + * Throw [ClassNotFoundException] anyway if we wouldn't want to + * return the resulting sandbox class. E.g. for any of our own + * internal classes. + */ +private fun toSandbox(className: kotlin.String): kotlin.String { + if (bannedClasses.any { it.matches(className) }) { + throw ClassNotFoundException(className) + } + return SANDBOX_PREFIX + className +} + +private val bannedClasses = setOf( + "^java\\.lang\\.DJVM(.*)?\$".toRegex(), + "^net\\.corda\\.djvm\\..*\$".toRegex(), + "^Task\$".toRegex() +) diff --git a/djvm/src/main/kotlin/sandbox/java/lang/Object.kt b/djvm/src/main/kotlin/sandbox/java/lang/Object.kt deleted file mode 100644 index 14ae5df025..0000000000 --- a/djvm/src/main/kotlin/sandbox/java/lang/Object.kt +++ /dev/null @@ -1,19 +0,0 @@ -package sandbox.java.lang - -/** - * Sandboxed implementation of `java/lang/Object`. - */ -@Suppress("EqualsOrHashCode") -open class Object { - - /** - * Deterministic hash code for objects. - */ - override fun hashCode(): Int = sandbox.java.lang.System.identityHashCode(this) - - /** - * Deterministic string representation of [Object]. - */ - override fun toString(): String = "sandbox.java.lang.Object@${hashCode().toString(16)}" - -} diff --git a/djvm/src/main/kotlin/sandbox/java/lang/System.kt b/djvm/src/main/kotlin/sandbox/java/lang/System.kt deleted file mode 100644 index 0b40e0cfd7..0000000000 --- a/djvm/src/main/kotlin/sandbox/java/lang/System.kt +++ /dev/null @@ -1,99 +0,0 @@ -@file:Suppress("UNUSED_PARAMETER") - -package sandbox.java.lang - -import java.io.IOException -import java.util.* - -object System { - - private var objectCounter = object : ThreadLocal() { - override fun initialValue() = 0 - } - - private var objectHashCodes = object : ThreadLocal>() { - override fun initialValue() = mutableMapOf() - } - - @JvmField - val `in`: java.io.InputStream? = null - - @JvmField - val out: java.io.PrintStream? = null - - @JvmField - val err: java.io.PrintStream? = null - - fun setIn(stream: java.io.InputStream) {} - - fun setOut(stream: java.io.PrintStream) {} - - fun setErr(stream: java.io.PrintStream) {} - - fun console(): java.io.Console? { - throw NotImplementedError() - } - - @Throws(java.io.IOException::class) - fun inheritedChannel(): java.nio.channels.Channel? { - throw IOException() - } - - fun setSecurityManager(manager: java.lang.SecurityManager) {} - - fun getSecurityManager(): java.lang.SecurityManager? = null - - fun currentTimeMillis(): Long = 0L - - fun nanoTime(): Long = 0L - - fun arraycopy(src: Object, srcPos: Int, dest: Object, destPos: Int, length: Int) { - java.lang.System.arraycopy(src, srcPos, dest, destPos, length) - } - - fun identityHashCode(obj: Object): Int { - val nativeHashCode = java.lang.System.identityHashCode(obj) - // TODO Instead of using a magic offset below, one could take in a per-context seed - return objectHashCodes.get().getOrPut(nativeHashCode) { - val newCounter = objectCounter.get() + 1 - objectCounter.set(newCounter) - 0xfed_c0de + newCounter - } - } - - fun getProperties(): java.util.Properties { - return Properties() - } - - fun lineSeparator() = "\n" - - fun setProperties(properties: java.util.Properties) {} - - fun getProperty(property: String): String? = null - - fun getProperty(property: String, defaultValue: String): String? = defaultValue - - fun setProperty(property: String, value: String): String? = null - - fun clearProperty(property: String): String? = null - - fun getenv(variable: String): String? = null - - @Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN") - fun getenv(): java.util.Map? = null - - fun exit(exitCode: Int) {} - - fun gc() {} - - fun runFinalization() {} - - fun runFinalizersOnExit(flag: Boolean) {} - - fun load(path: String) {} - - fun loadLibrary(path: String) {} - - fun mapLibraryName(path: String): String? = null - -} diff --git a/djvm/src/test/kotlin/foo/bar/sandbox/KotlinClass.kt b/djvm/src/test/kotlin/foo/bar/sandbox/KotlinClass.kt index 44ea7e29f7..f54128a2d6 100644 --- a/djvm/src/test/kotlin/foo/bar/sandbox/KotlinClass.kt +++ b/djvm/src/test/kotlin/foo/bar/sandbox/KotlinClass.kt @@ -1,12 +1,9 @@ package foo.bar.sandbox -import java.util.* - -fun testRandom(): Int { - val random = Random() - return random.nextInt() +fun testClock(): Long { + return System.nanoTime() } -fun String.toNumber(): Int { - return this.toInt() +fun String.toNumber(): Long { + return this.toLong() } diff --git a/djvm/src/test/kotlin/net/corda/djvm/DJVMTest.kt b/djvm/src/test/kotlin/net/corda/djvm/DJVMTest.kt new file mode 100644 index 0000000000..d71fa1a36a --- /dev/null +++ b/djvm/src/test/kotlin/net/corda/djvm/DJVMTest.kt @@ -0,0 +1,126 @@ +package net.corda.djvm + +import org.assertj.core.api.Assertions.* +import org.junit.Assert.* +import org.junit.Test +import sandbox.java.lang.sandbox +import sandbox.java.lang.unsandbox + +class DJVMTest { + + @Test + fun testDJVMString() { + val djvmString = sandbox.java.lang.String("New Value") + assertNotEquals(djvmString, "New Value") + assertEquals(djvmString, "New Value".sandbox()) + } + + @Test + fun testSimpleIntegerFormats() { + val result = sandbox.java.lang.String.format("%d-%d-%d-%d".toDJVM(), + 10.toDJVM(), 999999L.toDJVM(), 1234.toShort().toDJVM(), 108.toByte().toDJVM()).toString() + assertEquals("10-999999-1234-108", result) + } + + @Test + fun testHexFormat() { + val result = sandbox.java.lang.String.format("%0#6x".toDJVM(), 768.toDJVM()).toString() + assertEquals("0x0300", result) + } + + @Test + fun testDoubleFormat() { + val result = sandbox.java.lang.String.format("%9.4f".toDJVM(), 1234.5678.toDJVM()).toString() + assertEquals("1234.5678", result) + } + + @Test + fun testFloatFormat() { + val result = sandbox.java.lang.String.format("%7.2f".toDJVM(), 1234.5678f.toDJVM()).toString() + assertEquals("1234.57", result) + } + + @Test + fun testCharFormat() { + val result = sandbox.java.lang.String.format("[%c]".toDJVM(), 'A'.toDJVM()).toString() + assertEquals("[A]", result) + } + + @Test + fun testObjectFormat() { + val result = sandbox.java.lang.String.format("%s".toDJVM(), object : sandbox.java.lang.Object() {}).toString() + assertThat(result).startsWith("sandbox.java.lang.Object@") + } + + @Test + fun testStringEquality() { + val number = sandbox.java.lang.String.valueOf((Double.MIN_VALUE / 2.0) * 2.0) + require(number == "0.0".sandbox()) + } + + @Test + fun testSandboxingArrays() { + val result = arrayOf(1, 10L, "Hello World", '?', false, 1234.56).sandbox() + assertThat(result) + .isEqualTo(arrayOf(1.toDJVM(), 10L.toDJVM(), "Hello World".toDJVM(), '?'.toDJVM(), false.toDJVM(), 1234.56.toDJVM())) + } + + @Test + fun testUnsandboxingObjectArray() { + val result = arrayOf(1.toDJVM(), 10L.toDJVM(), "Hello World".toDJVM(), '?'.toDJVM(), false.toDJVM(), 1234.56.toDJVM()).unsandbox() + assertThat(result) + .isEqualTo(arrayOf(1, 10L, "Hello World", '?', false, 1234.56)) + } + + @Test + fun testSandboxingPrimitiveArray() { + val result = intArrayOf(1, 2, 3, 10).sandbox() + assertThat(result).isEqualTo(intArrayOf(1, 2, 3, 10)) + } + + @Test + fun testSandboxingIntegersAsObjectArray() { + val result = arrayOf(1, 2, 3, 10).sandbox() + assertThat(result).isEqualTo(arrayOf(1.toDJVM(), 2.toDJVM(), 3.toDJVM(), 10.toDJVM())) + } + + @Test + fun testUnsandboxingArrays() { + val arr = arrayOf( + Array(1) { "Hello".toDJVM() }, + Array(1) { 1234000L.toDJVM() }, + Array(1) { 1234.toDJVM() }, + Array(1) { 923.toShort().toDJVM() }, + Array(1) { 27.toByte().toDJVM() }, + Array(1) { 'X'.toDJVM() }, + Array(1) { 987.65f.toDJVM() }, + Array(1) { 343.282.toDJVM() }, + Array(1) { true.toDJVM() }, + ByteArray(1) { 127.toByte() }, + CharArray(1) { '?'} + ) + val result = arr.unsandbox() as Array<*> + assertEquals(arr.size, result.size) + assertArrayEquals(Array(1) { "Hello" }, result[0] as Array<*>) + assertArrayEquals(Array(1) { 1234000L }, result[1] as Array<*>) + assertArrayEquals(Array(1) { 1234 }, result[2] as Array<*>) + assertArrayEquals(Array(1) { 923.toShort() }, result[3] as Array<*>) + assertArrayEquals(Array(1) { 27.toByte() }, result[4] as Array<*>) + assertArrayEquals(Array(1) { 'X' }, result[5] as Array<*>) + assertArrayEquals(Array(1) { 987.65f }, result[6] as Array<*>) + assertArrayEquals(Array(1) { 343.282 }, result[7] as Array<*>) + assertArrayEquals(Array(1) { true }, result[8] as Array<*>) + assertArrayEquals(ByteArray(1) { 127.toByte() }, result[9] as ByteArray) + assertArrayEquals(CharArray(1) { '?' }, result[10] as CharArray) + } + + private fun String.toDJVM(): sandbox.java.lang.String = sandbox.java.lang.String.toDJVM(this) + private fun Long.toDJVM(): sandbox.java.lang.Long = sandbox.java.lang.Long.toDJVM(this) + private fun Int.toDJVM(): sandbox.java.lang.Integer = sandbox.java.lang.Integer.toDJVM(this) + private fun Short.toDJVM(): sandbox.java.lang.Short = sandbox.java.lang.Short.toDJVM(this) + private fun Byte.toDJVM(): sandbox.java.lang.Byte = sandbox.java.lang.Byte.toDJVM(this) + private fun Float.toDJVM(): sandbox.java.lang.Float = sandbox.java.lang.Float.toDJVM(this) + private fun Double.toDJVM(): sandbox.java.lang.Double = sandbox.java.lang.Double.toDJVM(this) + private fun Char.toDJVM(): sandbox.java.lang.Character = sandbox.java.lang.Character.toDJVM(this) + private fun Boolean.toDJVM(): sandbox.java.lang.Boolean = sandbox.java.lang.Boolean.toDJVM(this) +} \ No newline at end of file diff --git a/djvm/src/test/kotlin/net/corda/djvm/TestBase.kt b/djvm/src/test/kotlin/net/corda/djvm/TestBase.kt index b54d92b16e..a771798655 100644 --- a/djvm/src/test/kotlin/net/corda/djvm/TestBase.kt +++ b/djvm/src/test/kotlin/net/corda/djvm/TestBase.kt @@ -13,6 +13,7 @@ import net.corda.djvm.messages.Severity import net.corda.djvm.references.ClassHierarchy import net.corda.djvm.rewiring.LoadedClass import net.corda.djvm.rules.Rule +import net.corda.djvm.rules.implementation.* import net.corda.djvm.source.ClassSource import net.corda.djvm.utilities.Discovery import net.corda.djvm.validation.RuleValidator @@ -35,8 +36,19 @@ abstract class TestBase { val ALL_EMITTERS = Discovery.find() + // We need at least these emitters to handle the Java API classes. + val BASIC_EMITTERS: List = listOf( + ArgumentUnwrapper(), + ReturnTypeWrapper(), + RewriteClassMethods(), + StringConstantWrapper() + ) + val ALL_DEFINITION_PROVIDERS = Discovery.find() + // We need at least these providers to handle the Java API classes. + val BASIC_DEFINITION_PROVIDERS: List = listOf(StaticConstantRemover()) + val BLANK = emptySet() val DEFAULT = (ALL_RULES + ALL_EMITTERS + ALL_DEFINITION_PROVIDERS).distinctBy(Any::javaClass) @@ -86,14 +98,6 @@ abstract class TestBase { } } - /** - * Short-hand for analysing a class. - */ - inline fun analyze(block: (ClassAndMemberVisitor.(AnalysisContext) -> Unit)) { - val validator = RuleValidator(emptyList(), configuration) - block(validator, context) - } - /** * Run action on a separate thread to ensure that the code is run off a clean slate. The sandbox context is local to * the current thread, so this allows inspection of the cost summary object, etc. from within the provided delegate. @@ -106,8 +110,8 @@ abstract class TestBase { action: SandboxRuntimeContext.() -> Unit ) { val rules = mutableListOf() - val emitters = mutableListOf() - val definitionProviders = mutableListOf() + val emitters = mutableListOf().apply { addAll(BASIC_EMITTERS) } + val definitionProviders = mutableListOf().apply { addAll(BASIC_DEFINITION_PROVIDERS) } val classSources = mutableListOf() var executionProfile = ExecutionProfile.UNLIMITED var whitelist = Whitelist.MINIMAL @@ -137,7 +141,12 @@ abstract class TestBase { minimumSeverityLevel = minimumSeverityLevel ).use { analysisConfiguration -> SandboxRuntimeContext(SandboxConfiguration.of( - executionProfile, rules, emitters, definitionProviders, enableTracing, analysisConfiguration + executionProfile, + rules.distinctBy(Any::javaClass), + emitters.distinctBy(Any::javaClass), + definitionProviders.distinctBy(Any::javaClass), + enableTracing, + analysisConfiguration )).use { assertThat(runtimeCosts).areZero() action(this) @@ -163,7 +172,7 @@ abstract class TestBase { inline fun SandboxRuntimeContext.loadClass(): LoadedClass = loadClass(T::class.jvmName) fun SandboxRuntimeContext.loadClass(className: String): LoadedClass = - classLoader.loadClassAndBytes(ClassSource.fromClassName(className), context) + classLoader.loadForSandbox(className, context) /** * Run the entry-point of the loaded [Callable] class. diff --git a/djvm/src/test/kotlin/net/corda/djvm/Utilities.kt b/djvm/src/test/kotlin/net/corda/djvm/Utilities.kt new file mode 100644 index 0000000000..d493238723 --- /dev/null +++ b/djvm/src/test/kotlin/net/corda/djvm/Utilities.kt @@ -0,0 +1,22 @@ +package net.corda.djvm + +import sandbox.net.corda.djvm.costing.ThresholdViolationError +import sandbox.net.corda.djvm.rules.RuleViolationError + +object Utilities { + fun throwRuleViolationError(): Nothing = throw RuleViolationError("Can't catch this!") + + fun throwThresholdViolationError(): Nothing = throw ThresholdViolationError("Can't catch this!") + + fun throwContractConstraintViolation(): Nothing = throw IllegalArgumentException("Contract constraint violated") + + fun throwError(): Nothing = throw Error() + + fun throwThrowable(): Nothing = throw Throwable() + + fun throwThreadDeath(): Nothing = throw ThreadDeath() + + fun throwStackOverflowError(): Nothing = throw StackOverflowError("FAKE OVERFLOW!") + + fun throwOutOfMemoryError(): Nothing = throw OutOfMemoryError("FAKE OOM!") +} diff --git a/djvm/src/test/kotlin/net/corda/djvm/analysis/ClassResolverTest.kt b/djvm/src/test/kotlin/net/corda/djvm/analysis/ClassResolverTest.kt index 8f39f32808..d1ae149cb7 100644 --- a/djvm/src/test/kotlin/net/corda/djvm/analysis/ClassResolverTest.kt +++ b/djvm/src/test/kotlin/net/corda/djvm/analysis/ClassResolverTest.kt @@ -5,25 +5,25 @@ import org.junit.Test class ClassResolverTest { - private val resolver = ClassResolver(emptySet(), Whitelist.MINIMAL, "sandbox/") + private val resolver = ClassResolver(emptySet(), emptySet(), Whitelist.MINIMAL, "sandbox/") @Test fun `can resolve class name`() { assertThat(resolver.resolve("java/lang/Object")).isEqualTo("java/lang/Object") - assertThat(resolver.resolve("java/lang/String")).isEqualTo("java/lang/String") + assertThat(resolver.resolve("java/lang/String")).isEqualTo("sandbox/java/lang/String") assertThat(resolver.resolve("foo/bar/Test")).isEqualTo("sandbox/foo/bar/Test") } @Test fun `can resolve class name for arrays`() { assertThat(resolver.resolve("[Ljava/lang/Object;")).isEqualTo("[Ljava/lang/Object;") - assertThat(resolver.resolve("[Ljava/lang/String;")).isEqualTo("[Ljava/lang/String;") + assertThat(resolver.resolve("[Ljava/lang/String;")).isEqualTo("[Lsandbox/java/lang/String;") assertThat(resolver.resolve("[Lfoo/bar/Test;")).isEqualTo("[Lsandbox/foo/bar/Test;") assertThat(resolver.resolve("[[Ljava/lang/Object;")).isEqualTo("[[Ljava/lang/Object;") - assertThat(resolver.resolve("[[Ljava/lang/String;")).isEqualTo("[[Ljava/lang/String;") + assertThat(resolver.resolve("[[Ljava/lang/String;")).isEqualTo("[[Lsandbox/java/lang/String;") assertThat(resolver.resolve("[[Lfoo/bar/Test;")).isEqualTo("[[Lsandbox/foo/bar/Test;") assertThat(resolver.resolve("[[[Ljava/lang/Object;")).isEqualTo("[[[Ljava/lang/Object;") - assertThat(resolver.resolve("[[[Ljava/lang/String;")).isEqualTo("[[[Ljava/lang/String;") + assertThat(resolver.resolve("[[[Ljava/lang/String;")).isEqualTo("[[[Lsandbox/java/lang/String;") assertThat(resolver.resolve("[[[Lfoo/bar/Test;")).isEqualTo("[[[Lsandbox/foo/bar/Test;") } diff --git a/djvm/src/test/kotlin/net/corda/djvm/analysis/WhitelistTest.kt b/djvm/src/test/kotlin/net/corda/djvm/analysis/WhitelistTest.kt index a817be4108..74d223af13 100644 --- a/djvm/src/test/kotlin/net/corda/djvm/analysis/WhitelistTest.kt +++ b/djvm/src/test/kotlin/net/corda/djvm/analysis/WhitelistTest.kt @@ -11,8 +11,8 @@ class WhitelistTest : TestBase() { val whitelist = Whitelist.MINIMAL assertThat(whitelist.matches("java/lang/Object")).isTrue() assertThat(whitelist.matches("java/lang/Object.:()V")).isTrue() - assertThat(whitelist.matches("java/lang/Integer")).isTrue() - assertThat(whitelist.matches("java/lang/Integer.:(I)V")).isTrue() + assertThat(whitelist.matches("java/lang/reflect/Array")).isTrue() + assertThat(whitelist.matches("java/lang/reflect/Array.setInt(Ljava/lang/Object;II)V")).isTrue() } @Test diff --git a/djvm/src/test/kotlin/net/corda/djvm/assertions/AssertiveClassWithByteCode.kt b/djvm/src/test/kotlin/net/corda/djvm/assertions/AssertiveClassWithByteCode.kt index cc122b7a8f..0957217f47 100644 --- a/djvm/src/test/kotlin/net/corda/djvm/assertions/AssertiveClassWithByteCode.kt +++ b/djvm/src/test/kotlin/net/corda/djvm/assertions/AssertiveClassWithByteCode.kt @@ -28,4 +28,9 @@ class AssertiveClassWithByteCode(private val loadedClass: LoadedClass) { assertThat(loadedClass.type.name).isEqualTo(className) return this } + + fun hasInterface(className: String): AssertiveClassWithByteCode { + assertThat(loadedClass.type.interfaces.map(Class<*>::getName)).contains(className) + return this + } } diff --git a/djvm/src/test/kotlin/net/corda/djvm/costing/RuntimeCostTest.kt b/djvm/src/test/kotlin/net/corda/djvm/costing/RuntimeCostTest.kt index 0e68806d33..0eb2d1c03e 100644 --- a/djvm/src/test/kotlin/net/corda/djvm/costing/RuntimeCostTest.kt +++ b/djvm/src/test/kotlin/net/corda/djvm/costing/RuntimeCostTest.kt @@ -4,6 +4,7 @@ import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThatExceptionOfType import org.junit.Test import sandbox.net.corda.djvm.costing.ThresholdViolationError +import kotlin.concurrent.thread class RuntimeCostTest { @@ -16,17 +17,13 @@ class RuntimeCostTest { @Test fun `cannot increment cost beyond threshold`() { - Thread { + thread(name = "Foo") { val cost = RuntimeCost(10) { "failed in ${it.name}" } assertThatExceptionOfType(ThresholdViolationError::class.java) .isThrownBy { cost.increment(11) } .withMessage("failed in Foo") assertThat(cost.value).isEqualTo(11) - }.apply { - name = "Foo" - start() - join() - } + }.join() } } diff --git a/djvm/src/test/kotlin/net/corda/djvm/execution/SandboxEnumTest.kt b/djvm/src/test/kotlin/net/corda/djvm/execution/SandboxEnumTest.kt new file mode 100644 index 0000000000..af78c3183b --- /dev/null +++ b/djvm/src/test/kotlin/net/corda/djvm/execution/SandboxEnumTest.kt @@ -0,0 +1,86 @@ +package net.corda.djvm.execution + +import net.corda.djvm.TestBase +import org.assertj.core.api.Assertions.assertThat +import org.junit.Test +import java.util.* +import java.util.function.Function + +class SandboxEnumTest : TestBase() { + @Test + fun `test enum inside sandbox`() = sandbox(DEFAULT) { + val contractExecutor = DeterministicSandboxExecutor>(configuration) + contractExecutor.run(0).apply { + assertThat(result).isEqualTo(arrayOf("ONE", "TWO", "THREE")) + } + } + + @Test + fun `return enum from sandbox`() = sandbox(DEFAULT) { + val contractExecutor = DeterministicSandboxExecutor(configuration) + contractExecutor.run("THREE").apply { + assertThat(result).isEqualTo(ExampleEnum.THREE) + } + } + + @Test + fun `test we can identify class as Enum`() = sandbox(DEFAULT) { + val contractExecutor = DeterministicSandboxExecutor(configuration) + contractExecutor.run(ExampleEnum.THREE).apply { + assertThat(result).isTrue() + } + } + + @Test + fun `test we can create EnumMap`() = sandbox(DEFAULT) { + val contractExecutor = DeterministicSandboxExecutor(configuration) + contractExecutor.run(ExampleEnum.TWO).apply { + assertThat(result).isEqualTo(1) + } + } + + @Test + fun `test we can create EnumSet`() = sandbox(DEFAULT) { + val contractExecutor = DeterministicSandboxExecutor(configuration) + contractExecutor.run(ExampleEnum.ONE).apply { + assertThat(result).isTrue() + } + } +} + + +class AssertEnum : Function { + override fun apply(input: ExampleEnum): Boolean { + return input::class.java.isEnum + } +} + +class TransformEnum : Function> { + override fun apply(input: Int): Array { + return ExampleEnum.values().map(ExampleEnum::name).toTypedArray() + } +} + +class FetchEnum : Function { + override fun apply(input: String): ExampleEnum { + return ExampleEnum.valueOf(input) + } +} + +class UseEnumMap : Function { + override fun apply(input: ExampleEnum): Int { + val map = EnumMap(ExampleEnum::class.java) + map[input] = input.name + return map.size + } +} + +class UseEnumSet : Function { + override fun apply(input: ExampleEnum): Boolean { + return EnumSet.allOf(ExampleEnum::class.java).contains(input) + } +} + +enum class ExampleEnum { + ONE, TWO, THREE +} \ No newline at end of file diff --git a/djvm/src/test/kotlin/net/corda/djvm/execution/SandboxExecutorTest.kt b/djvm/src/test/kotlin/net/corda/djvm/execution/SandboxExecutorTest.kt index 92fe59e159..a3919c964c 100644 --- a/djvm/src/test/kotlin/net/corda/djvm/execution/SandboxExecutorTest.kt +++ b/djvm/src/test/kotlin/net/corda/djvm/execution/SandboxExecutorTest.kt @@ -1,10 +1,19 @@ package net.corda.djvm.execution import foo.bar.sandbox.MyObject -import foo.bar.sandbox.testRandom +import foo.bar.sandbox.testClock import foo.bar.sandbox.toNumber import net.corda.djvm.TestBase import net.corda.djvm.analysis.Whitelist +import net.corda.djvm.Utilities +import net.corda.djvm.Utilities.throwContractConstraintViolation +import net.corda.djvm.Utilities.throwError +import net.corda.djvm.Utilities.throwOutOfMemoryError +import net.corda.djvm.Utilities.throwRuleViolationError +import net.corda.djvm.Utilities.throwStackOverflowError +import net.corda.djvm.Utilities.throwThreadDeath +import net.corda.djvm.Utilities.throwThresholdViolationError +import net.corda.djvm.Utilities.throwThrowable import net.corda.djvm.assertions.AssertionExtensions.withProblem import net.corda.djvm.rewiring.SandboxClassLoadingException import org.assertj.core.api.Assertions.assertThat @@ -13,8 +22,8 @@ import org.junit.Test import sandbox.net.corda.djvm.costing.ThresholdViolationError import sandbox.net.corda.djvm.rules.RuleViolationError import java.nio.file.Files -import java.util.* import java.util.function.Function +import java.util.stream.Collectors.* class SandboxExecutorTest : TestBase() { @@ -34,7 +43,7 @@ class SandboxExecutorTest : TestBase() { @Test fun `can load and execute contract`() = sandbox( - pinnedClasses = setOf(Transaction::class.java) + pinnedClasses = setOf(Transaction::class.java, Utilities::class.java) ) { val contractExecutor = DeterministicSandboxExecutor(configuration) val tx = Transaction(1) @@ -44,13 +53,13 @@ class SandboxExecutorTest : TestBase() { .withMessageContaining("Contract constraint violated") } - class Contract : Function { - override fun apply(input: Transaction?) { - throw IllegalArgumentException("Contract constraint violated") + class Contract : Function { + override fun apply(input: Transaction) { + throwContractConstraintViolation() } } - data class Transaction(val id: Int?) + data class Transaction(val id: Int) @Test fun `can load and execute code that overrides object hash code`() = sandbox(DEFAULT) { @@ -65,7 +74,11 @@ class SandboxExecutorTest : TestBase() { val obj = Object() val hash1 = obj.hashCode() val hash2 = obj.hashCode() - require(hash1 == hash2) + //require(hash1 == hash2) + // TODO: Replace require() once we have working exception support. + if (hash1 != hash2) { + throwError() + } return Object().hashCode() } } @@ -123,37 +136,37 @@ class SandboxExecutorTest : TestBase() { @Test fun `can detect illegal references in Kotlin meta-classes`() = sandbox(DEFAULT, ExecutionProfile.DEFAULT) { - val contractExecutor = DeterministicSandboxExecutor(configuration) + val contractExecutor = DeterministicSandboxExecutor(configuration) assertThatExceptionOfType(SandboxException::class.java) .isThrownBy { contractExecutor.run(0) } - .withCauseInstanceOf(RuleViolationError::class.java) - .withMessageContaining("Disallowed reference to reflection API") + .withCauseInstanceOf(NoSuchMethodError::class.java) + .withProblem("sandbox.java.lang.System.nanoTime()J") } - class TestKotlinMetaClasses : Function { - override fun apply(input: Int): Int { - val someNumber = testRandom() + class TestKotlinMetaClasses : Function { + override fun apply(input: Int): Long { + val someNumber = testClock() return "12345".toNumber() * someNumber } } @Test fun `cannot execute runnable that references non-deterministic code`() = sandbox(DEFAULT) { - val contractExecutor = DeterministicSandboxExecutor(configuration) + val contractExecutor = DeterministicSandboxExecutor(configuration) assertThatExceptionOfType(SandboxException::class.java) .isThrownBy { contractExecutor.run(0) } - .withCauseInstanceOf(RuleViolationError::class.java) - .withProblem("Disallowed reference to reflection API") + .withCauseInstanceOf(NoSuchMethodError::class.java) + .withProblem("sandbox.java.lang.System.currentTimeMillis()J") } - class TestNonDeterministicCode : Function { - override fun apply(input: Int): Int { - return Random().nextInt() + class TestNonDeterministicCode : Function { + override fun apply(input: Int): Long { + return System.currentTimeMillis() } } @Test - fun `cannot execute runnable that catches ThreadDeath`() = sandbox(DEFAULT) { + fun `cannot execute runnable that catches ThreadDeath`() = sandbox(DEFAULT, pinnedClasses = setOf(Utilities::class.java)) { TestCatchThreadDeath().apply { assertThat(apply(0)).isEqualTo(1) } @@ -167,7 +180,7 @@ class SandboxExecutorTest : TestBase() { class TestCatchThreadDeath : Function { override fun apply(input: Int): Int { return try { - throw ThreadDeath() + throwThreadDeath() } catch (exception: ThreadDeath) { 1 } @@ -175,7 +188,7 @@ class SandboxExecutorTest : TestBase() { } @Test - fun `cannot execute runnable that catches ThresholdViolationError`() = sandbox(DEFAULT) { + fun `cannot execute runnable that catches ThresholdViolationError`() = sandbox(DEFAULT, pinnedClasses = setOf(Utilities::class.java)) { TestCatchThresholdViolationError().apply { assertThat(apply(0)).isEqualTo(1) } @@ -190,7 +203,7 @@ class SandboxExecutorTest : TestBase() { class TestCatchThresholdViolationError : Function { override fun apply(input: Int): Int { return try { - throw ThresholdViolationError("Can't catch this!") + throwThresholdViolationError() } catch (exception: ThresholdViolationError) { 1 } @@ -198,7 +211,7 @@ class SandboxExecutorTest : TestBase() { } @Test - fun `cannot execute runnable that catches RuleViolationError`() = sandbox(DEFAULT) { + fun `cannot execute runnable that catches RuleViolationError`() = sandbox(DEFAULT, pinnedClasses = setOf(Utilities::class.java)) { TestCatchRuleViolationError().apply { assertThat(apply(0)).isEqualTo(1) } @@ -213,7 +226,7 @@ class SandboxExecutorTest : TestBase() { class TestCatchRuleViolationError : Function { override fun apply(input: Int): Int { return try { - throw RuleViolationError("Can't catch this!") + throwRuleViolationError() } catch (exception: RuleViolationError) { 1 } @@ -221,7 +234,7 @@ class SandboxExecutorTest : TestBase() { } @Test - fun `can catch Throwable`() = sandbox(DEFAULT) { + fun `can catch Throwable`() = sandbox(DEFAULT, pinnedClasses = setOf(Utilities::class.java)) { val contractExecutor = DeterministicSandboxExecutor(configuration) contractExecutor.run(1).apply { assertThat(result).isEqualTo(1) @@ -229,7 +242,7 @@ class SandboxExecutorTest : TestBase() { } @Test - fun `can catch Error`() = sandbox(DEFAULT) { + fun `can catch Error`() = sandbox(DEFAULT, pinnedClasses = setOf(Utilities::class.java)) { val contractExecutor = DeterministicSandboxExecutor(configuration) contractExecutor.run(2).apply { assertThat(result).isEqualTo(2) @@ -237,7 +250,7 @@ class SandboxExecutorTest : TestBase() { } @Test - fun `cannot catch ThreadDeath`() = sandbox(DEFAULT) { + fun `cannot catch ThreadDeath`() = sandbox(DEFAULT, pinnedClasses = setOf(Utilities::class.java)) { val contractExecutor = DeterministicSandboxExecutor(configuration) assertThatExceptionOfType(SandboxException::class.java) .isThrownBy { contractExecutor.run(3) } @@ -248,8 +261,8 @@ class SandboxExecutorTest : TestBase() { override fun apply(input: Int): Int { return try { when (input) { - 1 -> throw Throwable() - 2 -> throw Error() + 1 -> throwThrowable() + 2 -> throwError() else -> 0 } } catch (exception: Error) { @@ -264,20 +277,20 @@ class SandboxExecutorTest : TestBase() { override fun apply(input: Int): Int { return try { when (input) { - 1 -> throw Throwable() - 2 -> throw Error() + 1 -> throwThrowable() + 2 -> throwError() 3 -> try { - throw ThreadDeath() + throwThreadDeath() } catch (ex: ThreadDeath) { 3 } 4 -> try { - throw StackOverflowError("FAKE OVERFLOW!") + throwStackOverflowError() } catch (ex: StackOverflowError) { 4 } 5 -> try { - throw OutOfMemoryError("FAKE OOM!") + throwOutOfMemoryError() } catch (ex: OutOfMemoryError) { 5 } @@ -292,7 +305,7 @@ class SandboxExecutorTest : TestBase() { } @Test - fun `cannot catch stack-overflow error`() = sandbox(DEFAULT) { + fun `cannot catch stack-overflow error`() = sandbox(DEFAULT, pinnedClasses = setOf(Utilities::class.java)) { val contractExecutor = DeterministicSandboxExecutor(configuration) assertThatExceptionOfType(SandboxException::class.java) .isThrownBy { contractExecutor.run(4) } @@ -301,7 +314,7 @@ class SandboxExecutorTest : TestBase() { } @Test - fun `cannot catch out-of-memory error`() = sandbox(DEFAULT) { + fun `cannot catch out-of-memory error`() = sandbox(DEFAULT, pinnedClasses = setOf(Utilities::class.java)) { val contractExecutor = DeterministicSandboxExecutor(configuration) assertThatExceptionOfType(SandboxException::class.java) .isThrownBy { contractExecutor.run(5) } @@ -371,7 +384,7 @@ class SandboxExecutorTest : TestBase() { @Test fun `can load and execute code that uses notify()`() = sandbox(DEFAULT) { - val contractExecutor = DeterministicSandboxExecutor(configuration) + val contractExecutor = DeterministicSandboxExecutor(configuration) assertThatExceptionOfType(SandboxException::class.java) .isThrownBy { contractExecutor.run(1) } .withCauseInstanceOf(RuleViolationError::class.java) @@ -381,7 +394,7 @@ class SandboxExecutorTest : TestBase() { @Test fun `can load and execute code that uses notifyAll()`() = sandbox(DEFAULT) { - val contractExecutor = DeterministicSandboxExecutor(configuration) + val contractExecutor = DeterministicSandboxExecutor(configuration) assertThatExceptionOfType(SandboxException::class.java) .isThrownBy { contractExecutor.run(2) } .withCauseInstanceOf(RuleViolationError::class.java) @@ -391,7 +404,7 @@ class SandboxExecutorTest : TestBase() { @Test fun `can load and execute code that uses wait()`() = sandbox(DEFAULT) { - val contractExecutor = DeterministicSandboxExecutor(configuration) + val contractExecutor = DeterministicSandboxExecutor(configuration) assertThatExceptionOfType(SandboxException::class.java) .isThrownBy { contractExecutor.run(3) } .withCauseInstanceOf(RuleViolationError::class.java) @@ -401,7 +414,7 @@ class SandboxExecutorTest : TestBase() { @Test fun `can load and execute code that uses wait(long)`() = sandbox(DEFAULT) { - val contractExecutor = DeterministicSandboxExecutor(configuration) + val contractExecutor = DeterministicSandboxExecutor(configuration) assertThatExceptionOfType(SandboxException::class.java) .isThrownBy { contractExecutor.run(4) } .withCauseInstanceOf(RuleViolationError::class.java) @@ -411,7 +424,7 @@ class SandboxExecutorTest : TestBase() { @Test fun `can load and execute code that uses wait(long,int)`() = sandbox(DEFAULT) { - val contractExecutor = DeterministicSandboxExecutor(configuration) + val contractExecutor = DeterministicSandboxExecutor(configuration) assertThatExceptionOfType(SandboxException::class.java) .isThrownBy { contractExecutor.run(5) } .withCauseInstanceOf(RuleViolationError::class.java) @@ -421,13 +434,13 @@ class SandboxExecutorTest : TestBase() { @Test fun `code after forbidden APIs is intact`() = sandbox(DEFAULT) { - val contractExecutor = DeterministicSandboxExecutor(configuration) + val contractExecutor = DeterministicSandboxExecutor(configuration) assertThat(contractExecutor.run(0).result) .isEqualTo("unknown") } - class TestMonitors : Function { - override fun apply(input: Int): String { + class TestMonitors : Function { + override fun apply(input: Int): String? { return synchronized(this) { @Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN") val javaObject = this as java.lang.Object @@ -493,6 +506,73 @@ class SandboxExecutorTest : TestBase() { } } + @Test + fun `check building a string`() = sandbox(DEFAULT) { + val contractExecutor = DeterministicSandboxExecutor(configuration) + contractExecutor.run("Hello Sandbox!").apply { + assertThat(result) + .isEqualTo("SANDBOX: Boolean=true, Char='X', Integer=1234, Long=99999, Short=3200, Byte=101, String='Hello Sandbox!', Float=123.456, Double=987.6543") + } + } + + class TestStringBuilding : Function { + override fun apply(input: String?): String? { + return StringBuilder("SANDBOX") + .append(": Boolean=").append(true) + .append(", Char='").append('X') + .append("', Integer=").append(1234) + .append(", Long=").append(99999L) + .append(", Short=").append(3200.toShort()) + .append(", Byte=").append(101.toByte()) + .append(", String='").append(input) + .append("', Float=").append(123.456f) + .append(", Double=").append(987.6543) + .toString() + } + } + + @Test + fun `check System-arraycopy still works with Objects`() = sandbox(DEFAULT) { + val source = arrayOf("one", "two", "three") + assertThat(TestArrayCopy().apply(source)) + .isEqualTo(source) + .isNotSameAs(source) + + val contractExecutor = DeterministicSandboxExecutor, Array>(configuration) + contractExecutor.run(source).apply { + assertThat(result) + .isEqualTo(source) + .isNotSameAs(source) + } + } + + class TestArrayCopy : Function, Array> { + override fun apply(input: Array): Array { + val newArray = Array(input.size) { "" } + System.arraycopy(input, 0, newArray, 0, newArray.size) + return newArray + } + } + + @Test + fun `test System-arraycopy still works with CharArray`() = sandbox(DEFAULT) { + val source = CharArray(10) { '?' } + val contractExecutor = DeterministicSandboxExecutor(configuration) + contractExecutor.run(source).apply { + assertThat(result) + .isEqualTo(source) + .isNotSameAs(source) + } + } + + class TestCharArrayCopy : Function { + override fun apply(input: CharArray): CharArray { + val newArray = CharArray(input.size) { 'X' } + System.arraycopy(input, 0, newArray, 0, newArray.size) + return newArray + } + } + @Test fun `can load and execute class that has finalize`() = sandbox(DEFAULT) { assertThatExceptionOfType(UnsupportedOperationException::class.java) @@ -515,4 +595,152 @@ class SandboxExecutorTest : TestBase() { throw UnsupportedOperationException("Very Bad Thing") } } + + @Test + fun `can execute parallel stream`() = sandbox(DEFAULT) { + val contractExecutor = DeterministicSandboxExecutor(configuration) + contractExecutor.run("Pebble").apply { + assertThat(result).isEqualTo("Five,Four,One,Pebble,Three,Two") + } + } + + class TestParallelStream : Function { + override fun apply(input: String): String { + return listOf(input, "One", input, "Two", input, "Three", input, "Four", input, "Five") + .stream() + .distinct() + .sorted() + .collect(joining(",")) + } + } + + @Test + fun `users cannot load our sandboxed classes`() = sandbox(DEFAULT) { + val contractExecutor = DeterministicSandboxExecutor>(configuration) + assertThatExceptionOfType(SandboxException::class.java) + .isThrownBy { contractExecutor.run("java.lang.DJVM") } + .withCauseInstanceOf(ClassNotFoundException::class.java) + .withMessageContaining("java.lang.DJVM") + } + + @Test + fun `users can load sandboxed classes`() = sandbox(DEFAULT) { + val contractExecutor = DeterministicSandboxExecutor>(configuration) + contractExecutor.run("java.util.List").apply { + assertThat(result?.name).isEqualTo("sandbox.java.util.List") + } + } + + class TestClassForName : Function> { + override fun apply(input: String): Class<*> { + return Class.forName(input) + } + } + + @Test + fun `test case-insensitive string sorting`() = sandbox(DEFAULT) { + val contractExecutor = DeterministicSandboxExecutor, Array>(configuration) + contractExecutor.run(arrayOf("Zelda", "angela", "BOB", "betsy", "ALBERT")).apply { + assertThat(result).isEqualTo(arrayOf("ALBERT", "angela", "betsy", "BOB", "Zelda")) + } + } + + class CaseInsensitiveSort : Function, Array> { + override fun apply(input: Array): Array { + return listOf(*input).sortedWith(String.CASE_INSENSITIVE_ORDER).toTypedArray() + } + } + + @Test + fun `test unicode characters`() = sandbox(DEFAULT) { + val contractExecutor = DeterministicSandboxExecutor(configuration) + contractExecutor.run(0x01f600).apply { + assertThat(result).isEqualTo("EMOTICONS") + } + } + + class ExamineUnicodeBlock : Function { + override fun apply(codePoint: Int): String { + return Character.UnicodeBlock.of(codePoint).toString() + } + } + + @Test + fun `test unicode scripts`() = sandbox(DEFAULT) { + val contractExecutor = DeterministicSandboxExecutor(configuration) + contractExecutor.run("COMMON").apply { + assertThat(result).isEqualTo(Character.UnicodeScript.COMMON) + } + } + + class ExamineUnicodeScript : Function { + override fun apply(scriptName: String): Character.UnicodeScript? { + val script = Character.UnicodeScript.valueOf(scriptName) + return if (script::class.java.isEnum) script else null + } + } + + @Test + fun `test users cannot define new classes`() = sandbox(DEFAULT) { + val contractExecutor = DeterministicSandboxExecutor>(configuration) + assertThatExceptionOfType(SandboxException::class.java) + .isThrownBy { contractExecutor.run("sandbox.java.lang.DJVM") } + .withCauseInstanceOf(RuleViolationError::class.java) + .withMessageContaining("Disallowed reference to API;") + .withMessageContaining("java.lang.ClassLoader.defineClass") + } + + class DefineNewClass : Function> { + override fun apply(input: String): Class<*> { + val data = ByteArray(0) + val cl = object : ClassLoader(this::class.java.classLoader) { + fun define(): Class<*> { + return super.defineClass(input, data, 0, data.size) + } + } + return cl.define() + } + } + + @Test + fun `test users cannot load new classes`() = sandbox(DEFAULT) { + val contractExecutor = DeterministicSandboxExecutor>(configuration) + assertThatExceptionOfType(SandboxException::class.java) + .isThrownBy { contractExecutor.run("sandbox.java.lang.DJVM") } + .withCauseInstanceOf(RuleViolationError::class.java) + .withMessageContaining("Disallowed reference to API;") + .withMessageContaining("java.lang.ClassLoader.loadClass") + } + + class LoadNewClass : Function> { + override fun apply(input: String): Class<*> { + val cl = object : ClassLoader(this::class.java.classLoader) { + fun load(): Class<*> { + return super.loadClass(input) + } + } + return cl.load() + } + } + + @Test + fun `test users cannot lookup classes`() = sandbox(DEFAULT) { + val contractExecutor = DeterministicSandboxExecutor>(configuration) + assertThatExceptionOfType(SandboxException::class.java) + .isThrownBy { contractExecutor.run("sandbox.java.lang.DJVM") } + .withCauseInstanceOf(RuleViolationError::class.java) + .withMessageContaining("Disallowed reference to API;") + .withMessageContaining("java.lang.ClassLoader.findClass") + } + + class FindClass : Function> { + override fun apply(input: String): Class<*> { + val cl = object : ClassLoader(this::class.java.classLoader) { + fun find(): Class<*> { + return super.findClass(input) + } + } + return cl.find() + } + } } diff --git a/djvm/src/test/kotlin/net/corda/djvm/rewiring/ClassRewriterTest.kt b/djvm/src/test/kotlin/net/corda/djvm/rewiring/ClassRewriterTest.kt index bd7e86dac0..68b60da1cd 100644 --- a/djvm/src/test/kotlin/net/corda/djvm/rewiring/ClassRewriterTest.kt +++ b/djvm/src/test/kotlin/net/corda/djvm/rewiring/ClassRewriterTest.kt @@ -1,16 +1,14 @@ package net.corda.djvm.rewiring -import foo.bar.sandbox.A -import foo.bar.sandbox.B -import foo.bar.sandbox.Empty -import foo.bar.sandbox.StrictFloat +import foo.bar.sandbox.* import net.corda.djvm.TestBase import net.corda.djvm.assertions.AssertionExtensions.assertThat import net.corda.djvm.execution.ExecutionProfile -import org.assertj.core.api.Assertions.assertThatExceptionOfType +import org.assertj.core.api.Assertions.* import org.junit.Test import sandbox.net.corda.djvm.costing.ThresholdViolationError import java.nio.file.Paths +import java.util.* class ClassRewriterTest : TestBase() { @@ -102,4 +100,44 @@ class ClassRewriterTest : TestBase() { return input } } + + @Test + fun `can load class with constant fields`() = sandbox(DEFAULT) { + assertThat(loadClass()) + .hasClassName("sandbox.net.corda.djvm.rewiring.ObjectWithConstants") + .hasBeenModified() + } + + @Test + fun `test rewrite static method`() = sandbox(DEFAULT) { + assertThat(loadClass()) + .hasClassName("sandbox.java.util.Arrays") + .hasBeenModified() + } + + @Test + fun `test stitch new super-interface`() = sandbox(DEFAULT) { + assertThat(loadClass()) + .hasClassName("sandbox.java.lang.CharSequence") + .hasInterface("java.lang.CharSequence") + .hasBeenModified() + } + + @Test + fun `test class with stitched interface`() = sandbox(DEFAULT) { + assertThat(loadClass()) + .hasClassName("sandbox.java.lang.StringBuilder") + .hasInterface("sandbox.java.lang.CharSequence") + .hasBeenModified() + } } + +@Suppress("unused") +private object ObjectWithConstants { + const val MESSAGE = "Hello Sandbox!" + const val BIG_NUMBER = 99999L + const val NUMBER = 100 + const val CHAR = '?' + const val BYTE = 7f.toByte() + val DATA = Array(0) { "" } +} \ No newline at end of file diff --git a/djvm/src/test/kotlin/net/corda/djvm/source/SourceClassLoaderTest.kt b/djvm/src/test/kotlin/net/corda/djvm/source/SourceClassLoaderTest.kt index 85f4596f1f..401e0b5c0a 100644 --- a/djvm/src/test/kotlin/net/corda/djvm/source/SourceClassLoaderTest.kt +++ b/djvm/src/test/kotlin/net/corda/djvm/source/SourceClassLoaderTest.kt @@ -10,7 +10,7 @@ import java.nio.file.Path class SourceClassLoaderTest { - private val classResolver = ClassResolver(emptySet(), Whitelist.MINIMAL, "") + private val classResolver = ClassResolver(emptySet(), emptySet(), Whitelist.MINIMAL, "") @Test fun `can load class from Java's lang package when no files are provided to the class loader`() { From e3685f5e8119f3947e63b1151defa0e627f8ae97 Mon Sep 17 00:00:00 2001 From: Anthony Keenan Date: Thu, 11 Oct 2018 18:01:54 +0200 Subject: [PATCH 5/6] Make blobinspector not log to console by default (#4059) --- .../main/kotlin/net/corda/blobinspector/BlobInspector.kt | 8 -------- tools/blobinspector/src/main/resources/log4j2.xml | 5 +++-- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/tools/blobinspector/src/main/kotlin/net/corda/blobinspector/BlobInspector.kt b/tools/blobinspector/src/main/kotlin/net/corda/blobinspector/BlobInspector.kt index 8887247bf6..8670989fd1 100644 --- a/tools/blobinspector/src/main/kotlin/net/corda/blobinspector/BlobInspector.kt +++ b/tools/blobinspector/src/main/kotlin/net/corda/blobinspector/BlobInspector.kt @@ -52,14 +52,6 @@ class BlobInspector : CordaCliWrapper("blob-inspector", "Convert AMQP serialised override fun runProgram() = run(System.out) - override fun initLogging() { - if (verbose) { - loggingLevel = Level.TRACE - } - val loggingLevel = loggingLevel.name.toLowerCase(Locale.ENGLISH) - System.setProperty("logLevel", loggingLevel) // This property is referenced from the XML config file. - } - fun run(out: PrintStream): Int { val inputBytes = source!!.readBytes() val bytes = parseToBinaryRelaxed(inputFormatType, inputBytes) diff --git a/tools/blobinspector/src/main/resources/log4j2.xml b/tools/blobinspector/src/main/resources/log4j2.xml index 98b3648e6b..b7a8bfcd2f 100644 --- a/tools/blobinspector/src/main/resources/log4j2.xml +++ b/tools/blobinspector/src/main/resources/log4j2.xml @@ -1,7 +1,8 @@ - off + ${sys:consoleLogLevel:-error} + ${sys:defaultLogLevel:-info} @@ -9,7 +10,7 @@ - + From 8c41ae208da787fd59e084ae549659d504afa9c3 Mon Sep 17 00:00:00 2001 From: Andrius Dagys Date: Thu, 11 Oct 2018 10:45:43 +0100 Subject: [PATCH 6/6] CORDA-535: Remove BFT-Smart related migration parts --- .../migration/node-notary.changelog-init.xml | 11 ----------- .../migration/node-notary.changelog-pkey.xml | 5 ----- .../resources/migration/node-notary.changelog-v1.xml | 1 - 3 files changed, 17 deletions(-) diff --git a/node/src/main/resources/migration/node-notary.changelog-init.xml b/node/src/main/resources/migration/node-notary.changelog-init.xml index 8d0f1bcb6f..7bb5a20c52 100644 --- a/node/src/main/resources/migration/node-notary.changelog-init.xml +++ b/node/src/main/resources/migration/node-notary.changelog-init.xml @@ -5,17 +5,6 @@ xsi:schemaLocation="http://www.liquibase.org/xml/ns/dbchangelog-ext http://www.liquibase.org/xml/ns/dbchangelog/dbchangelog-ext.xsd http://www.liquibase.org/xml/ns/dbchangelog http://www.liquibase.org/xml/ns/dbchangelog/dbchangelog-3.5.xsd" logicalFilePath="migration/node-services.changelog-init.xml"> - - - - - - - - - - - diff --git a/node/src/main/resources/migration/node-notary.changelog-pkey.xml b/node/src/main/resources/migration/node-notary.changelog-pkey.xml index c4d7c59376..8130c3b156 100644 --- a/node/src/main/resources/migration/node-notary.changelog-pkey.xml +++ b/node/src/main/resources/migration/node-notary.changelog-pkey.xml @@ -8,9 +8,4 @@ - - - - \ No newline at end of file diff --git a/node/src/main/resources/migration/node-notary.changelog-v1.xml b/node/src/main/resources/migration/node-notary.changelog-v1.xml index 3002133bad..4a7fc0e723 100644 --- a/node/src/main/resources/migration/node-notary.changelog-v1.xml +++ b/node/src/main/resources/migration/node-notary.changelog-v1.xml @@ -6,7 +6,6 @@ logicalFilePath="migration/node-services.changelog-init.xml"> - \ No newline at end of file