diff --git a/core/src/main/kotlin/net/corda/core/utilities/ProgressTracker.kt b/core/src/main/kotlin/net/corda/core/utilities/ProgressTracker.kt index ffe765d274..96a43e80e7 100644 --- a/core/src/main/kotlin/net/corda/core/utilities/ProgressTracker.kt +++ b/core/src/main/kotlin/net/corda/core/utilities/ProgressTracker.kt @@ -2,6 +2,7 @@ package net.corda.core.utilities import net.corda.core.DeleteForDJVM import net.corda.core.internal.STRUCTURAL_STEP_PREFIX +import net.corda.core.internal.warnOnce import net.corda.core.serialization.CordaSerializable import rx.Observable import rx.Subscription @@ -34,6 +35,10 @@ import java.util.* @DeleteForDJVM class ProgressTracker(vararg inputSteps: Step) { + private companion object { + private val log = contextLogger() + } + @CordaSerializable @DeleteForDJVM sealed class Change(val progressTracker: ProgressTracker) { @@ -55,6 +60,11 @@ class ProgressTracker(vararg inputSteps: Step) { */ @CordaSerializable open class Step(open val label: String) { + private fun definitionLocation(): String = Exception().stackTrace.first { it.className != ProgressTracker.Step::class.java.name }.let { "${it.className}:${it.lineNumber}" } + + // Required when Steps with the same name are defined in multiple places. + private val discriminator: String = definitionLocation() + open val changes: Observable get() = Observable.empty() open fun childProgressTracker(): ProgressTracker? = null /** @@ -63,6 +73,17 @@ class ProgressTracker(vararg inputSteps: Step) { * Even if empty the basic details (i.e. label) of the step will be recorded for audit purposes. */ open val extraAuditData: Map get() = emptyMap() + + override fun equals(other: Any?) = when (other) { + is Step -> this.label == other.label && this.discriminator == other.discriminator + else -> false + } + + override fun hashCode(): Int { + var result = label.hashCode() + result = 31 * result + discriminator.hashCode() + return result + } } // Sentinel objects. Overrides equals() to survive process restarts and serialization. @@ -89,7 +110,12 @@ class ProgressTracker(vararg inputSteps: Step) { /** * The steps in this tracker, same as the steps passed to the constructor but with UNSTARTED and DONE inserted. */ - val steps = arrayOf(UNSTARTED, STARTING, *inputSteps, DONE) + val steps = arrayOf(UNSTARTED, STARTING, *inputSteps, DONE).also { stepsArray -> + val labels = stepsArray.map { it.label } + if (labels.toSet().size < labels.size) { + log.warnOnce("Found ProgressTracker Step(s) with the same label: ${labels.groupBy { it }.filter { it.value.size > 1 }.map { it.key }}") + } + } private var _allStepsCache: List> = _allSteps() @@ -137,7 +163,6 @@ class ProgressTracker(vararg inputSteps: Step) { } } - init { steps.forEach { configureChildTrackerForStep(it) diff --git a/core/src/test/kotlin/net/corda/core/utilities/ProgressTrackerTest.kt b/core/src/test/kotlin/net/corda/core/utilities/ProgressTrackerTest.kt index 13769fcb0f..fa2569a3cb 100644 --- a/core/src/test/kotlin/net/corda/core/utilities/ProgressTrackerTest.kt +++ b/core/src/test/kotlin/net/corda/core/utilities/ProgressTrackerTest.kt @@ -1,13 +1,25 @@ package net.corda.core.utilities +import net.corda.core.serialization.internal.checkpointDeserialize +import net.corda.core.serialization.internal.checkpointSerialize +import net.corda.core.utilities.ProgressTrackerTest.NonSingletonSteps.first +import net.corda.core.utilities.ProgressTrackerTest.NonSingletonSteps.first2 +import net.corda.testing.core.internal.CheckpointSerializationEnvironmentRule import org.assertj.core.api.Assertions.assertThat import org.junit.Before +import org.junit.Rule import org.junit.Test import java.util.* import kotlin.test.assertEquals import kotlin.test.assertFails +import kotlin.test.assertNotEquals class ProgressTrackerTest { + + @Rule + @JvmField + val testCheckpointSerialization = CheckpointSerializationEnvironmentRule() + object SimpleSteps { object ONE : ProgressTracker.Step("one") object TWO : ProgressTracker.Step("two") @@ -92,7 +104,7 @@ class ProgressTrackerTest { assertEquals(pt2.currentStep, ProgressTracker.UNSTARTED) assertEquals(ProgressTracker.STARTING, pt2.nextStep()) assertEquals(ChildSteps.AYY, pt2.nextStep()) - assertEquals((stepNotification.last as ProgressTracker.Change.Position).newStep, ChildSteps.AYY) + assertEquals((stepNotification.last as ProgressTracker.Change.Position).newStep, ChildSteps.AYY) assertEquals(ChildSteps.BEE, pt2.nextStep()) } @@ -112,7 +124,7 @@ class ProgressTrackerTest { stepsTreeNotification += it } - fun assertCurrentStepsTree(index:Int, step: ProgressTracker.Step) { + fun assertCurrentStepsTree(index: Int, step: ProgressTracker.Step) { assertEquals(index, pt.stepsTreeIndex) assertEquals(step, allSteps[pt.stepsTreeIndex].second) } @@ -169,7 +181,7 @@ class ProgressTrackerTest { assertThat(stepsIndexNotifications).containsExactlyElementsOf(listOf(0, 1, 4, 7)) assertThat(stepsTreeNotification).hasSize(3) // The initial tree state, plus one per update } - + @Test fun `structure changes are pushed down when progress trackers are added`() { pt.setChildProgressTracker(SimpleSteps.TWO, pt2) @@ -186,7 +198,7 @@ class ProgressTrackerTest { stepsTreeNotification += it } - fun assertCurrentStepsTree(index:Int, step: ProgressTracker.Step) { + fun assertCurrentStepsTree(index: Int, step: ProgressTracker.Step) { assertEquals(index, pt.stepsTreeIndex) assertEquals(step.label, stepsTreeNotification.last()[pt.stepsTreeIndex].second) } @@ -223,7 +235,7 @@ class ProgressTrackerTest { stepsTreeNotification += it } - fun assertCurrentStepsTree(index:Int, step: ProgressTracker.Step) { + fun assertCurrentStepsTree(index: Int, step: ProgressTracker.Step) { assertEquals(index, pt.stepsTreeIndex) assertEquals(step.label, stepsTreeNotification.last()[pt.stepsTreeIndex].second) } @@ -273,7 +285,7 @@ class ProgressTrackerTest { pt.nextStep() pt.nextStep() pt.nextStep() - pt.changes.subscribe { steps.add(it.toString())} + pt.changes.subscribe { steps.add(it.toString()) } pt.nextStep() pt.nextStep() pt.nextStep() @@ -290,7 +302,7 @@ class ProgressTrackerTest { pt.setChildProgressTracker(SimpleSteps.TWO, pt3) val thirdStepLabels = pt.allStepsLabels - pt.stepsTreeChanges.subscribe { stepTreeNotifications.add(it)} + pt.stepsTreeChanges.subscribe { stepTreeNotifications.add(it) } // Should have one notification for original tree, then one for each time it changed. assertEquals(3, stepTreeNotifications.size) @@ -320,4 +332,40 @@ class ProgressTrackerTest { fun `cannot assign step not belonging to this progress tracker`() { assertFails { pt.currentStep = BabySteps.UNOS } } + + object NonSingletonSteps { + val first = ProgressTracker.Step("first") + val second = ProgressTracker.Step("second") + val first2 = ProgressTracker.Step("first") + fun tracker() = ProgressTracker(first, second, first2) + } + + @Test + fun `Serializing and deserializing a tracker maintains equality`() { + val step = NonSingletonSteps.first + val recreatedStep = step + .checkpointSerialize(testCheckpointSerialization.checkpointSerializationContext) + .checkpointDeserialize(testCheckpointSerialization.checkpointSerializationContext) + assertEquals(step, recreatedStep) + } + + @Test + fun `can assign a recreated equal step`() { + val tracker = NonSingletonSteps.tracker() + val recreatedStep = first + .checkpointSerialize(testCheckpointSerialization.checkpointSerializationContext) + .checkpointDeserialize(testCheckpointSerialization.checkpointSerializationContext) + tracker.currentStep = recreatedStep + } + + @Test + fun `Steps with the same label defined in different places are not equal`() { + val one = ProgressTracker.Step("one") + assertNotEquals(one, SimpleSteps.ONE) + } + + @Test + fun `Steps with the same label defined in the same place are also not equal`() { + assertNotEquals(first, first2) + } }