diff --git a/src/main/java/org/apache/commons/math3/analysis/interpolation/HermiteInterpolator.java b/src/main/java/org/apache/commons/math3/analysis/interpolation/HermiteInterpolator.java index 780e6c94c..a248f40c1 100644 --- a/src/main/java/org/apache/commons/math3/analysis/interpolation/HermiteInterpolator.java +++ b/src/main/java/org/apache/commons/math3/analysis/interpolation/HermiteInterpolator.java @@ -17,13 +17,15 @@ package org.apache.commons.math3.analysis.interpolation; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; -import org.apache.commons.math3.analysis.DifferentiableUnivariateVectorFunction; -import org.apache.commons.math3.analysis.UnivariateVectorFunction; +import org.apache.commons.math3.analysis.differentiation.DerivativeStructure; +import org.apache.commons.math3.analysis.differentiation.UnivariateDifferentiableVectorFunction; import org.apache.commons.math3.analysis.polynomials.PolynomialFunction; -import org.apache.commons.math3.exception.MathIllegalArgumentException; -import org.apache.commons.math3.exception.MathIllegalStateException; +import org.apache.commons.math3.exception.MathArithmeticException; +import org.apache.commons.math3.exception.NoDataException; +import org.apache.commons.math3.exception.ZeroException; import org.apache.commons.math3.exception.util.LocalizedFormats; import org.apache.commons.math3.util.ArithmeticUtils; @@ -31,7 +33,7 @@ import org.apache.commons.math3.util.ArithmeticUtils; *

* The interpolation polynomials match all sample points, including both values * and provided derivatives. There is one polynomial for each component of - * the values vector. All polynomial have the same degree. The degree of the + * the values vector. All polynomials have the same degree. The degree of the * polynomials depends on the number of points and number of derivatives at each * point. For example the interpolation polynomials for n sample points without * any derivatives all have degree n-1. The interpolation polynomials for n @@ -49,7 +51,7 @@ import org.apache.commons.math3.util.ArithmeticUtils; * @version $Id$ * @since 3.1 */ -public class HermiteInterpolator implements DifferentiableUnivariateVectorFunction { +public class HermiteInterpolator implements UnivariateDifferentiableVectorFunction { /** Sample abscissae. */ private final List abscissae; @@ -82,11 +84,13 @@ public class HermiteInterpolator implements DifferentiableUnivariateVectorFuncti * (if only one row is passed, it is the value, if two rows are * passed the first one is the value and the second the derivative * and so on) - * @exception MathIllegalArgumentException if the abscissa is equals to a previously - * added sample point + * @exception ZeroException if the abscissa difference between added point + * and a previous point is zero (i.e. the two points are at same abscissa) + * @exception MathArithmeticException if the number of derivatives is larger + * than 20, which prevents computation of a factorial */ public void addSamplePoint(final double x, final double[] ... value) - throws MathIllegalArgumentException { + throws ZeroException, MathArithmeticException { for (int i = 0; i < value.length; ++i) { @@ -106,8 +110,7 @@ public class HermiteInterpolator implements DifferentiableUnivariateVectorFuncti final double[] bottom1 = bottomDiagonal.get(n - (j + 1)); final double inv = 1.0 / (x - abscissae.get(n - (j + 1))); if (Double.isInfinite(inv)) { - throw new MathIllegalArgumentException(LocalizedFormats.DUPLICATED_ABSCISSA_DIVISION_BY_ZERO, - x); + throw new ZeroException(LocalizedFormats.DUPLICATED_ABSCISSA_DIVISION_BY_ZERO, x); } for (int k = 0; k < y.length; ++k) { bottom1[k] = inv * (bottom0[k] - bottom1[k]); @@ -127,10 +130,10 @@ public class HermiteInterpolator implements DifferentiableUnivariateVectorFuncti /** Compute the interpolation polynomials. * @return interpolation polynomials array - * @exception MathIllegalStateException if sample is empty + * @exception NoDataException if sample is empty */ public PolynomialFunction[] getPolynomials() - throws MathIllegalStateException { + throws NoDataException { // safety check checkInterpolation(); @@ -165,10 +168,10 @@ public class HermiteInterpolator implements DifferentiableUnivariateVectorFuncti *

* @param x interpolation abscissa * @return interpolated value - * @exception MathIllegalStateException if sample is empty + * @exception NoDataException if sample is empty */ public double[] value(double x) - throws MathIllegalStateException { + throws NoDataException { // safety check checkInterpolation(); @@ -188,59 +191,46 @@ public class HermiteInterpolator implements DifferentiableUnivariateVectorFuncti } - /** Interpolate first derivative at a specified abscissa. + /** Interpolate value at a specified abscissa. *

- * Calling this method is equivalent to call the {@link PolynomialFunction#value(double) - * value} methods of the derivatives of all polynomials returned by {@link - * #getPolynomials() getPolynomials}, except it builds neither the intermediate - * polynomials nor their derivatives, so this method is faster and numerically more stable. + * Calling this method is equivalent to call the {@link + * PolynomialFunction#value(DerivativeStructure) value} methods of all polynomials + * returned by {@link #getPolynomials() getPolynomials}, except it does not build the + * intermediate polynomials, so this method is faster and numerically more stable. *

* @param x interpolation abscissa - * @return interpolated derivative - * @exception MathIllegalStateException if sample is empty + * @return interpolated value + * @exception NoDataException if sample is empty */ - public double[] derivative(double x) - throws MathIllegalStateException { + public DerivativeStructure[] value(final DerivativeStructure x) + throws NoDataException { // safety check checkInterpolation(); - final double[] derivative = new double[topDiagonal.get(0).length]; - double valueCoeff = 1; - double derivativeCoeff = 0; + final DerivativeStructure[] value = new DerivativeStructure[topDiagonal.get(0).length]; + Arrays.fill(value, x.getField().getZero()); + DerivativeStructure valueCoeff = x.getField().getOne(); for (int i = 0; i < topDiagonal.size(); ++i) { double[] dividedDifference = topDiagonal.get(i); - for (int k = 0; k < derivative.length; ++k) { - derivative[k] += dividedDifference[k] * derivativeCoeff; + for (int k = 0; k < value.length; ++k) { + value[k] = value[k].add(valueCoeff.multiply(dividedDifference[k])); } - final double deltaX = x - abscissae.get(i); - derivativeCoeff = valueCoeff + derivativeCoeff * deltaX; - valueCoeff *= deltaX; + final DerivativeStructure deltaX = x.subtract(abscissae.get(i)); + valueCoeff = valueCoeff.multiply(deltaX); } - return derivative; + return value; } - /** {@inheritDoc}} */ - public UnivariateVectorFunction derivative() { - return new UnivariateVectorFunction() { - - /** {@inheritDoc}} */ - public double[] value(double x) { - return derivative(x); - } - - }; - } - /** Check interpolation can be performed. - * @exception MathIllegalStateException if interpolation cannot be performed + * @exception NoDataException if interpolation cannot be performed * because sample is empty */ - private void checkInterpolation() throws MathIllegalStateException { + private void checkInterpolation() throws NoDataException { if (abscissae.isEmpty()) { - throw new MathIllegalStateException(LocalizedFormats.EMPTY_INTERPOLATION_SAMPLE); + throw new NoDataException(LocalizedFormats.EMPTY_INTERPOLATION_SAMPLE); } } diff --git a/src/test/java/org/apache/commons/math3/analysis/interpolation/HermiteInterpolatorTest.java b/src/test/java/org/apache/commons/math3/analysis/interpolation/HermiteInterpolatorTest.java index 46228e1f8..7cf5a4499 100644 --- a/src/test/java/org/apache/commons/math3/analysis/interpolation/HermiteInterpolatorTest.java +++ b/src/test/java/org/apache/commons/math3/analysis/interpolation/HermiteInterpolatorTest.java @@ -18,7 +18,9 @@ package org.apache.commons.math3.analysis.interpolation; import java.util.Random; +import org.apache.commons.math3.analysis.differentiation.DerivativeStructure; import org.apache.commons.math3.analysis.polynomials.PolynomialFunction; +import org.apache.commons.math3.exception.NoDataException; import org.apache.commons.math3.util.FastMath; import org.junit.Assert; import org.junit.Test; @@ -30,8 +32,9 @@ public class HermiteInterpolatorTest { HermiteInterpolator interpolator = new HermiteInterpolator(); interpolator.addSamplePoint(0.0, new double[] { 0.0 }); for (double x = -10; x < 10; x += 1.0) { - Assert.assertEquals(0.0, interpolator.value(x)[0], 1.0e-15); - Assert.assertEquals(0.0, interpolator.derivative(x)[0], 1.0e-15); + DerivativeStructure y = interpolator.value(new DerivativeStructure(1, 1, 0, x))[0]; + Assert.assertEquals(0.0, y.getValue(), 1.0e-15); + Assert.assertEquals(0.0, y.getPartialDerivative(1), 1.0e-15); } checkPolynomial(new PolynomialFunction(new double[] { 0.0 }), interpolator.getPolynomials()[0]); @@ -44,8 +47,9 @@ public class HermiteInterpolatorTest { interpolator.addSamplePoint(1.0, new double[] { 0.0 }); interpolator.addSamplePoint(2.0, new double[] { 0.0 }); for (double x = -10; x < 10; x += 1.0) { - Assert.assertEquals((x - 1.0) * (x - 2.0), interpolator.value(x)[0], 1.0e-15); - Assert.assertEquals(2 * x - 3.0, interpolator.derivative(x)[0], 1.0e-15); + DerivativeStructure y = interpolator.value(new DerivativeStructure(1, 1, 0, x))[0]; + Assert.assertEquals((x - 1.0) * (x - 2.0), y.getValue(), 1.0e-15); + Assert.assertEquals(2 * x - 3.0, y.getPartialDerivative(1), 1.0e-15); } checkPolynomial(new PolynomialFunction(new double[] { 2.0, -3.0, 1.0 }), interpolator.getPolynomials()[0]); @@ -58,11 +62,13 @@ public class HermiteInterpolatorTest { interpolator.addSamplePoint(1.0, new double[] { 4.0 }); interpolator.addSamplePoint(2.0, new double[] { 5.0 }, new double[] { 2.0 }); Assert.assertEquals(4, interpolator.getPolynomials()[0].degree()); - Assert.assertEquals(1.0, interpolator.value(0.0)[0], 1.0e-15); - Assert.assertEquals(2.0, interpolator.derivative(0.0)[0], 1.0e-15); + DerivativeStructure y0 = interpolator.value(new DerivativeStructure(1, 1, 0, 0.0))[0]; + Assert.assertEquals(1.0, y0.getValue(), 1.0e-15); + Assert.assertEquals(2.0, y0.getPartialDerivative(1), 1.0e-15); Assert.assertEquals(4.0, interpolator.value(1.0)[0], 1.0e-15); - Assert.assertEquals(5.0, interpolator.value(2.0)[0], 1.0e-15); - Assert.assertEquals(2.0, interpolator.derivative(2.0)[0], 1.0e-15); + DerivativeStructure y2 = interpolator.value(new DerivativeStructure(1, 1, 0, 2.0))[0]; + Assert.assertEquals(5.0, y2.getValue(), 1.0e-15); + Assert.assertEquals(2.0, y2.getPartialDerivative(1), 1.0e-15); checkPolynomial(new PolynomialFunction(new double[] { 1.0, 2.0, 4.0, -4.0, 1.0 }), interpolator.getPolynomials()[0]); } @@ -138,12 +144,11 @@ public class HermiteInterpolatorTest { } for (double x = 0; x < 2; x += 0.1) { - double[] values = interpolator.value(x); - double[] derivatives = interpolator.derivative(x); - Assert.assertEquals(p.length, values.length); + DerivativeStructure[] y = interpolator.value(new DerivativeStructure(1, 1, 0, x)); + Assert.assertEquals(p.length, y.length); for (int k = 0; k < p.length; ++k) { - Assert.assertEquals(p[k].value(x), values[k], 1.0e-8 * FastMath.abs(p[k].value(x))); - Assert.assertEquals(pPrime[k].value(x), derivatives[k], 4.0e-8 * FastMath.abs(p[k].value(x))); + Assert.assertEquals(p[k].value(x), y[k].getValue(), 1.0e-8 * FastMath.abs(p[k].value(x))); + Assert.assertEquals(pPrime[k].value(x), y[k].getPartialDerivative(1), 4.0e-8 * FastMath.abs(p[k].value(x))); } } @@ -162,8 +167,10 @@ public class HermiteInterpolatorTest { interpolator.addSamplePoint(x, new double[] { FastMath.sin(x) }); } for (double x = 0.1; x <= 2.9; x += 0.01) { - Assert.assertEquals(FastMath.sin(x), interpolator.value(x)[0], 3.5e-5); - Assert.assertEquals(FastMath.cos(x), interpolator.derivative(x)[0], 1.3e-4); + DerivativeStructure y = interpolator.value(new DerivativeStructure(1, 2, 0, x))[0]; + Assert.assertEquals( FastMath.sin(x), y.getValue(), 3.5e-5); + Assert.assertEquals( FastMath.cos(x), y.getPartialDerivative(1), 1.3e-4); + Assert.assertEquals(-FastMath.sin(x), y.getPartialDerivative(2), 2.9e-3); } } @@ -174,8 +181,9 @@ public class HermiteInterpolatorTest { interpolator.addSamplePoint(x, new double[] { FastMath.sqrt(x) }); } for (double x = 1.1; x < 3.5; x += 0.01) { - Assert.assertEquals(FastMath.sqrt(x), interpolator.value(x)[0], 1.5e-4); - Assert.assertEquals(0.5 / FastMath.sqrt(x), interpolator.derivative(x)[0], 8.5e-4); + DerivativeStructure y = interpolator.value(new DerivativeStructure(1, 1, 0, x))[0]; + Assert.assertEquals(FastMath.sqrt(x), y.getValue(), 1.5e-4); + Assert.assertEquals(0.5 / FastMath.sqrt(x), y.getPartialDerivative(1), 8.5e-4); } } @@ -188,11 +196,12 @@ public class HermiteInterpolatorTest { interpolator.addSamplePoint( 0, new double[] { 1 }, new double[] { 0 }, new double[] { 0 }); interpolator.addSamplePoint( 1, new double[] { 2 }, new double[] { 8 }, new double[] { 56 }); for (double x = -1.0; x <= 1.0; x += 0.125) { + DerivativeStructure y = interpolator.value(new DerivativeStructure(1, 1, 0, x))[0]; double x2 = x * x; double x4 = x2 * x2; double x8 = x4 * x4; - Assert.assertEquals(x8 + 1, interpolator.value(x)[0], 1.0e-15); - Assert.assertEquals(8 * x4 * x2 * x, interpolator.derivative(x)[0], 1.0e-15); + Assert.assertEquals(x8 + 1, y.getValue(), 1.0e-15); + Assert.assertEquals(8 * x4 * x2 * x, y.getPartialDerivative(1), 1.0e-15); } checkPolynomial(new PolynomialFunction(new double[] { 1, 0, 0, 0, 0, 0, 0, 0, 1 }), interpolator.getPolynomials()[0]); @@ -203,8 +212,9 @@ public class HermiteInterpolatorTest { HermiteInterpolator interpolator = new HermiteInterpolator(); interpolator.addSamplePoint(0, new double[] { 1 }, new double[] { 1 }, new double[] { 2 }); for (double x = -1.0; x <= 1.0; x += 0.125) { - Assert.assertEquals(1 + x * (1 + x), interpolator.value(x)[0], 1.0e-15); - Assert.assertEquals(1 + 2 * x, interpolator.derivative(x)[0], 1.0e-15); + DerivativeStructure y = interpolator.value(new DerivativeStructure(1, 1, 0, x))[0]; + Assert.assertEquals(1 + x * (1 + x), y.getValue(), 1.0e-15); + Assert.assertEquals(1 + 2 * x, y.getPartialDerivative(1), 1.0e-15); } checkPolynomial(new PolynomialFunction(new double[] { 1, 1, 1 }), interpolator.getPolynomials()[0]); @@ -218,7 +228,7 @@ public class HermiteInterpolatorTest { return new PolynomialFunction(coeff); } - @Test(expected=IllegalStateException.class) + @Test(expected=NoDataException.class) public void testEmptySample() { new HermiteInterpolator().value(0.0); }