Added a wrapper class to compute gradient from differentiable function.

git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1383885 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Luc Maisonobe 2012-09-12 11:22:56 +00:00
parent 6fe3ae0e6c
commit 88678b58a4
2 changed files with 155 additions and 0 deletions

View File

@ -0,0 +1,67 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.analysis.differentiation;
import org.apache.commons.math3.analysis.MultivariateVectorFunction;
/** Class representing the gradient of a multivariate function.
* <p>
* The vectorial components of the function represent the derivatives
* with respect to each function parameters.
* </p>
* @version $Id$
* @since 3.1
*/
public class GradientFunction implements MultivariateVectorFunction {
/** Underlying real-valued function. */
private final MultivariateDifferentiableFunction f;
/** Simple constructor.
* @param f underlying real-valued function
*/
public GradientFunction(final MultivariateDifferentiableFunction f) {
this.f = f;
}
/** {@inheritDoc} */
public double[] value(double[] point)
throws IllegalArgumentException {
// set up parameters
final DerivativeStructure[] dsX = new DerivativeStructure[point.length];
for (int i = 0; i < point.length; ++i) {
dsX[i] = new DerivativeStructure(point.length, 1, i, point[i]);
}
// compute the derivatives
final DerivativeStructure dsY = f.value(dsX);
// extract the gradient
final double[] y = new double[point.length];
final int[] orders = new int[point.length];
for (int i = 0; i < point.length; ++i) {
orders[i] = 1;
y[i] = dsY.getPartialDerivative(orders);
orders[i] = 0;
}
return y;
}
}

View File

@ -0,0 +1,88 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.analysis.differentiation;
import org.apache.commons.math3.TestUtils;
import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.exception.MathIllegalArgumentException;
import org.apache.commons.math3.util.FastMath;
import org.junit.Test;
/**
* Test for class {@link GradientFunction}.
*/
public class GradientFunctionTest {
@Test
public void test2DDistance() {
EuclideanDistance f = new EuclideanDistance();
GradientFunction g = new GradientFunction(f);
for (double x = -10; x < 10; x += 0.5) {
for (double y = -10; y < 10; y += 0.5) {
double[] point = new double[] { x, y };
TestUtils.assertEquals(f.gradient(point), g.value(point), 1.0e-15);
}
}
}
@Test
public void test3DDistance() {
EuclideanDistance f = new EuclideanDistance();
GradientFunction g = new GradientFunction(f);
for (double x = -10; x < 10; x += 0.5) {
for (double y = -10; y < 10; y += 0.5) {
for (double z = -10; z < 10; z += 0.5) {
double[] point = new double[] { x, y, z };
TestUtils.assertEquals(f.gradient(point), g.value(point), 1.0e-15);
}
}
}
}
private static class EuclideanDistance implements MultivariateDifferentiableFunction {
public double value(double[] point) {
double d2 = 0;
for (double x : point) {
d2 += x * x;
}
return FastMath.sqrt(d2);
}
public DerivativeStructure value(DerivativeStructure[] point)
throws DimensionMismatchException, MathIllegalArgumentException {
DerivativeStructure d2 = point[0].getField().getZero();
for (DerivativeStructure x : point) {
d2 = d2.add(x.multiply(x));
}
return d2.sqrt();
}
public double[] gradient(double[] point) {
double[] gradient = new double[point.length];
double d = value(point);
for (int i = 0; i < point.length; ++i) {
gradient[i] = point[i] / d;
}
return gradient;
}
}
}