MATH-924
Avoid memory exhaustion for large number of unclorrelated observations. git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1426759 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
b07ecae3d6
commit
2836a6f9ef
|
@ -18,7 +18,7 @@
|
|||
package org.apache.commons.math3.optimization;
|
||||
|
||||
import org.apache.commons.math3.linear.RealMatrix;
|
||||
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
|
||||
import org.apache.commons.math3.linear.DiagonalMatrix;
|
||||
import org.apache.commons.math3.linear.NonSquareMatrixException;
|
||||
|
||||
/**
|
||||
|
@ -41,11 +41,7 @@ public class Weight implements OptimizationData {
|
|||
* @param weight List of the values of the diagonal.
|
||||
*/
|
||||
public Weight(double[] weight) {
|
||||
final int dim = weight.length;
|
||||
weightMatrix = new Array2DRowRealMatrix(dim, dim);
|
||||
for (int i = 0; i < dim; i++) {
|
||||
weightMatrix.setEntry(i, i, weight[i]);
|
||||
}
|
||||
weightMatrix = new DiagonalMatrix(weight);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -26,6 +26,7 @@ import org.apache.commons.math3.exception.NumberIsTooSmallException;
|
|||
import org.apache.commons.math3.exception.util.LocalizedFormats;
|
||||
import org.apache.commons.math3.linear.ArrayRealVector;
|
||||
import org.apache.commons.math3.linear.RealMatrix;
|
||||
import org.apache.commons.math3.linear.DiagonalMatrix;
|
||||
import org.apache.commons.math3.linear.DecompositionSolver;
|
||||
import org.apache.commons.math3.linear.MatrixUtils;
|
||||
import org.apache.commons.math3.linear.QRDecomposition;
|
||||
|
@ -558,7 +559,16 @@ public abstract class AbstractLeastSquaresOptimizer
|
|||
* @return the square-root of the weight matrix.
|
||||
*/
|
||||
private RealMatrix squareRoot(RealMatrix m) {
|
||||
final EigenDecomposition dec = new EigenDecomposition(m);
|
||||
return dec.getSquareRoot();
|
||||
if (m instanceof DiagonalMatrix) {
|
||||
final int dim = m.getRowDimension();
|
||||
final RealMatrix sqrtM = new DiagonalMatrix(dim);
|
||||
for (int i = 0; i < dim; i++) {
|
||||
sqrtM.setEntry(i, i, FastMath.sqrt(m.getEntry(i, i)));
|
||||
}
|
||||
return sqrtM;
|
||||
} else {
|
||||
final EigenDecomposition dec = new EigenDecomposition(m);
|
||||
return dec.getSquareRoot();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -223,6 +223,33 @@ public class PolynomialFitterTest {
|
|||
checkUnsolvableProblem(new GaussNewtonOptimizer(true, new SimpleVectorValueChecker(1e-15, 1e-15)), false);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLargeSample() {
|
||||
Random randomizer = new Random(0x5551480dca5b369bl);
|
||||
double maxError = 0;
|
||||
for (int degree = 0; degree < 10; ++degree) {
|
||||
PolynomialFunction p = buildRandomPolynomial(degree, randomizer);
|
||||
|
||||
PolynomialFitter fitter = new PolynomialFitter(new LevenbergMarquardtOptimizer());
|
||||
for (int i = 0; i < 40000; ++i) {
|
||||
double x = -1.0 + i / 20000.0;
|
||||
fitter.addObservedPoint(1.0, x,
|
||||
p.value(x) + 0.1 * randomizer.nextGaussian());
|
||||
}
|
||||
|
||||
final double[] init = new double[degree + 1];
|
||||
PolynomialFunction fitted = new PolynomialFunction(fitter.fit(init));
|
||||
|
||||
for (double x = -1.0; x < 1.0; x += 0.01) {
|
||||
double error = FastMath.abs(p.value(x) - fitted.value(x)) /
|
||||
(1.0 + FastMath.abs(p.value(x)));
|
||||
maxError = FastMath.max(maxError, error);
|
||||
Assert.assertTrue(FastMath.abs(error) < 0.01);
|
||||
}
|
||||
}
|
||||
Assert.assertTrue(maxError > 0.001);
|
||||
}
|
||||
|
||||
private void checkUnsolvableProblem(DifferentiableMultivariateVectorOptimizer optimizer,
|
||||
boolean solvable) {
|
||||
Random randomizer = new Random(1248788532l);
|
||||
|
|
Loading…
Reference in New Issue