From 7fb87df29458631ece4bc2e284b21ff336c5647e Mon Sep 17 00:00:00 2001 From: Gilles Sadowski Date: Fri, 5 Jul 2013 14:20:19 +0000 Subject: [PATCH] MATH-997 Gauss-Hermite quadrature scheme. git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1500018 13f79535-47bb-0310-9956-ffa450edef68 --- src/changes/changes.xml | 8 + .../integration/gauss/GaussIntegrator.java | 20 ++ .../gauss/GaussIntegratorFactory.java | 33 +++- .../integration/gauss/HermiteRuleFactory.java | 179 ++++++++++++++++++ .../gauss/SymmetricGaussIntegrator.java | 106 +++++++++++ .../gauss/GaussIntegratorTest.java | 75 ++++++++ .../gauss/HermiteParametricTest.java | 96 ++++++++++ .../integration/gauss/HermiteTest.java | 120 ++++++++++++ 8 files changed, 634 insertions(+), 3 deletions(-) create mode 100644 src/main/java/org/apache/commons/math3/analysis/integration/gauss/HermiteRuleFactory.java create mode 100644 src/main/java/org/apache/commons/math3/analysis/integration/gauss/SymmetricGaussIntegrator.java create mode 100644 src/test/java/org/apache/commons/math3/analysis/integration/gauss/GaussIntegratorTest.java create mode 100644 src/test/java/org/apache/commons/math3/analysis/integration/gauss/HermiteParametricTest.java create mode 100644 src/test/java/org/apache/commons/math3/analysis/integration/gauss/HermiteTest.java diff --git a/src/changes/changes.xml b/src/changes/changes.xml index b52adbd4d..8124135d8 100644 --- a/src/changes/changes.xml +++ b/src/changes/changes.xml @@ -51,6 +51,14 @@ If the output is not quite correct, check for invisible trailing spaces! + + Implemented Gauss-Hermite quadrature scheme (in package + "o.a.c.m.analysis.integration.gauss"). + + + Documented limitation of "IterativeLegendreGaussIntegrator" (added + warning about potential wrong usage). + In "GaussNewtonOptimizer", check for convergence before updating the parameters estimation for the next iteration. diff --git a/src/main/java/org/apache/commons/math3/analysis/integration/gauss/GaussIntegrator.java b/src/main/java/org/apache/commons/math3/analysis/integration/gauss/GaussIntegrator.java index feeffaa6d..e29f144b4 100644 --- a/src/main/java/org/apache/commons/math3/analysis/integration/gauss/GaussIntegrator.java +++ b/src/main/java/org/apache/commons/math3/analysis/integration/gauss/GaussIntegrator.java @@ -107,4 +107,24 @@ public class GaussIntegrator { public int getNumberOfPoints() { return points.length; } + + /** + * Gets the integration point at the given index. + * The index must be in the valid range but no check is performed. + * + * @return the integration point. + */ + public double getPoint(int index) { + return points[index]; + } + + /** + * Gets the weight of the integration point at the given index. + * The index must be in the valid range but no check is performed. + * + * @return the weight. + */ + public double getWeight(int index) { + return weights[index]; + } } diff --git a/src/main/java/org/apache/commons/math3/analysis/integration/gauss/GaussIntegratorFactory.java b/src/main/java/org/apache/commons/math3/analysis/integration/gauss/GaussIntegratorFactory.java index c9a5acbee..35df0a073 100644 --- a/src/main/java/org/apache/commons/math3/analysis/integration/gauss/GaussIntegratorFactory.java +++ b/src/main/java/org/apache/commons/math3/analysis/integration/gauss/GaussIntegratorFactory.java @@ -20,7 +20,9 @@ import java.math.BigDecimal; import org.apache.commons.math3.exception.DimensionMismatchException; import org.apache.commons.math3.exception.NotStrictlyPositiveException; +import org.apache.commons.math3.analysis.BivariateFunction; import org.apache.commons.math3.util.Pair; +import org.apache.commons.math3.util.FastMath; /** * Class that provides different ways to compute the nodes and weights to be @@ -34,9 +36,12 @@ public class GaussIntegratorFactory { private final BaseRuleFactory legendre = new LegendreRuleFactory(); /** Generator of Gauss-Legendre integrators. */ private final BaseRuleFactory legendreHighPrecision = new LegendreHighPrecisionRuleFactory(); + /** Generator of Gauss-Hermite integrators. */ + private final BaseRuleFactory hermite = new HermiteRuleFactory(); /** - * Creates an integrator of the given order, and whose call to the + * Creates a Gauss-Legendre integrator of the given order. + * The call to the * {@link GaussIntegrator#integrate(org.apache.commons.math3.analysis.UnivariateFunction) * integrate} method will perform an integration on the natural interval * {@code [-1 , 1]}. @@ -49,7 +54,8 @@ public class GaussIntegratorFactory { } /** - * Creates an integrator of the given order, and whose call to the + * Creates a Gauss-Legendre integrator of the given order. + * The call to the * {@link GaussIntegrator#integrate(org.apache.commons.math3.analysis.UnivariateFunction) * integrate} method will perform an integration on the given interval. * @@ -68,7 +74,8 @@ public class GaussIntegratorFactory { } /** - * Creates an integrator of the given order, and whose call to the + * Creates a Gauss-Legendre integrator of the given order. + * The call to the * {@link GaussIntegrator#integrate(org.apache.commons.math3.analysis.UnivariateFunction) * integrate} method will perform an integration on the natural interval * {@code [-1 , 1]}. @@ -101,6 +108,26 @@ public class GaussIntegratorFactory { lowerBound, upperBound)); } + /** + * Creates a Gauss-Hermite integrator of the given order. + * The call to the + * {@link SymmetricGaussIntegrator#integrate(org.apache.commons.math3.analysis.UnivariateFunction) + * integrate} method will perform a weighted integration on the interval + * {@code [-&inf;, +&inf;]}: the computed value is the improper integral of + * + * e-x2 f(x) + * + * where {@code f(x)} is the function passed to the + * {@link SymmetricGaussIntegrator#integrate(org.apache.commons.math3.analysis.UnivariateFunction) + * integrate} method. + * + * @param numberOfPoints Order of the integration rule. + * @return a Gauss-Hermite integrator. + */ + public SymmetricGaussIntegrator hermite(int numberOfPoints) { + return new SymmetricGaussIntegrator(getRule(hermite, numberOfPoints)); + } + /** * @param factory Integration rule factory. * @param numberOfPoints Order of the integration rule. diff --git a/src/main/java/org/apache/commons/math3/analysis/integration/gauss/HermiteRuleFactory.java b/src/main/java/org/apache/commons/math3/analysis/integration/gauss/HermiteRuleFactory.java new file mode 100644 index 000000000..ac31385dd --- /dev/null +++ b/src/main/java/org/apache/commons/math3/analysis/integration/gauss/HermiteRuleFactory.java @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.analysis.integration.gauss; + +import org.apache.commons.math3.exception.DimensionMismatchException; +import org.apache.commons.math3.util.Pair; +import org.apache.commons.math3.util.FastMath; + +/** + * Factory that creates a + * + * Gauss-type quadrature rule using Hermite polynomials + * of the first kind. + * Such a quadrature rule allows the calculation of improper integrals + * of a function + * + * f(x) e-x2 + * + *
+ * Recurrence relation and weights computation follow + * . + *
+ * The coefficients of the standard Hermite polynomials grow very rapidly; + * in order to avoid overflows, each Hermite polynomial is normalized with + * respect to the underlying scalar product. + * The initial interval for the application of the bisection method is + * based on the roots of the previous Hermite polynomial (interlacing). + * Upper and lower bounds of these roots are provided by + * + * I. Krasikov, + * Nonnegative quadratic forms and bounds on orthogonal polynomials, + * Journal of Approximation theory 111, 31-49 + * + * + * @since 3.3 + * @version $Id$ + */ +public class HermiteRuleFactory extends BaseRuleFactory { + /** π1/2 */ + private static final double SQRT_PI = 1.77245385090551602729; + /** π-1/4 */ + private static final double H0 = 7.5112554446494248286e-1; + /** π-1/4 √2 */ + private static final double H1 = 1.0622519320271969145; + + /** {@inheritDoc} */ + @Override + protected Pair computeRule(int numberOfPoints) + throws DimensionMismatchException { + + if (numberOfPoints == 1) { + // Break recursion. + return new Pair(new Double[] { 0d }, + new Double[] { SQRT_PI }); + } + + // Get previous rule. + // If it has not been computed yet it will trigger a recursive call + // to this method. + final int lastNumPoints = numberOfPoints - 1; + final Double[] previousPoints = getRuleInternal(lastNumPoints).getFirst(); + + // Compute next rule. + final Double[] points = new Double[numberOfPoints]; + final Double[] weights = new Double[numberOfPoints]; + + final double sqrtTwoTimesLastNumPoints = FastMath.sqrt(2 * lastNumPoints); + final double sqrtTwoTimesNumPoints = FastMath.sqrt(2 * numberOfPoints); + + // Find i-th root of H[n+1] by bracketing. + final int iMax = numberOfPoints / 2; + for (int i = 0; i < iMax; i++) { + // Lower-bound of the interval. + double a = (i == 0) ? -sqrtTwoTimesLastNumPoints : previousPoints[i - 1].doubleValue(); + // Upper-bound of the interval. + double b = (iMax == 1) ? -0.5 : previousPoints[i].doubleValue(); + + // H[j-1](a) + double hma = H0; + // H[j](a) + double ha = H1 * a; + // H[j-1](b) + double hmb = H0; + // H[j](b) + double hb = H1 * b; + for (int j = 1; j < numberOfPoints; j++) { + // Compute H[j+1](a) and H[j+1](b) + final double jp1 = j + 1; + final double s = FastMath.sqrt(2 / jp1); + final double sm = FastMath.sqrt(j / jp1); + final double hpa = s * a * ha - sm * hma; + final double hpb = s * b * hb - sm * hmb; + hma = ha; + ha = hpa; + hmb = hb; + hb = hpb; + } + + // Now ha = H[n+1](a), and hma = H[n](a) (same holds for b). + // Middle of the interval. + double c = 0.5 * (a + b); + // P[j-1](c) + double hmc = H0; + // P[j](c) + double hc = H1 * c; + boolean done = false; + while (!done) { + done = b - a <= Math.ulp(c); + hmc = H0; + hc = H1 * c; + for (int j = 1; j < numberOfPoints; j++) { + // Compute H[j+1](c) + final double jp1 = j + 1; + final double s = FastMath.sqrt(2 / jp1); + final double sm = FastMath.sqrt(j / jp1); + final double hpc = s * c * hc - sm * hmc; + hmc = hc; + hc = hpc; + } + // Now h = H[n+1](c) and hm = H[n](c). + if (!done) { + if (ha * hc < 0) { + b = c; + hmb = hmc; + hb = hc; + } else { + a = c; + hma = hmc; + ha = hc; + } + c = 0.5 * (a + b); + } + } + final double d = sqrtTwoTimesNumPoints * hmc; + final double w = 2 / (d * d); + + points[i] = c; + weights[i] = w; + + final int idx = lastNumPoints - i; + points[idx] = -c; + weights[idx] = w; + } + + // If "numberOfPoints" is odd, 0 is a root. + // Note: as written, the test for oddness will work for negative + // integers too (although it is not necessary here), preventing + // a FindBugs warning. + if (numberOfPoints % 2 != 0) { + double hm = H0; + for (int j = 1; j < numberOfPoints; j += 2) { + final double jp1 = j + 1; + hm = -FastMath.sqrt(j / jp1) * hm; + } + final double d = sqrtTwoTimesNumPoints * hm; + final double w = 2 / (d * d); + + points[iMax] = 0d; + weights[iMax] = w; + } + + return new Pair(points, weights); + } +} diff --git a/src/main/java/org/apache/commons/math3/analysis/integration/gauss/SymmetricGaussIntegrator.java b/src/main/java/org/apache/commons/math3/analysis/integration/gauss/SymmetricGaussIntegrator.java new file mode 100644 index 000000000..33cf4e00c --- /dev/null +++ b/src/main/java/org/apache/commons/math3/analysis/integration/gauss/SymmetricGaussIntegrator.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.analysis.integration.gauss; + +import org.apache.commons.math3.analysis.UnivariateFunction; +import org.apache.commons.math3.exception.DimensionMismatchException; +import org.apache.commons.math3.exception.NonMonotonicSequenceException; +import org.apache.commons.math3.util.MathArrays; +import org.apache.commons.math3.util.Pair; + +/** + * This class's implements {@link #integrate(UnivariateFunction) integrate} + * method assuming that the integral is symmetric about 0. + * This allows to reduce numerical errors. + * + * @since 3.3 + * @version $Id$ + */ +public class SymmetricGaussIntegrator extends GaussIntegrator { + /** + * Creates an integrator from the given {@code points} and {@code weights}. + * The integration interval is defined by the first and last value of + * {@code points} which must be sorted in increasing order. + * + * @param points Integration points. + * @param weights Weights of the corresponding integration nodes. + * @throws NonMonotonicSequenceException if the {@code points} are not + * sorted in increasing order. + * @throws DimensionMismatchException if points and weights don't have the same length + */ + public SymmetricGaussIntegrator(double[] points, + double[] weights) + throws NonMonotonicSequenceException, DimensionMismatchException { + super(points, weights); + } + + /** + * Creates an integrator from the given pair of points (first element of + * the pair) and weights (second element of the pair. + * + * @param pointsAndWeights Integration points and corresponding weights. + * @throws NonMonotonicSequenceException if the {@code points} are not + * sorted in increasing order. + * + * @see #SymmetricGaussIntegrator(double[], double[]) + */ + public SymmetricGaussIntegrator(Pair pointsAndWeights) + throws NonMonotonicSequenceException { + this(pointsAndWeights.getFirst(), pointsAndWeights.getSecond()); + } + + /** + * {@inheritDoc} + */ + @Override + public double integrate(UnivariateFunction f) { + final int ruleLength = getNumberOfPoints(); + + if (ruleLength == 1) { + return getWeight(0) * f.value(0d); + } + + final int iMax = ruleLength / 2; + double s = 0; + double c = 0; + for (int i = 0; i < iMax; i++) { + final double p = getPoint(i); + final double w = getWeight(i); + + final double f1 = f.value(p); + final double f2 = f.value(-p); + + final double y = w * (f1 + f2) - c; + final double t = s + y; + + c = (t - s) - y; + s = t; + } + + if (ruleLength % 2 == 1) { + final double w = getWeight(iMax); + + final double y = w * f.value(0d) - c; + final double t = s + y; + + c = (t - s) - y; + s = t; + } + + return s; + } +} diff --git a/src/test/java/org/apache/commons/math3/analysis/integration/gauss/GaussIntegratorTest.java b/src/test/java/org/apache/commons/math3/analysis/integration/gauss/GaussIntegratorTest.java new file mode 100644 index 000000000..2265e9a47 --- /dev/null +++ b/src/test/java/org/apache/commons/math3/analysis/integration/gauss/GaussIntegratorTest.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.analysis.integration.gauss; + +import org.apache.commons.math3.analysis.UnivariateFunction; +import org.apache.commons.math3.analysis.function.Constant; +import org.apache.commons.math3.util.Pair; +import org.junit.Test; +import org.junit.Assert; + +/** + * Test for {@link GaussIntegrator} class. + * + * @version $Id$ + */ +public class GaussIntegratorTest { + @Test + public void testGetWeights() { + final double[] points = { 0, 1.2, 3.4 }; + final double[] weights = { 9.8, 7.6, 5.4 }; + + final GaussIntegrator integrator + = new GaussIntegrator(new Pair(points, weights)); + + Assert.assertEquals(weights.length, integrator.getNumberOfPoints()); + + for (int i = 0; i < integrator.getNumberOfPoints(); i++) { + Assert.assertEquals(weights[i], integrator.getWeight(i), 0d); + } + } + + @Test + public void testGetPoints() { + final double[] points = { 0, 1.2, 3.4 }; + final double[] weights = { 9.8, 7.6, 5.4 }; + + final GaussIntegrator integrator + = new GaussIntegrator(new Pair(points, weights)); + + Assert.assertEquals(points.length, integrator.getNumberOfPoints()); + + for (int i = 0; i < integrator.getNumberOfPoints(); i++) { + Assert.assertEquals(points[i], integrator.getPoint(i), 0d); + } + } + + @Test + public void testIntegrate() { + final double[] points = { 0, 1, 2, 3, 4, 5 }; + final double[] weights = { 1, 1, 1, 1, 1, 1 }; + + final GaussIntegrator integrator + = new GaussIntegrator(new Pair(points, weights)); + + final double val = 123.456; + final UnivariateFunction c = new Constant(val); + + final double s = integrator.integrate(c); + Assert.assertEquals(points.length * val, s, 0d); + } +} diff --git a/src/test/java/org/apache/commons/math3/analysis/integration/gauss/HermiteParametricTest.java b/src/test/java/org/apache/commons/math3/analysis/integration/gauss/HermiteParametricTest.java new file mode 100644 index 000000000..d4c8a4d44 --- /dev/null +++ b/src/test/java/org/apache/commons/math3/analysis/integration/gauss/HermiteParametricTest.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.analysis.integration.gauss; + +import java.util.ArrayList; +import java.util.Collection; + +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameters; +import org.apache.commons.math3.util.FastMath; + +/** + * Test of the {@link HermiteRuleFactory}. + * This parameterized test extends the standard test for Gaussian quadrature + * rule, where each monomial is tested in turn. + * Parametrization allows to test automatically 0, 1, ... , {@link #MAX_NUM_POINTS} + * quadrature rules. + * + * @version $Id$ + */ +@RunWith(value=Parameterized.class) +public class HermiteParametricTest extends GaussianQuadratureAbstractTest { + private static final double SQRT_PI = FastMath.sqrt(Math.PI); + private static final GaussIntegratorFactory factory = new GaussIntegratorFactory(); + + /** + * The highest order quadrature rule to be tested. + */ + public static final int MAX_NUM_POINTS = 30; + + /** + * Creates a new instance of this test, with the specified number of nodes + * for the Gauss-Hermite quadrature rule. + * + * @param numberOfPoints Order of integration rule. + * @param maxDegree Maximum degree of monomials to be tested. + * @param eps Value of ε. + * @param numUlps Value of the maximum relative error (in ulps). + */ + public HermiteParametricTest(int numberOfPoints, + int maxDegree, + double eps, + double numUlps) { + super(factory.hermite(numberOfPoints), + maxDegree, eps, numUlps); + } + + /** + * Returns the collection of parameters to be passed to the constructor of + * this class. + * Gauss-Hermite quadrature rules of order 1, ..., {@link #MAX_NUM_POINTS} + * will be constructed. + * + * @return the collection of parameters for this parameterized test. + */ + @Parameters + public static Collection getParameters() { + final ArrayList parameters = new ArrayList(); + for (int k = 1; k <= MAX_NUM_POINTS; k++) { + parameters.add(new Object[] { k, 2 * k - 1, Math.ulp(1d), 195 }); + } + return parameters; + } + + @Override + public double getExpectedValue(final int n) { + if (n % 2 == 1) { + return 0; + } + + final int iMax = n / 2; + double p = 1; + double q = 1; + for (int i = 0; i < iMax; i++) { + p *= 2 * i + 1; + q *= 2; + } + + return p / q * SQRT_PI; + } +} diff --git a/src/test/java/org/apache/commons/math3/analysis/integration/gauss/HermiteTest.java b/src/test/java/org/apache/commons/math3/analysis/integration/gauss/HermiteTest.java new file mode 100644 index 000000000..7069f7353 --- /dev/null +++ b/src/test/java/org/apache/commons/math3/analysis/integration/gauss/HermiteTest.java @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.analysis.integration.gauss; + +import org.apache.commons.math3.analysis.UnivariateFunction; +import org.apache.commons.math3.analysis.function.Gaussian; +import org.apache.commons.math3.distribution.RealDistribution; +import org.apache.commons.math3.distribution.NormalDistribution; +import org.apache.commons.math3.util.FastMath; +import org.apache.commons.math3.util.Pair; +import org.junit.Test; +import org.junit.Assert; + +/** + * Test of the {@link HermiteRuleFactory}. + * + * @version $Id$ + */ +public class HermiteTest { + private static final GaussIntegratorFactory factory = new GaussIntegratorFactory(); + + @Test + public void testNormalDistribution() { + final double oneOverSqrtPi = 1 / FastMath.sqrt(Math.PI); + + final double mu = 12345.6789; + final double sigma = 987.654321; + // By defintion, Gauss-Hermite quadrature readily provides the + // integral of the normal distribution density. + final int numPoints = 1; + + // Change of variable: + // y = (x - mu) / (sqrt(2) * sigma) + // such that the integrand + // N(x, mu, sigma) + // is transformed to + // f(y) * exp(-y^2) + final UnivariateFunction f = new UnivariateFunction() { + @Override + public double value(double y) { + return oneOverSqrtPi; // Constant function. + } + }; + + final GaussIntegrator integrator = factory.hermite(numPoints); + final double result = integrator.integrate(f); + final double expected = 1; + Assert.assertEquals(expected, result, Math.ulp(expected)); + } + + @Test + public void testNormalMean() { + final double sqrtTwo = FastMath.sqrt(2); + final double oneOverSqrtPi = 1 / FastMath.sqrt(Math.PI); + + final double mu = 12345.6789; + final double sigma = 987.654321; + final int numPoints = 5; + + // Change of variable: + // y = (x - mu) / (sqrt(2) * sigma) + // such that the integrand + // x * N(x, mu, sigma) + // is transformed to + // f(y) * exp(-y^2) + final UnivariateFunction f = new UnivariateFunction() { + @Override + public double value(double y) { + return oneOverSqrtPi * (sqrtTwo * sigma * y + mu); + } + }; + + final GaussIntegrator integrator = factory.hermite(numPoints); + final double result = integrator.integrate(f); + final double expected = mu; + Assert.assertEquals(expected, result, Math.ulp(expected)); + } + + @Test + public void testNormalVariance() { + final double twoOverSqrtPi = 2 / FastMath.sqrt(Math.PI); + + final double mu = 12345.6789; + final double sigma = 987.654321; + final double sigma2 = sigma * sigma; + final int numPoints = 5; + + // Change of variable: + // y = (x - mu) / (sqrt(2) * sigma) + // such that the integrand + // (x - mu)^2 * N(x, mu, sigma) + // is transformed to + // f(y) * exp(-y^2) + final UnivariateFunction f = new UnivariateFunction() { + @Override + public double value(double y) { + return twoOverSqrtPi * sigma2 * y * y; + } + }; + + final GaussIntegrator integrator = factory.hermite(numPoints); + final double result = integrator.integrate(f); + final double expected = sigma2; + Assert.assertEquals(expected, result, 10 * Math.ulp(expected)); + } +}