From 5f4401d16a1dcbeda688b30a1b118afed4a0d9a5 Mon Sep 17 00:00:00 2001
From: Dan Newton <danknewton@hotmail.com>
Date: Wed, 13 May 2020 15:17:51 +0100
Subject: [PATCH] NOTICK Non-database error handling in `withEntityManager`
 (#6239)

When a non-database exception is thrown out of a `withEntityManager`
block, always check if the session needs to be rolled back.

This means if a database error is caught and a new non-database error is
thrown out of the `withEntityManager` block, the transaction is still
rolled back. The flow can then continue progressing as normal.
---
 .../corda/node/flows/FlowEntityManagerTest.kt | 126 ++++++++++++++++++
 .../net/corda/node/internal/AbstractNode.kt   |   7 +-
 2 files changed, 129 insertions(+), 4 deletions(-)

diff --git a/node/src/integration-test/kotlin/net/corda/node/flows/FlowEntityManagerTest.kt b/node/src/integration-test/kotlin/net/corda/node/flows/FlowEntityManagerTest.kt
index f8c68f0cb6..add2317a64 100644
--- a/node/src/integration-test/kotlin/net/corda/node/flows/FlowEntityManagerTest.kt
+++ b/node/src/integration-test/kotlin/net/corda/node/flows/FlowEntityManagerTest.kt
@@ -32,6 +32,7 @@ import net.corda.testing.driver.driver
 import org.hibernate.exception.ConstraintViolationException
 import org.junit.Before
 import org.junit.Test
+import java.lang.RuntimeException
 import java.sql.Connection
 import java.util.concurrent.ExecutorService
 import java.util.concurrent.Executors
@@ -39,6 +40,7 @@ import java.util.concurrent.Semaphore
 import javax.persistence.PersistenceException
 import kotlin.test.assertEquals
 
+@Suppress("TooGenericExceptionCaught", "TooGenericExceptionThrown")
 class FlowEntityManagerTest : AbstractFlowEntityManagerTest() {
 
     @Before
@@ -364,6 +366,62 @@ class FlowEntityManagerTest : AbstractFlowEntityManagerTest() {
         }
     }
 
+    @Test(timeout = 300_000)
+    fun `non database error caught outside entity manager does not save entities`() {
+        var counter = 0
+        StaffedFlowHospital.onFlowDischarged.add { _, _ -> ++counter }
+
+        driver(DriverParameters(notarySpecs = emptyList(), startNodesInProcess = true)) {
+
+            val alice = startNode(providedName = ALICE_NAME).getOrThrow()
+            alice.rpc.startFlow(::EntityManagerSaveAndThrowNonDatabaseErrorFlow)
+                .returnValue.getOrThrow(30.seconds)
+            assertEquals(0, counter)
+            val entities = alice.rpc.startFlow(::GetCustomEntities).returnValue.getOrThrow()
+            assertEquals(0, entities.size)
+
+        }
+    }
+
+    @Test(timeout = 300_000)
+    fun `non database error caught outside entity manager after flush occurs does save entities`() {
+        var counter = 0
+        StaffedFlowHospital.onFlowDischarged.add { _, _ -> ++counter }
+
+        driver(DriverParameters(notarySpecs = emptyList(), startNodesInProcess = true)) {
+
+            val alice = startNode(providedName = ALICE_NAME).getOrThrow()
+            alice.rpc.startFlow(::EntityManagerSaveFlushAndThrowNonDatabaseErrorFlow)
+                .returnValue.getOrThrow(30.seconds)
+            assertEquals(0, counter)
+            val entities = alice.rpc.startFlow(::GetCustomEntities).returnValue.getOrThrow()
+            assertEquals(3, entities.size)
+        }
+    }
+
+    @Test(timeout = 300_000)
+    fun `database error caught inside entity manager non database exception thrown and caught outside entity manager should not save entities`() {
+        var counter = 0
+        StaffedFlowHospital.onFlowDischarged.add { _, _ -> ++counter }
+
+        driver(DriverParameters(notarySpecs = emptyList(), startNodesInProcess = true)) {
+
+            val alice = startNode(providedName = ALICE_NAME).getOrThrow()
+            alice.rpc.expectFlowSuccessAndAssertCreatedEntities(
+                flow = ::EntityManagerCatchDatabaseErrorInsideEntityManagerThrowNonDatabaseErrorAndCatchOutsideFlow,
+                commitStatus = CommitStatus.NO_INTERMEDIATE_COMMIT,
+                numberOfDischarges = 0,
+                numberOfExpectedEntities = 1
+            )
+            alice.rpc.expectFlowSuccessAndAssertCreatedEntities(
+                flow = ::EntityManagerCatchDatabaseErrorInsideEntityManagerThrowNonDatabaseErrorAndCatchOutsideFlow,
+                commitStatus = CommitStatus.INTERMEDIATE_COMMIT,
+                numberOfDischarges = 0,
+                numberOfExpectedEntities = 1
+            )
+        }
+    }
+
     @StartableByRPC
     class EntityManagerSaveEntitiesWithoutAFlushFlow : FlowLogic<Unit>() {
 
@@ -706,6 +764,74 @@ class FlowEntityManagerTest : AbstractFlowEntityManagerTest() {
         }
     }
 
+    @StartableByRPC
+    class EntityManagerSaveAndThrowNonDatabaseErrorFlow : FlowLogic<Unit>() {
+
+        @Suspendable
+        override fun call() {
+            try {
+                serviceHub.withEntityManager {
+                    persist(entityWithIdOne)
+                    persist(entityWithIdTwo)
+                    persist(entityWithIdThree)
+                    throw RuntimeException("die")
+                }
+            } catch (e: RuntimeException) {
+                logger.info("Caught error")
+            }
+            sleep(1.millis)
+        }
+    }
+
+    @StartableByRPC
+    class EntityManagerSaveFlushAndThrowNonDatabaseErrorFlow : FlowLogic<Unit>() {
+
+        @Suspendable
+        override fun call() {
+            try {
+                serviceHub.withEntityManager {
+                    persist(entityWithIdOne)
+                    persist(entityWithIdTwo)
+                    persist(entityWithIdThree)
+                    flush()
+                    throw RuntimeException("die")
+                }
+            } catch (e: RuntimeException) {
+                logger.info("Caught error")
+            }
+            sleep(1.millis)
+        }
+    }
+
+    @StartableByRPC
+    class EntityManagerCatchDatabaseErrorInsideEntityManagerThrowNonDatabaseErrorAndCatchOutsideFlow(private val commitStatus: CommitStatus) :
+        FlowLogic<Unit>() {
+
+        @Suspendable
+        override fun call() {
+            serviceHub.withEntityManager {
+                persist(entityWithIdOne)
+            }
+            if (commitStatus == CommitStatus.INTERMEDIATE_COMMIT) {
+                sleep(1.millis)
+            }
+            try {
+                serviceHub.withEntityManager {
+                    persist(anotherEntityWithIdOne)
+                    try {
+                        flush()
+                    } catch (e: PersistenceException) {
+                        logger.info("Caught the exception!")
+                    }
+                    throw RuntimeException("die")
+                }
+            } catch (e: RuntimeException) {
+                logger.info("Caught error")
+            }
+            sleep(1.millis)
+        }
+    }
+
     @CordaService
     class MyService(private val services: AppServiceHub) : SingletonSerializeAsToken() {
 
diff --git a/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt b/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt
index 462d310c78..83e35807a6 100644
--- a/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt
+++ b/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt
@@ -180,7 +180,7 @@ import java.sql.Savepoint
 import java.time.Clock
 import java.time.Duration
 import java.time.format.DateTimeParseException
-import java.util.*
+import java.util.Properties
 import java.util.concurrent.ExecutorService
 import java.util.concurrent.Executors
 import java.util.concurrent.LinkedBlockingQueue
@@ -190,8 +190,6 @@ import java.util.concurrent.TimeUnit.MINUTES
 import java.util.concurrent.TimeUnit.SECONDS
 import java.util.function.Consumer
 import javax.persistence.EntityManager
-import javax.persistence.PersistenceException
-import kotlin.collections.ArrayList
 
 /**
  * A base node implementation that can be customised either for production (with real implementations that do real
@@ -1166,6 +1164,7 @@ abstract class AbstractNode<S>(val configuration: NodeConfiguration,
 
         override fun jdbcSession(): Connection = database.createSession()
 
+        @Suppress("TooGenericExceptionCaught")
         override fun <T : Any?> withEntityManager(block: EntityManager.() -> T): T {
             return database.transaction(useErrorHandler = false) {
                 session.flush()
@@ -1180,7 +1179,7 @@ abstract class AbstractNode<S>(val configuration: NodeConfiguration,
                                 connection.rollback(savepoint)
                             }
                         }
-                    } catch (e: PersistenceException) {
+                    } catch (e: Exception) {
                         if (manager.transaction.rollbackOnly) {
                             connection.rollback(savepoint)
                         }