Converters for univariate and multivariate differentiable functions.
JIRA: MATH-1143
This commit is contained in:
parent
cb21480cb1
commit
613afdb0c3
|
@ -54,6 +54,9 @@ If the output is not quite correct, check for invisible trailing spaces!
|
|||
</release>
|
||||
|
||||
<release version="4.0" date="XXXX-XX-XX" description="">
|
||||
<action dev="luc" type="fix" issue="MATH-1143">
|
||||
Added helper methods to FunctionUtils for univariate and multivariate differentiable functions conversion.
|
||||
</action>
|
||||
<action dev="tn" type="fix" issue="MATH-964">
|
||||
Removed unused package private class PollardRho in package primes.
|
||||
</action>
|
||||
|
|
|
@ -18,12 +18,14 @@
|
|||
package org.apache.commons.math4.analysis;
|
||||
|
||||
import org.apache.commons.math4.analysis.differentiation.DerivativeStructure;
|
||||
import org.apache.commons.math4.analysis.differentiation.MultivariateDifferentiableFunction;
|
||||
import org.apache.commons.math4.analysis.differentiation.UnivariateDifferentiableFunction;
|
||||
import org.apache.commons.math4.analysis.function.Identity;
|
||||
import org.apache.commons.math4.exception.DimensionMismatchException;
|
||||
import org.apache.commons.math4.exception.NotStrictlyPositiveException;
|
||||
import org.apache.commons.math4.exception.NumberIsTooLargeException;
|
||||
import org.apache.commons.math4.exception.util.LocalizedFormats;
|
||||
import org.apache.commons.math4.util.MathArrays;
|
||||
|
||||
/**
|
||||
* Utilities for manipulating function objects.
|
||||
|
@ -337,4 +339,206 @@ public class FunctionUtils {
|
|||
return s;
|
||||
}
|
||||
|
||||
/** Convert regular functions to {@link UnivariateDifferentiableFunction}.
|
||||
* <p>
|
||||
* This method handle the case with one free parameter and several derivatives.
|
||||
* For the case with several free parameters and only first order derivatives,
|
||||
* see {@link #toDifferentiable(MultivariateFunction, MultivariateVectorFunction)}.
|
||||
* There are no direct support for intermediate cases, with several free parameters
|
||||
* and order 2 or more derivatives, as is would be difficult to specify all the
|
||||
* cross derivatives.
|
||||
* </p>
|
||||
* <p>
|
||||
* Note that the derivatives are expected to be computed only with respect to the
|
||||
* raw parameter x of the base function, i.e. they are df/dx, df<sup>2</sup>/dx<sup>2</sup>, ...
|
||||
* Even if the built function is later used in a composition like f(sin(t)), the provided
|
||||
* derivatives should <em>not</em> apply the composition with sine and its derivatives by
|
||||
* themselves. The composition will be done automatically here and the result will properly
|
||||
* contain f(sin(t)), df(sin(t))/dt, df<sup>2</sup>(sin(t))/dt<sup>2</sup> despite the
|
||||
* provided derivatives functions know nothing about the sine function.
|
||||
* </p>
|
||||
* @param f base function f(x)
|
||||
* @param derivatives derivatives of the base function, in increasing differentiation order
|
||||
* @return a differentiable function with value and all specified derivatives
|
||||
* @see #toDifferentiable(MultivariateFunction, MultivariateVectorFunction)
|
||||
* @see #derivative(UnivariateDifferentiableFunction, int)
|
||||
*/
|
||||
public static UnivariateDifferentiableFunction toDifferentiable(final UnivariateFunction f,
|
||||
final UnivariateFunction ... derivatives) {
|
||||
|
||||
return new UnivariateDifferentiableFunction() {
|
||||
|
||||
/** {@inheritDoc} */
|
||||
@Override
|
||||
public double value(final double x) {
|
||||
return f.value(x);
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
@Override
|
||||
public DerivativeStructure value(final DerivativeStructure x) {
|
||||
if (x.getOrder() > derivatives.length) {
|
||||
throw new NumberIsTooLargeException(x.getOrder(), derivatives.length, true);
|
||||
}
|
||||
final double[] packed = new double[x.getOrder() + 1];
|
||||
packed[0] = f.value(x.getValue());
|
||||
for (int i = 0; i < x.getOrder(); ++i) {
|
||||
packed[i + 1] = derivatives[i].value(x.getValue());
|
||||
}
|
||||
return x.compose(packed);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
/** Convert regular functions to {@link MultivariateDifferentiableFunction}.
|
||||
* <p>
|
||||
* This method handle the case with several free parameters and only first order derivatives.
|
||||
* For the case with one free parameter and several derivatives,
|
||||
* see {@link #toDifferentiable(UnivariateFunction, UnivariateFunction...)}.
|
||||
* There are no direct support for intermediate cases, with several free parameters
|
||||
* and order 2 or more derivatives, as is would be difficult to specify all the
|
||||
* cross derivatives.
|
||||
* </p>
|
||||
* <p>
|
||||
* Note that the gradient is expected to be computed only with respect to the
|
||||
* raw parameter x of the base function, i.e. it is df/dx<sub>1</sub>, df/dx<sub>2</sub>, ...
|
||||
* Even if the built function is later used in a composition like f(sin(t), cos(t)), the provided
|
||||
* gradient should <em>not</em> apply the composition with sine or cosine and their derivative by
|
||||
* itself. The composition will be done automatically here and the result will properly
|
||||
* contain f(sin(t), cos(t)), df(sin(t), cos(t))/dt despite the provided derivatives functions
|
||||
* know nothing about the sine or cosine functions.
|
||||
* </p>
|
||||
* @param f base function f(x)
|
||||
* @param gradient gradient of the base function
|
||||
* @return a differentiable function with value and gradient
|
||||
* @see #toDifferentiable(UnivariateFunction, UnivariateFunction...)
|
||||
* @see #derivative(MultivariateDifferentiableFunction, int[])
|
||||
*/
|
||||
public static MultivariateDifferentiableFunction toDifferentiable(final MultivariateFunction f,
|
||||
final MultivariateVectorFunction gradient) {
|
||||
|
||||
return new MultivariateDifferentiableFunction() {
|
||||
|
||||
/** {@inheritDoc} */
|
||||
@Override
|
||||
public double value(final double[] point) {
|
||||
return f.value(point);
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
@Override
|
||||
public DerivativeStructure value(final DerivativeStructure[] point) {
|
||||
|
||||
// set up the input parameters
|
||||
final double[] dPoint = new double[point.length];
|
||||
for (int i = 0; i < point.length; ++i) {
|
||||
dPoint[i] = point[i].getValue();
|
||||
if (point[i].getOrder() > 1) {
|
||||
throw new NumberIsTooLargeException(point[i].getOrder(), 1, true);
|
||||
}
|
||||
}
|
||||
|
||||
// evaluate regular functions
|
||||
final double v = f.value(dPoint);
|
||||
final double[] dv = gradient.value(dPoint);
|
||||
if (dv.length != point.length) {
|
||||
// the gradient function is inconsistent
|
||||
throw new DimensionMismatchException(dv.length, point.length);
|
||||
}
|
||||
|
||||
// build the combined derivative
|
||||
final int parameters = point[0].getFreeParameters();
|
||||
final double[] partials = new double[point.length];
|
||||
final double[] packed = new double[parameters + 1];
|
||||
packed[0] = v;
|
||||
final int orders[] = new int[parameters];
|
||||
for (int i = 0; i < parameters; ++i) {
|
||||
|
||||
// we differentiate once with respect to parameter i
|
||||
orders[i] = 1;
|
||||
for (int j = 0; j < point.length; ++j) {
|
||||
partials[j] = point[j].getPartialDerivative(orders);
|
||||
}
|
||||
orders[i] = 0;
|
||||
|
||||
// compose partial derivatives
|
||||
packed[i + 1] = MathArrays.linearCombination(dv, partials);
|
||||
|
||||
}
|
||||
|
||||
return new DerivativeStructure(parameters, 1, packed);
|
||||
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
/** Convert an {@link UnivariateDifferentiableFunction} to an
|
||||
* {@link UnivariateFunction} computing n<sup>th</sup> order derivative.
|
||||
* <p>
|
||||
* This converter is only a convenience method. Beware computing only one derivative does
|
||||
* not save any computation as the original function will really be called under the hood.
|
||||
* The derivative will be extracted from the full {@link DerivativeStructure} result.
|
||||
* </p>
|
||||
* @param f original function, with value and all its derivatives
|
||||
* @param order of the derivative to extract
|
||||
* @return function computing the derivative at required order
|
||||
* @see #derivative(MultivariateDifferentiableFunction, int[])
|
||||
* @see #toDifferentiable(UnivariateFunction, UnivariateFunction...)
|
||||
*/
|
||||
public static UnivariateFunction derivative(final UnivariateDifferentiableFunction f, final int order) {
|
||||
return new UnivariateFunction() {
|
||||
|
||||
/** {@inheritDoc} */
|
||||
@Override
|
||||
public double value(final double x) {
|
||||
final DerivativeStructure dsX = new DerivativeStructure(1, order, 0, x);
|
||||
return f.value(dsX).getPartialDerivative(order);
|
||||
}
|
||||
|
||||
};
|
||||
}
|
||||
|
||||
/** Convert an {@link MultivariateDifferentiableFunction} to an
|
||||
* {@link MultivariateFunction} computing n<sup>th</sup> order derivative.
|
||||
* <p>
|
||||
* This converter is only a convenience method. Beware computing only one derivative does
|
||||
* not save any computation as the original function will really be called under the hood.
|
||||
* The derivative will be extracted from the full {@link DerivativeStructure} result.
|
||||
* </p>
|
||||
* @param f original function, with value and all its derivatives
|
||||
* @param orders of the derivative to extract, for each free parameters
|
||||
* @return function computing the derivative at required order
|
||||
* @see #derivative(UnivariateDifferentiableFunction, int)
|
||||
* @see #toDifferentiable(MultivariateFunction, MultivariateVectorFunction)
|
||||
*/
|
||||
public static MultivariateFunction derivative(final MultivariateDifferentiableFunction f, final int[] orders) {
|
||||
return new MultivariateFunction() {
|
||||
|
||||
/** {@inheritDoc} */
|
||||
@Override
|
||||
public double value(final double[] point) {
|
||||
|
||||
// the maximum differentiation order is the sum of all orders
|
||||
int sumOrders = 0;
|
||||
for (final int order : orders) {
|
||||
sumOrders += order;
|
||||
}
|
||||
|
||||
// set up the input parameters
|
||||
final DerivativeStructure[] dsPoint = new DerivativeStructure[point.length];
|
||||
for (int i = 0; i < point.length; ++i) {
|
||||
dsPoint[i] = new DerivativeStructure(point.length, sumOrders, i, point[i]);
|
||||
}
|
||||
|
||||
return f.value(dsPoint).getPartialDerivative(orders);
|
||||
|
||||
}
|
||||
|
||||
};
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
package org.apache.commons.math4.analysis;
|
||||
|
||||
import org.apache.commons.math4.analysis.differentiation.DerivativeStructure;
|
||||
import org.apache.commons.math4.analysis.differentiation.MultivariateDifferentiableFunction;
|
||||
import org.apache.commons.math4.analysis.differentiation.UnivariateDifferentiableFunction;
|
||||
import org.apache.commons.math4.analysis.function.Add;
|
||||
import org.apache.commons.math4.analysis.function.Constant;
|
||||
|
@ -35,6 +36,7 @@ import org.apache.commons.math4.analysis.function.Pow;
|
|||
import org.apache.commons.math4.analysis.function.Power;
|
||||
import org.apache.commons.math4.analysis.function.Sin;
|
||||
import org.apache.commons.math4.analysis.function.Sinc;
|
||||
import org.apache.commons.math4.exception.DimensionMismatchException;
|
||||
import org.apache.commons.math4.exception.NotStrictlyPositiveException;
|
||||
import org.apache.commons.math4.exception.NumberIsTooLargeException;
|
||||
import org.apache.commons.math4.util.FastMath;
|
||||
|
@ -233,4 +235,197 @@ public class FunctionUtilsTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testToDifferentiableUnivariate() {
|
||||
|
||||
final UnivariateFunction f0 = new UnivariateFunction() {
|
||||
@Override
|
||||
public double value(final double x) {
|
||||
return x * x;
|
||||
}
|
||||
};
|
||||
final UnivariateFunction f1 = new UnivariateFunction() {
|
||||
@Override
|
||||
public double value(final double x) {
|
||||
return 2 * x;
|
||||
}
|
||||
};
|
||||
final UnivariateFunction f2 = new UnivariateFunction() {
|
||||
@Override
|
||||
public double value(final double x) {
|
||||
return 2;
|
||||
}
|
||||
};
|
||||
final UnivariateDifferentiableFunction f = FunctionUtils.toDifferentiable(f0, f1, f2);
|
||||
|
||||
for (double t = -1.0; t < 1; t += 0.01) {
|
||||
// x = sin(t)
|
||||
DerivativeStructure dsT = new DerivativeStructure(1, 2, 0, t);
|
||||
DerivativeStructure y = f.value(dsT.sin());
|
||||
Assert.assertEquals(FastMath.sin(t) * FastMath.sin(t), f.value(FastMath.sin(t)), 1.0e-15);
|
||||
Assert.assertEquals(FastMath.sin(t) * FastMath.sin(t), y.getValue(), 1.0e-15);
|
||||
Assert.assertEquals(2 * FastMath.cos(t) * FastMath.sin(t), y.getPartialDerivative(1), 1.0e-15);
|
||||
Assert.assertEquals(2 * (1 - 2 * FastMath.sin(t) * FastMath.sin(t)), y.getPartialDerivative(2), 1.0e-15);
|
||||
}
|
||||
|
||||
try {
|
||||
f.value(new DerivativeStructure(1, 3, 0.0));
|
||||
Assert.fail("an exception should have been thrown");
|
||||
} catch (NumberIsTooLargeException e) {
|
||||
Assert.assertEquals(2, e.getMax());
|
||||
Assert.assertEquals(3, e.getArgument());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testToDifferentiableMultivariate() {
|
||||
|
||||
final double a = 1.5;
|
||||
final double b = 0.5;
|
||||
final MultivariateFunction f = new MultivariateFunction() {
|
||||
@Override
|
||||
public double value(final double[] point) {
|
||||
return a * point[0] + b * point[1];
|
||||
}
|
||||
};
|
||||
final MultivariateVectorFunction gradient = new MultivariateVectorFunction() {
|
||||
@Override
|
||||
public double[] value(final double[] point) {
|
||||
return new double[] { a, b };
|
||||
}
|
||||
};
|
||||
final MultivariateDifferentiableFunction mdf = FunctionUtils.toDifferentiable(f, gradient);
|
||||
|
||||
for (double t = -1.0; t < 1; t += 0.01) {
|
||||
// x = sin(t), y = cos(t), hence the method really becomes univariate
|
||||
DerivativeStructure dsT = new DerivativeStructure(1, 1, 0, t);
|
||||
DerivativeStructure y = mdf.value(new DerivativeStructure[] { dsT.sin(), dsT.cos() });
|
||||
Assert.assertEquals(a * FastMath.sin(t) + b * FastMath.cos(t), y.getValue(), 1.0e-15);
|
||||
Assert.assertEquals(a * FastMath.cos(t) - b * FastMath.sin(t), y.getPartialDerivative(1), 1.0e-15);
|
||||
}
|
||||
|
||||
for (double u = -1.0; u < 1; u += 0.01) {
|
||||
DerivativeStructure dsU = new DerivativeStructure(2, 1, 0, u);
|
||||
for (double v = -1.0; v < 1; v += 0.01) {
|
||||
DerivativeStructure dsV = new DerivativeStructure(2, 1, 1, v);
|
||||
DerivativeStructure y = mdf.value(new DerivativeStructure[] { dsU, dsV });
|
||||
Assert.assertEquals(a * u + b * v, mdf.value(new double[] { u, v }), 1.0e-15);
|
||||
Assert.assertEquals(a * u + b * v, y.getValue(), 1.0e-15);
|
||||
Assert.assertEquals(a, y.getPartialDerivative(1, 0), 1.0e-15);
|
||||
Assert.assertEquals(b, y.getPartialDerivative(0, 1), 1.0e-15);
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
mdf.value(new DerivativeStructure[] { new DerivativeStructure(1, 3, 0.0), new DerivativeStructure(1, 3, 0.0) });
|
||||
Assert.fail("an exception should have been thrown");
|
||||
} catch (NumberIsTooLargeException e) {
|
||||
Assert.assertEquals(1, e.getMax());
|
||||
Assert.assertEquals(3, e.getArgument());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testToDifferentiableMultivariateInconsistentGradient() {
|
||||
|
||||
final double a = 1.5;
|
||||
final double b = 0.5;
|
||||
final MultivariateFunction f = new MultivariateFunction() {
|
||||
@Override
|
||||
public double value(final double[] point) {
|
||||
return a * point[0] + b * point[1];
|
||||
}
|
||||
};
|
||||
final MultivariateVectorFunction gradient = new MultivariateVectorFunction() {
|
||||
@Override
|
||||
public double[] value(final double[] point) {
|
||||
return new double[] { a, b, 0.0 };
|
||||
}
|
||||
};
|
||||
final MultivariateDifferentiableFunction mdf = FunctionUtils.toDifferentiable(f, gradient);
|
||||
|
||||
try {
|
||||
DerivativeStructure dsT = new DerivativeStructure(1, 1, 0, 0.0);
|
||||
mdf.value(new DerivativeStructure[] { dsT.sin(), dsT.cos() });
|
||||
Assert.fail("an exception should have been thrown");
|
||||
} catch (DimensionMismatchException e) {
|
||||
Assert.assertEquals(2, e.getDimension());
|
||||
Assert.assertEquals(3, e.getArgument());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDerivativeUnivariate() {
|
||||
|
||||
final UnivariateDifferentiableFunction f = new UnivariateDifferentiableFunction() {
|
||||
|
||||
@Override
|
||||
public double value(double x) {
|
||||
return x * x;
|
||||
}
|
||||
|
||||
@Override
|
||||
public DerivativeStructure value(DerivativeStructure x) {
|
||||
return x.multiply(x);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
final UnivariateFunction f0 = FunctionUtils.derivative(f, 0);
|
||||
final UnivariateFunction f1 = FunctionUtils.derivative(f, 1);
|
||||
final UnivariateFunction f2 = FunctionUtils.derivative(f, 2);
|
||||
|
||||
for (double t = -1.0; t < 1; t += 0.01) {
|
||||
Assert.assertEquals(t * t, f0.value(t), 1.0e-15);
|
||||
Assert.assertEquals(2 * t, f1.value(t), 1.0e-15);
|
||||
Assert.assertEquals(2, f2.value(t), 1.0e-15);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDerivativeMultivariate() {
|
||||
|
||||
final double a = 1.5;
|
||||
final double b = 0.5;
|
||||
final double c = 0.25;
|
||||
final MultivariateDifferentiableFunction mdf = new MultivariateDifferentiableFunction() {
|
||||
|
||||
@Override
|
||||
public double value(double[] point) {
|
||||
return a * point[0] * point[0] + b * point[1] * point[1] + c * point[0] * point[1];
|
||||
}
|
||||
|
||||
@Override
|
||||
public DerivativeStructure value(DerivativeStructure[] point) {
|
||||
DerivativeStructure x = point[0];
|
||||
DerivativeStructure y = point[1];
|
||||
DerivativeStructure x2 = x.multiply(x);
|
||||
DerivativeStructure y2 = y.multiply(y);
|
||||
DerivativeStructure xy = x.multiply(y);
|
||||
return x2.multiply(a).add(y2.multiply(b)).add(xy.multiply(c));
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
final MultivariateFunction f = FunctionUtils.derivative(mdf, new int[] { 0, 0 });
|
||||
final MultivariateFunction dfdx = FunctionUtils.derivative(mdf, new int[] { 1, 0 });
|
||||
final MultivariateFunction dfdy = FunctionUtils.derivative(mdf, new int[] { 0, 1 });
|
||||
final MultivariateFunction d2fdx2 = FunctionUtils.derivative(mdf, new int[] { 2, 0 });
|
||||
final MultivariateFunction d2fdy2 = FunctionUtils.derivative(mdf, new int[] { 0, 2 });
|
||||
final MultivariateFunction d2fdxdy = FunctionUtils.derivative(mdf, new int[] { 1, 1 });
|
||||
|
||||
for (double x = -1.0; x < 1; x += 0.01) {
|
||||
for (double y = -1.0; y < 1; y += 0.01) {
|
||||
Assert.assertEquals(a * x * x + b * y * y + c * x * y, f.value(new double[] { x, y }), 1.0e-15);
|
||||
Assert.assertEquals(2 * a * x + c * y, dfdx.value(new double[] { x, y }), 1.0e-15);
|
||||
Assert.assertEquals(2 * b * y + c * x, dfdy.value(new double[] { x, y }), 1.0e-15);
|
||||
Assert.assertEquals(2 * a, d2fdx2.value(new double[] { x, y }), 1.0e-15);
|
||||
Assert.assertEquals(2 * b, d2fdy2.value(new double[] { x, y }), 1.0e-15);
|
||||
Assert.assertEquals(c, d2fdxdy.value(new double[] { x, y }), 1.0e-15);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue