Converters for univariate and multivariate differentiable functions.

JIRA: MATH-1143
This commit is contained in:
Luc Maisonobe 2015-05-03 19:18:09 +02:00
parent cb21480cb1
commit 613afdb0c3
3 changed files with 402 additions and 0 deletions

View File

@ -54,6 +54,9 @@ If the output is not quite correct, check for invisible trailing spaces!
</release>
<release version="4.0" date="XXXX-XX-XX" description="">
<action dev="luc" type="fix" issue="MATH-1143">
Added helper methods to FunctionUtils for univariate and multivariate differentiable functions conversion.
</action>
<action dev="tn" type="fix" issue="MATH-964">
Removed unused package private class PollardRho in package primes.
</action>

View File

@ -18,12 +18,14 @@
package org.apache.commons.math4.analysis;
import org.apache.commons.math4.analysis.differentiation.DerivativeStructure;
import org.apache.commons.math4.analysis.differentiation.MultivariateDifferentiableFunction;
import org.apache.commons.math4.analysis.differentiation.UnivariateDifferentiableFunction;
import org.apache.commons.math4.analysis.function.Identity;
import org.apache.commons.math4.exception.DimensionMismatchException;
import org.apache.commons.math4.exception.NotStrictlyPositiveException;
import org.apache.commons.math4.exception.NumberIsTooLargeException;
import org.apache.commons.math4.exception.util.LocalizedFormats;
import org.apache.commons.math4.util.MathArrays;
/**
* Utilities for manipulating function objects.
@ -337,4 +339,206 @@ public class FunctionUtils {
return s;
}
/** Convert regular functions to {@link UnivariateDifferentiableFunction}.
* <p>
* This method handle the case with one free parameter and several derivatives.
* For the case with several free parameters and only first order derivatives,
* see {@link #toDifferentiable(MultivariateFunction, MultivariateVectorFunction)}.
* There are no direct support for intermediate cases, with several free parameters
* and order 2 or more derivatives, as is would be difficult to specify all the
* cross derivatives.
* </p>
* <p>
* Note that the derivatives are expected to be computed only with respect to the
* raw parameter x of the base function, i.e. they are df/dx, df<sup>2</sup>/dx<sup>2</sup>, ...
* Even if the built function is later used in a composition like f(sin(t)), the provided
* derivatives should <em>not</em> apply the composition with sine and its derivatives by
* themselves. The composition will be done automatically here and the result will properly
* contain f(sin(t)), df(sin(t))/dt, df<sup>2</sup>(sin(t))/dt<sup>2</sup> despite the
* provided derivatives functions know nothing about the sine function.
* </p>
* @param f base function f(x)
* @param derivatives derivatives of the base function, in increasing differentiation order
* @return a differentiable function with value and all specified derivatives
* @see #toDifferentiable(MultivariateFunction, MultivariateVectorFunction)
* @see #derivative(UnivariateDifferentiableFunction, int)
*/
public static UnivariateDifferentiableFunction toDifferentiable(final UnivariateFunction f,
final UnivariateFunction ... derivatives) {
return new UnivariateDifferentiableFunction() {
/** {@inheritDoc} */
@Override
public double value(final double x) {
return f.value(x);
}
/** {@inheritDoc} */
@Override
public DerivativeStructure value(final DerivativeStructure x) {
if (x.getOrder() > derivatives.length) {
throw new NumberIsTooLargeException(x.getOrder(), derivatives.length, true);
}
final double[] packed = new double[x.getOrder() + 1];
packed[0] = f.value(x.getValue());
for (int i = 0; i < x.getOrder(); ++i) {
packed[i + 1] = derivatives[i].value(x.getValue());
}
return x.compose(packed);
}
};
}
/** Convert regular functions to {@link MultivariateDifferentiableFunction}.
* <p>
* This method handle the case with several free parameters and only first order derivatives.
* For the case with one free parameter and several derivatives,
* see {@link #toDifferentiable(UnivariateFunction, UnivariateFunction...)}.
* There are no direct support for intermediate cases, with several free parameters
* and order 2 or more derivatives, as is would be difficult to specify all the
* cross derivatives.
* </p>
* <p>
* Note that the gradient is expected to be computed only with respect to the
* raw parameter x of the base function, i.e. it is df/dx<sub>1</sub>, df/dx<sub>2</sub>, ...
* Even if the built function is later used in a composition like f(sin(t), cos(t)), the provided
* gradient should <em>not</em> apply the composition with sine or cosine and their derivative by
* itself. The composition will be done automatically here and the result will properly
* contain f(sin(t), cos(t)), df(sin(t), cos(t))/dt despite the provided derivatives functions
* know nothing about the sine or cosine functions.
* </p>
* @param f base function f(x)
* @param gradient gradient of the base function
* @return a differentiable function with value and gradient
* @see #toDifferentiable(UnivariateFunction, UnivariateFunction...)
* @see #derivative(MultivariateDifferentiableFunction, int[])
*/
public static MultivariateDifferentiableFunction toDifferentiable(final MultivariateFunction f,
final MultivariateVectorFunction gradient) {
return new MultivariateDifferentiableFunction() {
/** {@inheritDoc} */
@Override
public double value(final double[] point) {
return f.value(point);
}
/** {@inheritDoc} */
@Override
public DerivativeStructure value(final DerivativeStructure[] point) {
// set up the input parameters
final double[] dPoint = new double[point.length];
for (int i = 0; i < point.length; ++i) {
dPoint[i] = point[i].getValue();
if (point[i].getOrder() > 1) {
throw new NumberIsTooLargeException(point[i].getOrder(), 1, true);
}
}
// evaluate regular functions
final double v = f.value(dPoint);
final double[] dv = gradient.value(dPoint);
if (dv.length != point.length) {
// the gradient function is inconsistent
throw new DimensionMismatchException(dv.length, point.length);
}
// build the combined derivative
final int parameters = point[0].getFreeParameters();
final double[] partials = new double[point.length];
final double[] packed = new double[parameters + 1];
packed[0] = v;
final int orders[] = new int[parameters];
for (int i = 0; i < parameters; ++i) {
// we differentiate once with respect to parameter i
orders[i] = 1;
for (int j = 0; j < point.length; ++j) {
partials[j] = point[j].getPartialDerivative(orders);
}
orders[i] = 0;
// compose partial derivatives
packed[i + 1] = MathArrays.linearCombination(dv, partials);
}
return new DerivativeStructure(parameters, 1, packed);
}
};
}
/** Convert an {@link UnivariateDifferentiableFunction} to an
* {@link UnivariateFunction} computing n<sup>th</sup> order derivative.
* <p>
* This converter is only a convenience method. Beware computing only one derivative does
* not save any computation as the original function will really be called under the hood.
* The derivative will be extracted from the full {@link DerivativeStructure} result.
* </p>
* @param f original function, with value and all its derivatives
* @param order of the derivative to extract
* @return function computing the derivative at required order
* @see #derivative(MultivariateDifferentiableFunction, int[])
* @see #toDifferentiable(UnivariateFunction, UnivariateFunction...)
*/
public static UnivariateFunction derivative(final UnivariateDifferentiableFunction f, final int order) {
return new UnivariateFunction() {
/** {@inheritDoc} */
@Override
public double value(final double x) {
final DerivativeStructure dsX = new DerivativeStructure(1, order, 0, x);
return f.value(dsX).getPartialDerivative(order);
}
};
}
/** Convert an {@link MultivariateDifferentiableFunction} to an
* {@link MultivariateFunction} computing n<sup>th</sup> order derivative.
* <p>
* This converter is only a convenience method. Beware computing only one derivative does
* not save any computation as the original function will really be called under the hood.
* The derivative will be extracted from the full {@link DerivativeStructure} result.
* </p>
* @param f original function, with value and all its derivatives
* @param orders of the derivative to extract, for each free parameters
* @return function computing the derivative at required order
* @see #derivative(UnivariateDifferentiableFunction, int)
* @see #toDifferentiable(MultivariateFunction, MultivariateVectorFunction)
*/
public static MultivariateFunction derivative(final MultivariateDifferentiableFunction f, final int[] orders) {
return new MultivariateFunction() {
/** {@inheritDoc} */
@Override
public double value(final double[] point) {
// the maximum differentiation order is the sum of all orders
int sumOrders = 0;
for (final int order : orders) {
sumOrders += order;
}
// set up the input parameters
final DerivativeStructure[] dsPoint = new DerivativeStructure[point.length];
for (int i = 0; i < point.length; ++i) {
dsPoint[i] = new DerivativeStructure(point.length, sumOrders, i, point[i]);
}
return f.value(dsPoint).getPartialDerivative(orders);
}
};
}
}

