diff --git a/src/main/java/org/apache/commons/math3/analysis/polynomials/PolynomialFunction.java b/src/main/java/org/apache/commons/math3/analysis/polynomials/PolynomialFunction.java index 3089bacaa..0894546cc 100644 --- a/src/main/java/org/apache/commons/math3/analysis/polynomials/PolynomialFunction.java +++ b/src/main/java/org/apache/commons/math3/analysis/polynomials/PolynomialFunction.java @@ -25,6 +25,8 @@ import org.apache.commons.math3.exception.NullArgumentException; import org.apache.commons.math3.analysis.DifferentiableUnivariateFunction; import org.apache.commons.math3.analysis.UnivariateFunction; import org.apache.commons.math3.analysis.ParametricUnivariateFunction; +import org.apache.commons.math3.analysis.differentiation.DerivativeStructure; +import org.apache.commons.math3.analysis.differentiation.UnivariateDifferentiable; import org.apache.commons.math3.util.FastMath; import org.apache.commons.math3.util.MathUtils; @@ -36,7 +38,7 @@ import org.apache.commons.math3.util.MathUtils; * * @version $Id$ */ -public class PolynomialFunction implements DifferentiableUnivariateFunction, Serializable { +public class PolynomialFunction implements UnivariateDifferentiable, DifferentiableUnivariateFunction, Serializable { /** * Serialization identifier */ @@ -137,6 +139,27 @@ public class PolynomialFunction implements DifferentiableUnivariateFunction, Ser return result; } + + /** {@inheritDoc} + * @since 3.1 + * @throws NoDataException if {@code coefficients} is empty. + * @throws NullArgumentException if {@code coefficients} is {@code null}. + */ + public DerivativeStructure value(final DerivativeStructure t) + throws NullArgumentException, NoDataException { + MathUtils.checkNotNull(coefficients); + int n = coefficients.length; + if (n == 0) { + throw new NoDataException(LocalizedFormats.EMPTY_POLYNOMIALS_COEFFICIENTS_ARRAY); + } + DerivativeStructure result = + new DerivativeStructure(t.getFreeParameters(), t.getOrder(), coefficients[n - 1]); + for (int j = n - 2; j >= 0; j--) { + result = result.multiply(t).add(coefficients[j]); + } + return result; + } + /** * Add a polynomial to the instance. * diff --git a/src/main/java/org/apache/commons/math3/analysis/polynomials/PolynomialSplineFunction.java b/src/main/java/org/apache/commons/math3/analysis/polynomials/PolynomialSplineFunction.java index b25ac2d53..ff34f379e 100644 --- a/src/main/java/org/apache/commons/math3/analysis/polynomials/PolynomialSplineFunction.java +++ b/src/main/java/org/apache/commons/math3/analysis/polynomials/PolynomialSplineFunction.java @@ -21,6 +21,8 @@ import java.util.Arrays; import org.apache.commons.math3.util.MathArrays; import org.apache.commons.math3.analysis.DifferentiableUnivariateFunction; import org.apache.commons.math3.analysis.UnivariateFunction; +import org.apache.commons.math3.analysis.differentiation.DerivativeStructure; +import org.apache.commons.math3.analysis.differentiation.UnivariateDifferentiable; import org.apache.commons.math3.exception.OutOfRangeException; import org.apache.commons.math3.exception.NumberIsTooSmallException; import org.apache.commons.math3.exception.DimensionMismatchException; @@ -61,7 +63,7 @@ import org.apache.commons.math3.exception.util.LocalizedFormats; * * @version $Id$ */ -public class PolynomialSplineFunction implements DifferentiableUnivariateFunction { +public class PolynomialSplineFunction implements UnivariateDifferentiable, DifferentiableUnivariateFunction { /** * Spline segment interval delimiters (knots). * Size is n + 1 for n segments. @@ -168,6 +170,28 @@ public class PolynomialSplineFunction implements DifferentiableUnivariateFunctio return new PolynomialSplineFunction(knots, derivativePolynomials); } + + /** {@inheritDoc} + * @since 3.1 + */ + public DerivativeStructure value(final DerivativeStructure t) { + final double t0 = t.getValue(); + if (t0 < knots[0] || t0 > knots[n]) { + throw new OutOfRangeException(t0, knots[0], knots[n]); + } + int i = Arrays.binarySearch(knots, t0); + if (i < 0) { + i = -i - 2; + } + // This will handle the case where t is the last knot value + // There are only n-1 polynomials, so if t is the last knot + // then we will use the last polynomial to calculate the value. + if ( i >= polynomials.length ) { + i--; + } + return polynomials[i].value(t.subtract(knots[i])); + } + /** * Get the number of spline segments. * It is also the number of polynomials and the number of knot points - 1.