Added converters for multivariate functions.

The converters allow to convert back and forth between the older
and the newer differentiation API. They are considered temporary methods
for version 3.1 and will be removed in 4.0 as only the new API will
remain.

git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1401837 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Luc Maisonobe 2012-10-24 19:39:45 +00:00
parent 3bae991455
commit 7b5a64c0bb
3 changed files with 411 additions and 14 deletions

View File

@ -18,6 +18,8 @@
package org.apache.commons.math3.analysis; package org.apache.commons.math3.analysis;
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure; import org.apache.commons.math3.analysis.differentiation.DerivativeStructure;
import org.apache.commons.math3.analysis.differentiation.MultivariateDifferentiableFunction;
import org.apache.commons.math3.analysis.differentiation.MultivariateDifferentiableVectorFunction;
import org.apache.commons.math3.analysis.differentiation.UnivariateDifferentiableFunction; import org.apache.commons.math3.analysis.differentiation.UnivariateDifferentiableFunction;
import org.apache.commons.math3.analysis.function.Identity; import org.apache.commons.math3.analysis.function.Identity;
import org.apache.commons.math3.exception.DimensionMismatchException; import org.apache.commons.math3.exception.DimensionMismatchException;
@ -466,9 +468,8 @@ public class FunctionUtils {
/** Convert a {@link DifferentiableUnivariateFunction} into a {@link UnivariateDifferentiableFunction}. /** Convert a {@link DifferentiableUnivariateFunction} into a {@link UnivariateDifferentiableFunction}.
* <p> * <p>
* Note that the converted function is able to handle {@link DerivativeStructure} with * Note that the converted function is able to handle {@link DerivativeStructure} up to order one.
* <em>only</em> one parameter and up to order one. If the function is called with * If the function is called with higher order, a {@link NumberIsTooLargeException} will be thrown.
* more parameters or higher order, a {@link DimensionMismatchException} will be thrown.
* </p> * </p>
* @param f function to convert * @param f function to convert
* @return converted function * @return converted function
@ -484,22 +485,304 @@ public class FunctionUtils {
return f.value(x); return f.value(x);
} }
/** {@inheritDoc}
* @exception NumberIsTooLargeException if derivation order is greater than 1
*/
public DerivativeStructure value(final DerivativeStructure t)
throws NumberIsTooLargeException {
switch (t.getOrder()) {
case 0 :
return new DerivativeStructure(t.getFreeParameters(), 0, f.value(t.getValue()));
case 1 : {
final int parameters = t.getFreeParameters();
final double[] derivatives = new double[parameters + 1];
derivatives[0] = f.value(t.getValue());
final double fPrime = f.derivative().value(t.getValue());
int[] orders = new int[parameters];
for (int i = 0; i < parameters; ++i) {
orders[i] = 1;
derivatives[i + 1] = fPrime * t.getPartialDerivative(orders);
orders[i] = 0;
}
return new DerivativeStructure(parameters, 1, derivatives);
}
default :
throw new NumberIsTooLargeException(t.getOrder(), 1, true);
}
}
};
}
/** Convert a {@link MultivariateDifferentiableFunction} into a {@link DifferentiableMultivariateFunction}.
* @param f function to convert
* @return converted function
* @deprecated this conversion method is temporary in version 3.1, as the {@link
* DifferentiableMultivariateFunction} interface itself is deprecated
*/
@Deprecated
public static DifferentiableMultivariateFunction toDifferentiableMultivariateFunction(final MultivariateDifferentiableFunction f) {
return new DifferentiableMultivariateFunction() {
/** {@inheritDoc} */
public double value(final double[] x) {
return f.value(x);
}
/** {@inheritDoc} */
public MultivariateFunction partialDerivative(final int k) {
return new MultivariateFunction() {
/** {@inheritDoc} */
public double value(final double[] x) {
final int n = x.length;
// delegate computation to underlying function
final DerivativeStructure[] dsX = new DerivativeStructure[n];
for (int i = 0; i < n; ++i) {
if (i == k) {
dsX[i] = new DerivativeStructure(1, 1, 0, x[i]);
} else {
dsX[i] = new DerivativeStructure(1, 1, x[i]);
}
}
final DerivativeStructure y = f.value(dsX);
// extract partial derivative
return y.getPartialDerivative(1);
}
};
}
public MultivariateVectorFunction gradient() {
return new MultivariateVectorFunction() {
/** {@inheritDoc} */
public double[] value(final double[] x) {
final int n = x.length;
// delegate computation to underlying function
final DerivativeStructure[] dsX = new DerivativeStructure[n];
for (int i = 0; i < n; ++i) {
dsX[i] = new DerivativeStructure(n, 1, i, x[i]);
}
final DerivativeStructure y = f.value(dsX);
// extract gradient
final double[] gradient = new double[n];
final int[] orders = new int[n];
for (int i = 0; i < n; ++i) {
orders[i] = 1;
gradient[i] = y.getPartialDerivative(orders);
orders[i] = 0;
}
return gradient;
}
};
}
};
}
/** Convert a {@link DifferentiableMultivariateFunction} into a {@link MultivariateDifferentiableFunction}.
* <p>
* Note that the converted function is able to handle {@link DerivativeStructure} elements
* that all have the same number of free parameters and order, and with order at most 1.
* If the function is called with inconsistent numbers of free parameters or higher order, a
* {@link DimensionMismatchException} or a {@link NumberIsTooLargeException} will be thrown.
* </p>
* @param f function to convert
* @return converted function
* @deprecated this conversion method is temporary in version 3.1, as the {@link
* DifferentiableMultivariateFunction} interface itself is deprecated
*/
@Deprecated
public static MultivariateDifferentiableFunction toMultivariateDifferentiableFunction(final DifferentiableMultivariateFunction f) {
return new MultivariateDifferentiableFunction() {
/** {@inheritDoc} */
public double value(final double[] x) {
return f.value(x);
}
/** {@inheritDoc} /** {@inheritDoc}
* @exception DimensionMismatchException if number of parameters or derivation * @exception DimensionMismatchException if number of parameters or derivation
* order are higher than 1 * order are higher than 1
*/ */
public DerivativeStructure value(final DerivativeStructure t) public DerivativeStructure value(final DerivativeStructure[] t)
throws DimensionMismatchException { throws DimensionMismatchException, NumberIsTooLargeException {
if (t.getFreeParameters() != 1) {
throw new DimensionMismatchException(t.getFreeParameters(), 1); // check parameters and orders limits
final int parameters = t[0].getFreeParameters();
final int order = t[0].getOrder();
final int n = t.length;
if (order > 1) {
throw new NumberIsTooLargeException(order, 1, true);
} }
if (t.getOrder() > 1) {
throw new DimensionMismatchException(t.getOrder(), 1); // check all elements in the array are consistent
for (int i = 0; i < n; ++i) {
if (t[i].getFreeParameters() != parameters) {
throw new DimensionMismatchException(t[i].getFreeParameters(), parameters);
}
if (t[i].getOrder() != order) {
throw new DimensionMismatchException(t[i].getOrder(), order);
}
} }
return t.compose(new double[] {
f.value(t.getValue()), // delegate computation to underlying function
f.derivative().value(t.getValue()) final double[] point = new double[n];
}); for (int i = 0; i < n; ++i) {
point[i] = t[i].getValue();
}
final double value = f.value(point);
final double[] gradient = f.gradient().value(point);
// merge value and gradient into one DerivativeStructure
final double[] derivatives = new double[parameters + 1];
derivatives[0] = value;
final int[] orders = new int[parameters];
for (int i = 0; i < parameters; ++i) {
orders[i] = 1;
for (int j = 0; j < n; ++j) {
derivatives[i + 1] += gradient[j] * t[j].getPartialDerivative(orders);
}
orders[i] = 0;
}
return new DerivativeStructure(parameters, order, derivatives);
}
};
}
/** Convert a {@link MultivariateDifferentiableVectorFunction} into a {@link DifferentiableMultivariateVectorFunction}.
* @param f function to convert
* @return converted function
* @deprecated this conversion method is temporary in version 3.1, as the {@link
* DifferentiableMultivariateVectorFunction} interface itself is deprecated
*/
@Deprecated
public static DifferentiableMultivariateVectorFunction toDifferentiableMultivariateVectorFunction(final MultivariateDifferentiableVectorFunction f) {
return new DifferentiableMultivariateVectorFunction() {
/** {@inheritDoc} */
public double[] value(final double[] x) {
return f.value(x);
}
public MultivariateMatrixFunction jacobian() {
return new MultivariateMatrixFunction() {
/** {@inheritDoc} */
public double[][] value(final double[] x) {
final int n = x.length;
// delegate computation to underlying function
final DerivativeStructure[] dsX = new DerivativeStructure[n];
for (int i = 0; i < n; ++i) {
dsX[i] = new DerivativeStructure(n, 1, i, x[i]);
}
final DerivativeStructure[] y = f.value(dsX);
// extract Jacobian
final double[][] jacobian = new double[y.length][n];
final int[] orders = new int[n];
for (int i = 0; i < y.length; ++i) {
for (int j = 0; j < n; ++j) {
orders[j] = 1;
jacobian[i][j] = y[i].getPartialDerivative(orders);
orders[j] = 0;
}
}
return jacobian;
}
};
}
};
}
/** Convert a {@link DifferentiableMultivariateVectorFunction} into a {@link MultivariateDifferentiableVectorFunction}.
* <p>
* Note that the converted function is able to handle {@link DerivativeStructure} elements
* that all have the same number of free parameters and order, and with order at most 1.
* If the function is called with inconsistent numbers of free parameters or higher order, a
* {@link DimensionMismatchException} or a {@link NumberIsTooLargeException} will be thrown.
* </p>
* @param f function to convert
* @return converted function
* @deprecated this conversion method is temporary in version 3.1, as the {@link
* DifferentiableMultivariateFunction} interface itself is deprecated
*/
@Deprecated
public static MultivariateDifferentiableVectorFunction toMultivariateDifferentiableVectorFunction(final DifferentiableMultivariateVectorFunction f) {
return new MultivariateDifferentiableVectorFunction() {
/** {@inheritDoc} */
public double[] value(final double[] x) {
return f.value(x);
}
/** {@inheritDoc}
* @exception DimensionMismatchException if number of parameters or derivation
* order are higher than 1
*/
public DerivativeStructure[] value(final DerivativeStructure[] t)
throws DimensionMismatchException, NumberIsTooLargeException {
// check parameters and orders limits
final int parameters = t[0].getFreeParameters();
final int order = t[0].getOrder();
final int n = t.length;
if (order > 1) {
throw new NumberIsTooLargeException(order, 1, true);
}
// check all elements in the array are consistent
for (int i = 0; i < n; ++i) {
if (t[i].getFreeParameters() != parameters) {
throw new DimensionMismatchException(t[i].getFreeParameters(), parameters);
}
if (t[i].getOrder() != order) {
throw new DimensionMismatchException(t[i].getOrder(), order);
}
}
// delegate computation to underlying function
final double[] point = new double[n];
for (int i = 0; i < n; ++i) {
point[i] = t[i].getValue();
}
final double[] value = f.value(point);
final double[][] jacobian = f.jacobian().value(point);
// merge value and Jacobian into a DerivativeStructure array
final DerivativeStructure[] merged = new DerivativeStructure[value.length];
for (int k = 0; k < merged.length; ++k) {
final double[] derivatives = new double[parameters + 1];
derivatives[0] = value[k];
final int[] orders = new int[parameters];
for (int i = 0; i < parameters; ++i) {
orders[i] = 1;
for (int j = 0; j < n; ++j) {
derivatives[i + 1] += jacobian[k][j] * t[j].getPartialDerivative(orders);
}
orders[i] = 0;
}
merged[k] = new DerivativeStructure(parameters, order, derivatives);
}
return merged;
} }
}; };

