Gauss-Hermite quadrature scheme.


git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1500018 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Gilles Sadowski 2013-07-05 14:20:19 +00:00
parent 3739de305d
commit 7fb87df294
8 changed files with 634 additions and 3 deletions

View File

@ -51,6 +51,14 @@ If the output is not quite correct, check for invisible trailing spaces!
</properties>
<body>
<release version="x.y" date="TBD" description="TBD">
<action dev="erans" type="add" issue="MATH-997">
Implemented Gauss-Hermite quadrature scheme (in package
"o.a.c.m.analysis.integration.gauss").
</action>
<action dev="erans" type="update" issue="MATH-995">
Documented limitation of "IterativeLegendreGaussIntegrator" (added
warning about potential wrong usage).
</action>
<action dev="erans" type="fix" issue="MATH-993">
In "GaussNewtonOptimizer", check for convergence before updating the
parameters estimation for the next iteration.

View File

@ -107,4 +107,24 @@ public class GaussIntegrator {
public int getNumberOfPoints() {
return points.length;
}
/**
* Gets the integration point at the given index.
* The index must be in the valid range but no check is performed.
*
* @return the integration point.
*/
public double getPoint(int index) {
return points[index];
}
/**
* Gets the weight of the integration point at the given index.
* The index must be in the valid range but no check is performed.
*
* @return the weight.
*/
public double getWeight(int index) {
return weights[index];
}
}

View File