View File

@ -18,6 +18,7 @@
package org.apache.commons.math4.analysis;
import org.apache.commons.math4.analysis.differentiation.DerivativeStructure;
import org.apache.commons.math4.analysis.differentiation.MultivariateDifferentiableFunction;
import org.apache.commons.math4.analysis.differentiation.UnivariateDifferentiableFunction;
import org.apache.commons.math4.analysis.function.Add;
import org.apache.commons.math4.analysis.function.Constant;
@ -35,6 +36,7 @@ import org.apache.commons.math4.analysis.function.Pow;
import org.apache.commons.math4.analysis.function.Power;
import org.apache.commons.math4.analysis.function.Sin;
import org.apache.commons.math4.analysis.function.Sinc;
import org.apache.commons.math4.exception.DimensionMismatchException;
import org.apache.commons.math4.exception.NotStrictlyPositiveException;
import org.apache.commons.math4.exception.NumberIsTooLargeException;
import org.apache.commons.math4.util.FastMath;
@ -233,4 +235,197 @@ public class FunctionUtilsTest {
}
}
@Test
public void testToDifferentiableUnivariate() {
final UnivariateFunction f0 = new UnivariateFunction() {
@Override
public double value(final double x) {
return x * x;
}
};
final UnivariateFunction f1 = new UnivariateFunction() {
@Override
public double value(final double x) {
return 2 * x;
}
};
final UnivariateFunction f2 = new UnivariateFunction() {
@Override
public double value(final double x) {
return 2;
}
};
final UnivariateDifferentiableFunction f = FunctionUtils.toDifferentiable(f0, f1, f2);
for (double t = -1.0; t < 1; t += 0.01) {
// x = sin(t)
DerivativeStructure dsT = new DerivativeStructure(1, 2, 0, t);
DerivativeStructure y = f.value(dsT.sin());
Assert.assertEquals(FastMath.sin(t) * FastMath.sin(t), f.value(FastMath.sin(t)), 1.0e-15);
Assert.assertEquals(FastMath.sin(t) * FastMath.sin(t), y.getValue(), 1.0e-15);
Assert.assertEquals(2 * FastMath.cos(t) * FastMath.sin(t), y.getPartialDerivative(1), 1.0e-15);
Assert.assertEquals(2 * (1 - 2 * FastMath.sin(t) * FastMath.sin(t)), y.getPartialDerivative(2), 1.0e-15);
}
try {
f.value(new DerivativeStructure(1, 3, 0.0));
Assert.fail("an exception should have been thrown");
} catch (NumberIsTooLargeException e) {
Assert.assertEquals(2, e.getMax());
Assert.assertEquals(3, e.getArgument());
}
}
@Test
public void testToDifferentiableMultivariate() {
final double a = 1.5;
final double b = 0.5;
final MultivariateFunction f = new MultivariateFunction() {
@Override
public double value(final double[] point) {
return a * point[0] + b * point[1];
}
};
final MultivariateVectorFunction gradient = new MultivariateVectorFunction() {
@Override
public double[] value(final double[] point) {
return new double[] { a, b };
}
};
final MultivariateDifferentiableFunction mdf = FunctionUtils.toDifferentiable(f, gradient);
for (double t = -1.0; t < 1; t += 0.01) {
// x = sin(t), y = cos(t), hence the method really becomes univariate
DerivativeStructure dsT = new DerivativeStructure(1, 1, 0, t);
DerivativeStructure y = mdf.value(new DerivativeStructure[] { dsT.sin(), dsT.cos() });
Assert.assertEquals(a * FastMath.sin(t) + b * FastMath.cos(t), y.getValue(), 1.0e-15);
Assert.assertEquals(a * FastMath.cos(t) - b * FastMath.sin(t), y.getPartialDerivative(1), 1.0e-15);
}
for (double u = -1.0; u < 1; u += 0.01) {
DerivativeStructure dsU = new DerivativeStructure(2, 1, 0, u);
for (double v = -1.0; v < 1; v += 0.01) {
DerivativeStructure dsV = new DerivativeStructure(2, 1, 1, v);
DerivativeStructure y = mdf.value(new DerivativeStructure[] { dsU, dsV });
Assert.assertEquals(a * u + b * v, mdf.value(new double[] { u, v }), 1.0e-15);
Assert.assertEquals(a * u + b * v, y.getValue(), 1.0e-15);
Assert.assertEquals(a, y.getPartialDerivative(1, 0), 1.0e-15);
Assert.assertEquals(b, y.getPartialDerivative(0, 1), 1.0e-15);
}
}
try {
mdf.value(new DerivativeStructure[] { new DerivativeStructure(1, 3, 0.0), new DerivativeStructure(1, 3, 0.0) });
Assert.fail("an exception should have been thrown");
} catch (NumberIsTooLargeException e) {
Assert.assertEquals(1, e.getMax());
Assert.assertEquals(3, e.getArgument());
}
}
@Test
public void testToDifferentiableMultivariateInconsistentGradient() {
final double a = 1.5;
final double b = 0.5;
final MultivariateFunction f = new MultivariateFunction() {
@Override
public double value(final double[] point) {
return a * point[0] + b * point[1];
}
};
final MultivariateVectorFunction gradient = new MultivariateVectorFunction() {
@Override
public double[] value(final double[] point) {
return new double[] { a, b, 0.0 };
}
};
final MultivariateDifferentiableFunction mdf = FunctionUtils.toDifferentiable(f, gradient);
try {
DerivativeStructure dsT = new DerivativeStructure(1, 1, 0, 0.0);
mdf.value(new DerivativeStructure[] { dsT.sin(), dsT.cos() });
Assert.fail("an exception should have been thrown");
} catch (DimensionMismatchException e) {
Assert.assertEquals(2, e.getDimension());
Assert.assertEquals(3, e.getArgument());
}
}
@Test
public void testDerivativeUnivariate() {
final UnivariateDifferentiableFunction f = new UnivariateDifferentiableFunction() {
@Override
public double value(double x) {
return x * x;
}
@Override
public DerivativeStructure value(DerivativeStructure x) {
return x.multiply(x);
}
};
final UnivariateFunction f0 = FunctionUtils.derivative(f, 0);
final UnivariateFunction f1 = FunctionUtils.derivative(f, 1);
final UnivariateFunction f2 = FunctionUtils.derivative(f, 2);
for (double t = -1.0; t < 1; t += 0.01) {
Assert.assertEquals(t * t, f0.value(t), 1.0e-15);
Assert.assertEquals(2 * t, f1.value(t), 1.0e-15);
Assert.assertEquals(2, f2.value(t), 1.0e-15);
}
}
@Test
public void testDerivativeMultivariate() {
final double a = 1.5;
final double b = 0.5;
final double c = 0.25;
final MultivariateDifferentiableFunction mdf = new MultivariateDifferentiableFunction() {
@Override
public double value(double[] point) {
return a * point[0] * point[0] + b * point[1] * point[1] + c * point[0] * point[1];
}
@Override
public DerivativeStructure value(DerivativeStructure[] point) {
DerivativeStructure x = point[0];
DerivativeStructure y = point[1];
DerivativeStructure x2 = x.multiply(x);
DerivativeStructure y2 = y.multiply(y);
DerivativeStructure xy = x.multiply(y);
return x2.multiply(a).add(y2.multiply(b)).add(xy.multiply(c));
}
};
final MultivariateFunction f = FunctionUtils.derivative(mdf, new int[] { 0, 0 });
final MultivariateFunction dfdx = FunctionUtils.derivative(mdf, new int[] { 1, 0 });
final MultivariateFunction dfdy = FunctionUtils.derivative(mdf, new int[] { 0, 1 });
final MultivariateFunction d2fdx2 = FunctionUtils.derivative(mdf, new int[] { 2, 0 });
final MultivariateFunction d2fdy2 = FunctionUtils.derivative(mdf, new int[] { 0, 2 });
final MultivariateFunction d2fdxdy = FunctionUtils.derivative(mdf, new int[] { 1, 1 });
for (double x = -1.0; x < 1; x += 0.01) {
for (double y = -1.0; y < 1; y += 0.01) {
Assert.assertEquals(a * x * x + b * y * y + c * x * y, f.value(new double[] { x, y }), 1.0e-15);
Assert.assertEquals(2 * a * x + c * y, dfdx.value(new double[] { x, y }), 1.0e-15);
Assert.assertEquals(2 * b * y + c * x, dfdy.value(new double[] { x, y }), 1.0e-15);
Assert.assertEquals(2 * a, d2fdx2.value(new double[] { x, y }), 1.0e-15);
Assert.assertEquals(2 * b, d2fdy2.value(new double[] { x, y }), 1.0e-15);
Assert.assertEquals(c, d2fdxdy.value(new double[] { x, y }), 1.0e-15);
}
}
}
}