diff --git a/src/main/kotlin/core/Services.kt b/src/main/kotlin/core/Services.kt
index 3ce6616056..44e035a08b 100644
--- a/src/main/kotlin/core/Services.kt
+++ b/src/main/kotlin/core/Services.kt
@@ -55,6 +55,7 @@ interface WalletService {
  * interface if/when one is developed.
  */
 interface KeyManagementService {
+    /** Returns a snapshot of the current pubkey->privkey mapping. */
     val keys: Map<PublicKey, PrivateKey>
 
     fun toPrivate(publicKey: PublicKey) = keys[publicKey] ?: throw IllegalStateException("No private key known for requested public key")
diff --git a/src/test/kotlin/core/MockServices.kt b/src/test/kotlin/core/MockServices.kt
index e6407ef535..6f9f8009ad 100644
--- a/src/test/kotlin/core/MockServices.kt
+++ b/src/test/kotlin/core/MockServices.kt
@@ -58,11 +58,20 @@ object MockIdentityService : IdentityService {
     override fun partyFromKey(key: PublicKey): Party? = TEST_KEYS_TO_CORP_MAP[key]
 }
 
-class MockKeyManagementService(
-        override val keys: Map<PublicKey, PrivateKey>,
-        val nextKeys: MutableList<KeyPair> = arrayListOf(generateKeyPair())
-) : KeyManagementService {
-    override fun freshKey() = nextKeys.removeAt(nextKeys.lastIndex)
+class MockKeyManagementService(vararg initialKeys: KeyPair) : KeyManagementService {
+    override val keys: MutableMap<PublicKey, PrivateKey>
+
+    init {
+        keys = initialKeys.map { it.public to it.private }.toMap(HashMap())
+    }
+
+    val nextKeys = LinkedList<KeyPair>()
+
+    override fun freshKey(): KeyPair {
+        val k = nextKeys.poll() ?: generateKeyPair()
+        keys[k.public] = k.private
+        return k
+    }
 }
 
 class MockWalletService(val states: List<StateAndRef<OwnableState>>) : WalletService {
diff --git a/src/test/kotlin/core/messaging/TwoPartyTradeProtocolTests.kt b/src/test/kotlin/core/messaging/TwoPartyTradeProtocolTests.kt
index d88992fc4f..3c4b92b115 100644
--- a/src/test/kotlin/core/messaging/TwoPartyTradeProtocolTests.kt
+++ b/src/test/kotlin/core/messaging/TwoPartyTradeProtocolTests.kt
@@ -55,7 +55,7 @@ class TwoPartyTradeProtocolTests : TestWithInMemoryNetwork() {
             val alicesServices = MockServices(net = alicesNode)
             val bobsServices = MockServices(
                     wallet = MockWalletService(bobsWallet.states),
-                    keyManagement = MockKeyManagementService(mapOf(BOB to BOB_KEY.private)),
+                    keyManagement = MockKeyManagementService(BOB_KEY),
                     net = bobsNode,
                     storage = MockStorageService()
             )
@@ -105,7 +105,7 @@ class TwoPartyTradeProtocolTests : TestWithInMemoryNetwork() {
             val alicesServices = MockServices(wallet = null, keyManagement = null, net = alicesNode)
             var bobsServices = MockServices(
                     wallet = MockWalletService(wallet.states),
-                    keyManagement = MockKeyManagementService(mapOf(BOB to BOB_KEY.private)),
+                    keyManagement = MockKeyManagementService(BOB_KEY),
                     net = bobsNode,
                     storage = bobsStorage
             )
@@ -198,7 +198,7 @@ class TwoPartyTradeProtocolTests : TestWithInMemoryNetwork() {
             )
             val bobsServices = MockServices(
                     wallet = MockWalletService(bobsWallet.states),
-                    keyManagement = MockKeyManagementService(mapOf(BOB to BOB_KEY.private)),
+                    keyManagement = MockKeyManagementService(BOB_KEY),
                     net = bobsNode,
                     storage = MockStorageService(mapOf("validated-transactions" to "bob"))
             )
@@ -286,7 +286,7 @@ class TwoPartyTradeProtocolTests : TestWithInMemoryNetwork() {
             val alicesServices = MockServices(net = alicesNode)
             val bobsServices = MockServices(
                     wallet = MockWalletService(bobsWallet.states),
-                    keyManagement = MockKeyManagementService(mapOf(BOB to BOB_KEY.private)),
+                    keyManagement = MockKeyManagementService(BOB_KEY),
                     net = bobsNode,
                     storage = MockStorageService(mapOf("validated-transactions" to "bob"))
             )
@@ -336,7 +336,7 @@ class TwoPartyTradeProtocolTests : TestWithInMemoryNetwork() {
             val alicesServices = MockServices(net = alicesNode)
             val bobsServices = MockServices(
                     wallet = MockWalletService(bobsWallet.states),
-                    keyManagement = MockKeyManagementService(mapOf(BOB to BOB_KEY.private)),
+                    keyManagement = MockKeyManagementService(BOB_KEY),
                     net = bobsNode,
                     storage = MockStorageService(mapOf("validated-transactions" to "bob"))
             )
diff --git a/src/test/kotlin/core/node/E2ETestWalletServiceTest.kt b/src/test/kotlin/core/node/E2ETestWalletServiceTest.kt
index 3a32bea0ef..6b0c78db08 100644
--- a/src/test/kotlin/core/node/E2ETestWalletServiceTest.kt
+++ b/src/test/kotlin/core/node/E2ETestWalletServiceTest.kt
@@ -16,17 +16,18 @@ import core.ServiceHub
 import core.testutils.ALICE
 import core.testutils.ALICE_KEY
 import org.junit.Test
-import java.security.KeyPair
 import java.util.*
 import kotlin.test.assertEquals
 
 class E2ETestWalletServiceTest {
+    val kms = MockKeyManagementService()
     val services: ServiceHub = MockServices(
-        keyManagement = MockKeyManagementService(emptyMap(), arrayListOf<KeyPair>(ALICE_KEY, ALICE_KEY, ALICE_KEY))
+        keyManagement = kms
     )
 
     @Test fun splits() {
         val wallet = E2ETestWalletService(services)
+        kms.nextKeys += Array(3) { ALICE_KEY }
         // Fix the PRNG so that we get the same splits every time.
         wallet.fillWithSomeTestCash(100.DOLLARS, 3, 3, Random(0L))