View File

@ -551,7 +551,7 @@ public class DerivativeStructure implements FieldElement<DerivativeStructure>, S
* @exception DimensionMismatchException if the number of derivatives * @exception DimensionMismatchException if the number of derivatives
* in the array is not equal to {@link #getOrder() order} + 1 * in the array is not equal to {@link #getOrder() order} + 1
*/ */
public DerivativeStructure compose(final double[] f) { public DerivativeStructure compose(final double ... f) {
if (f.length != getOrder() + 1) { if (f.length != getOrder() + 1) {
throw new DimensionMismatchException(f.length, getOrder() + 1); throw new DimensionMismatchException(f.length, getOrder() + 1);
} }

View File

@ -18,6 +18,7 @@
package org.apache.commons.math3.analysis; package org.apache.commons.math3.analysis;
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure; import org.apache.commons.math3.analysis.differentiation.DerivativeStructure;
import org.apache.commons.math3.analysis.differentiation.MultivariateDifferentiableFunction;
import org.apache.commons.math3.analysis.differentiation.UnivariateDifferentiableFunction; import org.apache.commons.math3.analysis.differentiation.UnivariateDifferentiableFunction;
import org.apache.commons.math3.analysis.function.Add; import org.apache.commons.math3.analysis.function.Add;
import org.apache.commons.math3.analysis.function.Constant; import org.apache.commons.math3.analysis.function.Constant;
@ -232,4 +233,117 @@ public class FunctionUtilsTest {
Assert.assertEquals("x = " + x, FastMath.sin(x), actual[i], 0.0); Assert.assertEquals("x = " + x, FastMath.sin(x), actual[i], 0.0);
} }
} }
@Test
@Deprecated
public void testToDifferentiableUnivariateFunction() {
// Sin implements both UnivariateDifferentiableFunction and DifferentiableUnivariateFunction
Sin sin = new Sin();
DifferentiableUnivariateFunction converted = FunctionUtils.toDifferentiableUnivariateFunction(sin);
for (double x = 0.1; x < 0.5; x += 0.01) {
Assert.assertEquals(sin.value(x), converted.value(x), 1.0e-10);
Assert.assertEquals(sin.derivative().value(x), converted.derivative().value(x), 1.0e-10);
}
}
@Test
@Deprecated
public void testToUnivariateDifferential() {
// Sin implements both UnivariateDifferentiableFunction and DifferentiableUnivariateFunction
Sin sin = new Sin();
UnivariateDifferentiableFunction converted = FunctionUtils.toUnivariateDifferential(sin);
for (double x = 0.1; x < 0.5; x += 0.01) {
DerivativeStructure t = new DerivativeStructure(2, 1, x, 1.0, 2.0);
Assert.assertEquals(sin.value(t).getValue(), converted.value(t).getValue(), 1.0e-10);
Assert.assertEquals(sin.value(t).getPartialDerivative(1, 0),
converted.value(t).getPartialDerivative(1, 0),
1.0e-10);
Assert.assertEquals(sin.value(t).getPartialDerivative(0, 1),
converted.value(t).getPartialDerivative(0, 1),
1.0e-10);
}
}
@Test
@Deprecated
public void testToDifferentiableMultivariateFunction() {
MultivariateDifferentiableFunction hypot = new MultivariateDifferentiableFunction() {
public double value(double[] point) {
return FastMath.hypot(point[0], point[1]);
}
public DerivativeStructure value(DerivativeStructure[] point) {
return DerivativeStructure.hypot(point[0], point[1]);
}
};
DifferentiableMultivariateFunction converted = FunctionUtils.toDifferentiableMultivariateFunction(hypot);
for (double x = 0.1; x < 0.5; x += 0.01) {
for (double y = 0.1; y < 0.5; y += 0.01) {
double[] point = new double[] { x, y };
Assert.assertEquals(hypot.value(point), converted.value(point), 1.0e-10);
Assert.assertEquals(x / hypot.value(point), converted.gradient().value(point)[0], 1.0e-10);
Assert.assertEquals(y / hypot.value(point), converted.gradient().value(point)[1], 1.0e-10);
}
}
}
@Test
@Deprecated
public void testToMultivariateDifferentiableFunction() {
DifferentiableMultivariateFunction hypot = new DifferentiableMultivariateFunction() {
public double value(double[] point) {
return FastMath.hypot(point[0], point[1]);
}
public MultivariateFunction partialDerivative(final int k) {
return new MultivariateFunction() {
public double value(double[] point) {
return point[k] / FastMath.hypot(point[0], point[1]);
}
};
}
public MultivariateVectorFunction gradient() {
return new MultivariateVectorFunction() {
public double[] value(double[] point) {
final double h = FastMath.hypot(point[0], point[1]);
return new double[] { point[0] / h, point[1] / h };
}
};
}
};
MultivariateDifferentiableFunction converted = FunctionUtils.toMultivariateDifferentiableFunction(hypot);
for (double x = 0.1; x < 0.5; x += 0.01) {
for (double y = 0.1; y < 0.5; y += 0.01) {
DerivativeStructure[] t = new DerivativeStructure[] {
new DerivativeStructure(3, 1, x, 1.0, 2.0, 3.0 ),
new DerivativeStructure(3, 1, y, 4.0, 5.0, 6.0 )
};
DerivativeStructure h = DerivativeStructure.hypot(t[0], t[1]);
Assert.assertEquals(h.getValue(), converted.value(t).getValue(), 1.0e-10);
Assert.assertEquals(h.getPartialDerivative(1, 0, 0),
converted.value(t).getPartialDerivative(1, 0, 0),
1.0e-10);
Assert.assertEquals(h.getPartialDerivative(0, 1, 0),
converted.value(t).getPartialDerivative(0, 1, 0),
1.0e-10);
Assert.assertEquals(h.getPartialDerivative(0, 0, 1),
converted.value(t).getPartialDerivative(0, 0, 1),
1.0e-10);
}
}
}
} }