@ -20,7 +20,9 @@ import java.math.BigDecimal;
import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
import org.apache.commons.math3.analysis.BivariateFunction;
import org.apache.commons.math3.util.Pair;
import org.apache.commons.math3.util.FastMath;
/**
* Class that provides different ways to compute the nodes and weights to be
@ -34,9 +36,12 @@ public class GaussIntegratorFactory {
private final BaseRuleFactory<Double> legendre = new LegendreRuleFactory();
/** Generator of Gauss-Legendre integrators. */
private final BaseRuleFactory<BigDecimal> legendreHighPrecision = new LegendreHighPrecisionRuleFactory();
/** Generator of Gauss-Hermite integrators. */
private final BaseRuleFactory<Double> hermite = new HermiteRuleFactory();
/**
* Creates an integrator of the given order, and whose call to the
* Creates a Gauss-Legendre integrator of the given order.
* The call to the
* {@link GaussIntegrator#integrate(org.apache.commons.math3.analysis.UnivariateFunction)
* integrate} method will perform an integration on the natural interval
* {@code [-1 , 1]}.
@ -49,7 +54,8 @@ public class GaussIntegratorFactory {
}
/**
* Creates an integrator of the given order, and whose call to the
* Creates a Gauss-Legendre integrator of the given order.
* The call to the
* {@link GaussIntegrator#integrate(org.apache.commons.math3.analysis.UnivariateFunction)
* integrate} method will perform an integration on the given interval.
*
@ -68,7 +74,8 @@ public class GaussIntegratorFactory {
}
/**
* Creates an integrator of the given order, and whose call to the
* Creates a Gauss-Legendre integrator of the given order.
* The call to the
* {@link GaussIntegrator#integrate(org.apache.commons.math3.analysis.UnivariateFunction)
* integrate} method will perform an integration on the natural interval
* {@code [-1 , 1]}.
@ -101,6 +108,26 @@ public class GaussIntegratorFactory {
lowerBound, upperBound));
}
/**
* Creates a Gauss-Hermite integrator of the given order.
* The call to the
* {@link SymmetricGaussIntegrator#integrate(org.apache.commons.math3.analysis.UnivariateFunction)
* integrate} method will perform a weighted integration on the interval
* {@code [-&inf;, +&inf;]}: the computed value is the improper integral of
* <code>
* e<sup>-x<sup>2</sup></sup> f(x)
* </code>
* where {@code f(x)} is the function passed to the
* {@link SymmetricGaussIntegrator#integrate(org.apache.commons.math3.analysis.UnivariateFunction)
* integrate} method.
*
* @param numberOfPoints Order of the integration rule.
* @return a Gauss-Hermite integrator.
*/
public SymmetricGaussIntegrator hermite(int numberOfPoints) {
return new SymmetricGaussIntegrator(getRule(hermite, numberOfPoints));
}
/**
* @param factory Integration rule factory.
* @param numberOfPoints Order of the integration rule.

View File

@ -0,0 +1,179 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.analysis.integration.gauss;
import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.util.Pair;
import org.apache.commons.math3.util.FastMath;
/**
* Factory that creates a
* <a href="http://en.wikipedia.org/wiki/Gauss-Hermite_quadrature">
* Gauss-type quadrature rule using Hermite polynomials</a>
* of the first kind.
* Such a quadrature rule allows the calculation of improper integrals
* of a function
* <code>
* f(x) e<sup>-x<sup>2</sup></sup>
* </code>
* <br/>
* Recurrence relation and weights computation follow
* <a href="http://en.wikipedia.org/wiki/Abramowitz_and_Stegun"
* Abramowitz and Stegun, 1964</a>.
* <br/>
* The coefficients of the standard Hermite polynomials grow very rapidly;
* in order to avoid overflows, each Hermite polynomial is normalized with
* respect to the underlying scalar product.
* The initial interval for the application of the bisection method is
* based on the roots of the previous Hermite polynomial (interlacing).
* Upper and lower bounds of these roots are provided by
* <quote>
* I. Krasikov,
* <em>Nonnegative quadratic forms and bounds on orthogonal polynomials</em>,
* Journal of Approximation theory <b>111</b>, 31-49
* </quote>
*
* @since 3.3
* @version $Id$
*/
public class HermiteRuleFactory extends BaseRuleFactory<Double> {
/** &pi;<sup>1/2</sup> */
private static final double SQRT_PI = 1.77245385090551602729;
/** &pi;<sup>-1/4</sup> */
private static final double H0 = 7.5112554446494248286e-1;
/** &pi;<sup>-1/4</sup> &radic;2 */
private static final double H1 = 1.0622519320271969145;
/** {@inheritDoc} */
@Override
protected Pair<Double[], Double[]> computeRule(int numberOfPoints)
throws DimensionMismatchException {
if (numberOfPoints == 1) {
// Break recursion.
return new Pair<Double[], Double[]>(new Double[] { 0d },
new Double[] { SQRT_PI });
}
// Get previous rule.
// If it has not been computed yet it will trigger a recursive call
// to this method.
final int lastNumPoints = numberOfPoints - 1;
final Double[] previousPoints = getRuleInternal(lastNumPoints).getFirst();
// Compute next rule.
final Double[] points = new Double[numberOfPoints];
final Double[] weights = new Double[numberOfPoints];
final double sqrtTwoTimesLastNumPoints = FastMath.sqrt(2 * lastNumPoints);
final double sqrtTwoTimesNumPoints = FastMath.sqrt(2 * numberOfPoints);
// Find i-th root of H[n+1] by bracketing.
final int iMax = numberOfPoints / 2;
for (int i = 0; i < iMax; i++) {
// Lower-bound of the interval.
double a = (i == 0) ? -sqrtTwoTimesLastNumPoints : previousPoints[i - 1].doubleValue();
// Upper-bound of the interval.
double b = (iMax == 1) ? -0.5 : previousPoints[i].doubleValue();
// H[j-1](a)
double hma = H0;
// H[j](a)
double ha = H1 * a;
// H[j-1](b)
double hmb = H0;
// H[j](b)
double hb = H1 * b;
for (int j = 1; j < numberOfPoints; j++) {
// Compute H[j+1](a) and H[j+1](b)
final double jp1 = j + 1;
final double s = FastMath.sqrt(2 / jp1);
final double sm = FastMath.sqrt(j / jp1);
final double hpa = s * a * ha - sm * hma;
final double hpb = s * b * hb - sm * hmb;
hma = ha;
ha = hpa;
hmb = hb;
hb = hpb;
}
// Now ha = H[n+1](a), and hma = H[n](a) (same holds for b).
// Middle of the interval.
double c = 0.5 * (a + b);
// P[j-1](c)
double hmc = H0;
// P[j](c)
double hc = H1 * c;
boolean done = false;
while (!done) {
done = b - a <= Math.ulp(c);
hmc = H0;
hc = H1 * c;
for (int j = 1; j < numberOfPoints; j++) {
// Compute H[j+1](c)
final double jp1 = j + 1;
final double s = FastMath.sqrt(2 / jp1);
final double sm = FastMath.sqrt(j / jp1);
final double hpc = s * c * hc - sm * hmc;
hmc = hc;
hc = hpc;
}
// Now h = H[n+1](c) and hm = H[n](c).
if (!done) {
if (ha * hc < 0) {
b = c;
hmb = hmc;
hb = hc;
} else {
a = c;
hma = hmc;
ha = hc;
}
c = 0.5 * (a + b);
}
}
final double d = sqrtTwoTimesNumPoints * hmc;
final double w = 2 / (d * d);
points[i] = c;
weights[i] = w;
final int idx = lastNumPoints - i;
points[idx] = -c;
weights[idx] = w;
}
// If "numberOfPoints" is odd, 0 is a root.
// Note: as written, the test for oddness will work for negative
// integers too (although it is not necessary here), preventing
// a FindBugs warning.
if (numberOfPoints % 2 != 0) {
double hm = H0;
for (int j = 1; j < numberOfPoints; j += 2) {
final double jp1 = j + 1;
hm = -FastMath.sqrt(j / jp1) * hm;
}
final double d = sqrtTwoTimesNumPoints * hm;
final double w = 2 / (d * d);
points[iMax] = 0d;
weights[iMax] = w;
}
return new Pair<Double[], Double[]>(points, weights);
}
}

