diff --git a/src/main/java/org/apache/commons/math3/optimization/Weight.java b/src/main/java/org/apache/commons/math3/optimization/Weight.java index 8e7538f22..28c161903 100644 --- a/src/main/java/org/apache/commons/math3/optimization/Weight.java +++ b/src/main/java/org/apache/commons/math3/optimization/Weight.java @@ -18,7 +18,7 @@ package org.apache.commons.math3.optimization; import org.apache.commons.math3.linear.RealMatrix; -import org.apache.commons.math3.linear.Array2DRowRealMatrix; +import org.apache.commons.math3.linear.DiagonalMatrix; import org.apache.commons.math3.linear.NonSquareMatrixException; /** @@ -41,11 +41,7 @@ public class Weight implements OptimizationData { * @param weight List of the values of the diagonal. */ public Weight(double[] weight) { - final int dim = weight.length; - weightMatrix = new Array2DRowRealMatrix(dim, dim); - for (int i = 0; i < dim; i++) { - weightMatrix.setEntry(i, i, weight[i]); - } + weightMatrix = new DiagonalMatrix(weight); } /** diff --git a/src/main/java/org/apache/commons/math3/optimization/general/AbstractLeastSquaresOptimizer.java b/src/main/java/org/apache/commons/math3/optimization/general/AbstractLeastSquaresOptimizer.java index b6c97e7e5..982e559eb 100644 --- a/src/main/java/org/apache/commons/math3/optimization/general/AbstractLeastSquaresOptimizer.java +++ b/src/main/java/org/apache/commons/math3/optimization/general/AbstractLeastSquaresOptimizer.java @@ -26,6 +26,7 @@ import org.apache.commons.math3.exception.NumberIsTooSmallException; import org.apache.commons.math3.exception.util.LocalizedFormats; import org.apache.commons.math3.linear.ArrayRealVector; import org.apache.commons.math3.linear.RealMatrix; +import org.apache.commons.math3.linear.DiagonalMatrix; import org.apache.commons.math3.linear.DecompositionSolver; import org.apache.commons.math3.linear.MatrixUtils; import org.apache.commons.math3.linear.QRDecomposition; @@ -558,7 +559,16 @@ public abstract class AbstractLeastSquaresOptimizer * @return the square-root of the weight matrix. */ private RealMatrix squareRoot(RealMatrix m) { - final EigenDecomposition dec = new EigenDecomposition(m); - return dec.getSquareRoot(); + if (m instanceof DiagonalMatrix) { + final int dim = m.getRowDimension(); + final RealMatrix sqrtM = new DiagonalMatrix(dim); + for (int i = 0; i < dim; i++) { + sqrtM.setEntry(i, i, FastMath.sqrt(m.getEntry(i, i))); + } + return sqrtM; + } else { + final EigenDecomposition dec = new EigenDecomposition(m); + return dec.getSquareRoot(); + } } } diff --git a/src/test/java/org/apache/commons/math3/optimization/fitting/PolynomialFitterTest.java b/src/test/java/org/apache/commons/math3/optimization/fitting/PolynomialFitterTest.java index 7f7743e03..74bd8eee2 100644 --- a/src/test/java/org/apache/commons/math3/optimization/fitting/PolynomialFitterTest.java +++ b/src/test/java/org/apache/commons/math3/optimization/fitting/PolynomialFitterTest.java @@ -223,6 +223,33 @@ public class PolynomialFitterTest { checkUnsolvableProblem(new GaussNewtonOptimizer(true, new SimpleVectorValueChecker(1e-15, 1e-15)), false); } + @Test + public void testLargeSample() { + Random randomizer = new Random(0x5551480dca5b369bl); + double maxError = 0; + for (int degree = 0; degree < 10; ++degree) { + PolynomialFunction p = buildRandomPolynomial(degree, randomizer); + + PolynomialFitter fitter = new PolynomialFitter(new LevenbergMarquardtOptimizer()); + for (int i = 0; i < 40000; ++i) { + double x = -1.0 + i / 20000.0; + fitter.addObservedPoint(1.0, x, + p.value(x) + 0.1 * randomizer.nextGaussian()); + } + + final double[] init = new double[degree + 1]; + PolynomialFunction fitted = new PolynomialFunction(fitter.fit(init)); + + for (double x = -1.0; x < 1.0; x += 0.01) { + double error = FastMath.abs(p.value(x) - fitted.value(x)) / + (1.0 + FastMath.abs(p.value(x))); + maxError = FastMath.max(maxError, error); + Assert.assertTrue(FastMath.abs(error) < 0.01); + } + } + Assert.assertTrue(maxError > 0.001); + } + private void checkUnsolvableProblem(DifferentiableMultivariateVectorOptimizer optimizer, boolean solvable) { Random randomizer = new Random(1248788532l);