mirror of
https://github.com/corda/corda.git
synced 2024-12-23 23:02:29 +00:00
commit
c80a00f07e
85
core/src/main/kotlin/core/math/Interpolators.kt
Normal file
85
core/src/main/kotlin/core/math/Interpolators.kt
Normal file
@ -0,0 +1,85 @@
|
||||
package core.math
|
||||
|
||||
import java.util.*
|
||||
|
||||
/**
|
||||
* Interpolates values between the given data points using a [SplineFunction].
|
||||
*
|
||||
* Implementation uses the Natural Cubic Spline algorithm as described in
|
||||
* R. L. Burden and J. D. Faires (2011), *Numerical Analysis*. 9th ed. Boston, MA: Brooks/Cole, Cengage Learning. p149-150.
|
||||
*/
|
||||
class CubicSplineInterpolator(private val xs: DoubleArray, private val ys: DoubleArray) {
|
||||
init {
|
||||
require(xs.size == ys.size) { "x and y dimensions should match: ${xs.size} != ${ys.size}" }
|
||||
require(xs.size >= 3) { "At least 3 data points are required for interpolation, received: ${xs.size}" }
|
||||
}
|
||||
|
||||
private val splineFunction by lazy { computeSplineFunction() }
|
||||
|
||||
fun interpolate(x: Double): Double {
|
||||
require(x >= xs.first() && x <= xs.last()) { "Can't interpolate below ${xs.first()} or above ${xs.last()}" }
|
||||
return splineFunction.getValue(x)
|
||||
}
|
||||
|
||||
private fun computeSplineFunction(): SplineFunction {
|
||||
val n = xs.size - 1
|
||||
|
||||
// Coefficients of polynomial
|
||||
val b = DoubleArray(n) // linear
|
||||
val c = DoubleArray(n + 1) // quadratic
|
||||
val d = DoubleArray(n) // cubic
|
||||
|
||||
// Helpers
|
||||
val h = DoubleArray(n)
|
||||
val g = DoubleArray(n)
|
||||
|
||||
for (i in 0..n - 1)
|
||||
h[i] = xs[i + 1] - xs[i]
|
||||
for (i in 1..n - 1)
|
||||
g[i] = 3 / h[i] * (ys[i + 1] - ys[i]) - 3 / h[i - 1] * (ys[i] - ys[i - 1])
|
||||
|
||||
// Solve tridiagonal linear system (using Crout Factorization)
|
||||
val m = DoubleArray(n)
|
||||
val z = DoubleArray(n)
|
||||
for (i in 1..n - 1) {
|
||||
val l = 2 * (xs[i + 1] - xs[i - 1]) - h[i - 1] * m[i - 1]
|
||||
m[i] = h[i]/l
|
||||
z[i] = (g[i] - h[i - 1] * z[i - 1]) / l
|
||||
}
|
||||
for (j in n - 1 downTo 0) {
|
||||
c[j] = z[j] - m[j] * c[j + 1]
|
||||
b[j] = (ys[j + 1] - ys[j]) / h[j] - h[j] * (c[j + 1] + 2.0 * c[j]) / 3.0
|
||||
d[j] = (c[j + 1] - c[j]) / (3.0 * h[j])
|
||||
}
|
||||
|
||||
val segmentMap = TreeMap<Double, Polynomial>()
|
||||
for (i in 0..n - 1) {
|
||||
val coefficients = doubleArrayOf(ys[i], b[i], c[i], d[i])
|
||||
segmentMap.put(xs[i], Polynomial(coefficients))
|
||||
}
|
||||
return SplineFunction(segmentMap)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents a polynomial function of arbitrary degree
|
||||
* @param coefficients polynomial coefficients in the order of degree (constant first, followed by higher degree term coefficients)
|
||||
*/
|
||||
class Polynomial(private val coefficients: DoubleArray) {
|
||||
private val reversedCoefficients = coefficients.reversed()
|
||||
|
||||
fun getValue(x: Double) = reversedCoefficients.fold(0.0, { result, c -> result * x + c })
|
||||
}
|
||||
|
||||
/**
|
||||
* A *spline* is function piecewise-defined by polynomial functions.
|
||||
* Points at which polynomial pieces connect are known as *knots*.
|
||||
*
|
||||
* @param segmentMap a mapping between a knot and the polynomial that covers the subsequent interval
|
||||
*/
|
||||
class SplineFunction(private val segmentMap: TreeMap<Double, Polynomial>) {
|
||||
fun getValue(x: Double): Double {
|
||||
val (knot, polynomial) = segmentMap.floorEntry(x)
|
||||
return polynomial.getValue(x - knot)
|
||||
}
|
||||
}
|
43
core/src/test/kotlin/core/math/InterpolatorsTest.kt
Normal file
43
core/src/test/kotlin/core/math/InterpolatorsTest.kt
Normal file
@ -0,0 +1,43 @@
|
||||
package core.math
|
||||
|
||||
import org.junit.Assert
|
||||
import org.junit.Test
|
||||
import kotlin.test.assertEquals
|
||||
import kotlin.test.assertFailsWith
|
||||
|
||||
class InterpolatorsTest {
|
||||
|
||||
@Test
|
||||
fun `throws when key to interpolate is outside the data set`() {
|
||||
val xs = doubleArrayOf(1.0, 2.0, 4.0, 5.0)
|
||||
val interpolator = CubicSplineInterpolator(xs, ys = xs)
|
||||
assertFailsWith<IllegalArgumentException> { interpolator.interpolate(0.0) }
|
||||
assertFailsWith<IllegalArgumentException> { interpolator.interpolate(6.0) }
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `throws when data set is less than 3 points`() {
|
||||
val xs = doubleArrayOf(1.0, 2.0)
|
||||
assertFailsWith<IllegalArgumentException> { CubicSplineInterpolator(xs, ys = xs) }
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `returns existing value when key is in data set`() {
|
||||
val xs = doubleArrayOf(1.0, 2.0, 4.0, 5.0)
|
||||
val interpolatedValue = CubicSplineInterpolator(xs, ys = xs).interpolate(2.0)
|
||||
assertEquals(2.0, interpolatedValue)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `interpolates missing values correctly`() {
|
||||
val xs = doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0)
|
||||
val ys = doubleArrayOf(2.0, 4.0, 5.0, 11.0, 10.0)
|
||||
val toInterpolate = doubleArrayOf(1.5, 2.5, 2.8, 3.3, 3.7, 4.3, 4.7)
|
||||
// Expected values generated using R's splinefun (package stats v3.2.4), "natural" method
|
||||
val expected = doubleArrayOf(3.28, 4.03, 4.37, 6.7, 9.46, 11.5, 10.91)
|
||||
|
||||
val interpolator = CubicSplineInterpolator(xs, ys)
|
||||
val actual = toInterpolate.map { interpolator.interpolate(it) }.toDoubleArray()
|
||||
Assert.assertArrayEquals(expected, actual, 0.01)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user