View File

@ -0,0 +1,106 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.analysis.integration.gauss;
import org.apache.commons.math3.analysis.UnivariateFunction;
import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.exception.NonMonotonicSequenceException;
import org.apache.commons.math3.util.MathArrays;
import org.apache.commons.math3.util.Pair;
/**
* This class's implements {@link #integrate(UnivariateFunction) integrate}
* method assuming that the integral is symmetric about 0.
* This allows to reduce numerical errors.
*
* @since 3.3
* @version $Id$
*/
public class SymmetricGaussIntegrator extends GaussIntegrator {
/**
* Creates an integrator from the given {@code points} and {@code weights}.
* The integration interval is defined by the first and last value of
* {@code points} which must be sorted in increasing order.
*
* @param points Integration points.
* @param weights Weights of the corresponding integration nodes.
* @throws NonMonotonicSequenceException if the {@code points} are not
* sorted in increasing order.
* @throws DimensionMismatchException if points and weights don't have the same length
*/
public SymmetricGaussIntegrator(double[] points,
double[] weights)
throws NonMonotonicSequenceException, DimensionMismatchException {
super(points, weights);
}
/**
* Creates an integrator from the given pair of points (first element of
* the pair) and weights (second element of the pair.
*
* @param pointsAndWeights Integration points and corresponding weights.
* @throws NonMonotonicSequenceException if the {@code points} are not
* sorted in increasing order.
*
* @see #SymmetricGaussIntegrator(double[], double[])
*/
public SymmetricGaussIntegrator(Pair<double[], double[]> pointsAndWeights)
throws NonMonotonicSequenceException {
this(pointsAndWeights.getFirst(), pointsAndWeights.getSecond());
}
/**
* {@inheritDoc}
*/
@Override
public double integrate(UnivariateFunction f) {
final int ruleLength = getNumberOfPoints();
if (ruleLength == 1) {
return getWeight(0) * f.value(0d);
}
final int iMax = ruleLength / 2;
double s = 0;
double c = 0;
for (int i = 0; i < iMax; i++) {
final double p = getPoint(i);
final double w = getWeight(i);
final double f1 = f.value(p);
final double f2 = f.value(-p);
final double y = w * (f1 + f2) - c;
final double t = s + y;
c = (t - s) - y;
s = t;
}
if (ruleLength % 2 == 1) {
final double w = getWeight(iMax);
final double y = w * f.value(0d) - c;
final double t = s + y;
c = (t - s) - y;
s = t;
}
return s;
}
}

View File

@ -0,0 +1,75 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.analysis.integration.gauss;
import org.apache.commons.math3.analysis.UnivariateFunction;
import org.apache.commons.math3.analysis.function.Constant;
import org.apache.commons.math3.util.Pair;
import org.junit.Test;
import org.junit.Assert;
/**
* Test for {@link GaussIntegrator} class.
*
* @version $Id$
*/
public class GaussIntegratorTest {
@Test
public void testGetWeights() {
final double[] points = { 0, 1.2, 3.4 };
final double[] weights = { 9.8, 7.6, 5.4 };
final GaussIntegrator integrator
= new GaussIntegrator(new Pair<double[], double[]>(points, weights));
Assert.assertEquals(weights.length, integrator.getNumberOfPoints());
for (int i = 0; i < integrator.getNumberOfPoints(); i++) {
Assert.assertEquals(weights[i], integrator.getWeight(i), 0d);
}
}
@Test
public void testGetPoints() {
final double[] points = { 0, 1.2, 3.4 };
final double[] weights = { 9.8, 7.6, 5.4 };
final GaussIntegrator integrator
= new GaussIntegrator(new Pair<double[], double[]>(points, weights));
Assert.assertEquals(points.length, integrator.getNumberOfPoints());
for (int i = 0; i < integrator.getNumberOfPoints(); i++) {
Assert.assertEquals(points[i], integrator.getPoint(i), 0d);
}
}
@Test
public void testIntegrate() {
final double[] points = { 0, 1, 2, 3, 4, 5 };
final double[] weights = { 1, 1, 1, 1, 1, 1 };
final GaussIntegrator integrator
= new GaussIntegrator(new Pair<double[], double[]>(points, weights));
final double val = 123.456;
final UnivariateFunction c = new Constant(val);
final double s = integrator.integrate(c);
Assert.assertEquals(points.length * val, s, 0d);
}
}

View File

@ -0,0 +1,96 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.analysis.integration.gauss;
import java.util.ArrayList;
import java.util.Collection;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.runners.Parameterized.Parameters;
import org.apache.commons.math3.util.FastMath;
/**
* Test of the {@link HermiteRuleFactory}.
* This parameterized test extends the standard test for Gaussian quadrature
* rule, where each monomial is tested in turn.
* Parametrization allows to test automatically 0, 1, ... , {@link #MAX_NUM_POINTS}
* quadrature rules.
*
* @version $Id$
*/
@RunWith(value=Parameterized.class)
public class HermiteParametricTest extends GaussianQuadratureAbstractTest {
private static final double SQRT_PI = FastMath.sqrt(Math.PI);
private static final GaussIntegratorFactory factory = new GaussIntegratorFactory();
/**
* The highest order quadrature rule to be tested.
*/
public static final int MAX_NUM_POINTS = 30;
/**
* Creates a new instance of this test, with the specified number of nodes
* for the Gauss-Hermite quadrature rule.
*
* @param numberOfPoints Order of integration rule.
* @param maxDegree Maximum degree of monomials to be tested.
* @param eps Value of &epsilon;.
* @param numUlps Value of the maximum relative error (in ulps).
*/
public HermiteParametricTest(int numberOfPoints,
int maxDegree,
double eps,
double numUlps) {
super(factory.hermite(numberOfPoints),
maxDegree, eps, numUlps);
}
/**
* Returns the collection of parameters to be passed to the constructor of
* this class.
* Gauss-Hermite quadrature rules of order 1, ..., {@link #MAX_NUM_POINTS}
* will be constructed.
*
* @return the collection of parameters for this parameterized test.
*/
@Parameters
public static Collection<Object[]> getParameters() {
final ArrayList<Object[]> parameters = new ArrayList<Object[]>();
for (int k = 1; k <= MAX_NUM_POINTS; k++) {
parameters.add(new Object[] { k, 2 * k - 1, Math.ulp(1d), 195 });
}
return parameters;
}
@Override
public double getExpectedValue(final int n) {
if (n % 2 == 1) {
return 0;
}
final int iMax = n / 2;
double p = 1;
double q = 1;
for (int i = 0; i < iMax; i++) {
p *= 2 * i + 1;
q *= 2;
}
return p / q * SQRT_PI;
}
}

View File

@ -0,0 +1,120 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.analysis.integration.gauss;
import org.apache.commons.math3.analysis.UnivariateFunction;
import org.apache.commons.math3.analysis.function.Gaussian;
import org.apache.commons.math3.distribution.RealDistribution;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.commons.math3.util.FastMath;
import org.apache.commons.math3.util.Pair;
import org.junit.Test;
import org.junit.Assert;
/**
* Test of the {@link HermiteRuleFactory}.
*
* @version $Id$
*/
public class HermiteTest {
private static final GaussIntegratorFactory factory = new GaussIntegratorFactory();
@Test
public void testNormalDistribution() {
final double oneOverSqrtPi = 1 / FastMath.sqrt(Math.PI);
final double mu = 12345.6789;
final double sigma = 987.654321;
// By defintion, Gauss-Hermite quadrature readily provides the
// integral of the normal distribution density.
final int numPoints = 1;
// Change of variable:
// y = (x - mu) / (sqrt(2) * sigma)
// such that the integrand
// N(x, mu, sigma)
// is transformed to
// f(y) * exp(-y^2)
final UnivariateFunction f = new UnivariateFunction() {
@Override
public double value(double y) {
return oneOverSqrtPi; // Constant function.
}
};
final GaussIntegrator integrator = factory.hermite(numPoints);
final double result = integrator.integrate(f);
final double expected = 1;
Assert.assertEquals(expected, result, Math.ulp(expected));
}
@Test
public void testNormalMean() {
final double sqrtTwo = FastMath.sqrt(2);
final double oneOverSqrtPi = 1 / FastMath.sqrt(Math.PI);
final double mu = 12345.6789;
final double sigma = 987.654321;
final int numPoints = 5;
// Change of variable:
// y = (x - mu) / (sqrt(2) * sigma)
// such that the integrand
// x * N(x, mu, sigma)
// is transformed to
// f(y) * exp(-y^2)
final UnivariateFunction f = new UnivariateFunction() {
@Override
public double value(double y) {
return oneOverSqrtPi * (sqrtTwo * sigma * y + mu);
}
};
final GaussIntegrator integrator = factory.hermite(numPoints);
final double result = integrator.integrate(f);
final double expected = mu;
Assert.assertEquals(expected, result, Math.ulp(expected));
}
@Test
public void testNormalVariance() {
final double twoOverSqrtPi = 2 / FastMath.sqrt(Math.PI);
final double mu = 12345.6789;
final double sigma = 987.654321;
final double sigma2 = sigma * sigma;
final int numPoints = 5;
// Change of variable:
// y = (x - mu) / (sqrt(2) * sigma)
// such that the integrand
// (x - mu)^2 * N(x, mu, sigma)
// is transformed to
// f(y) * exp(-y^2)
final UnivariateFunction f = new UnivariateFunction() {
@Override
public double value(double y) {
return twoOverSqrtPi * sigma2 * y * y;
}
};
final GaussIntegrator integrator = factory.hermite(numPoints);
final double result = integrator.integrate(f);
final double expected = sigma2;
Assert.assertEquals(expected, result, 10 * Math.ulp(expected));
}
}