replaced calls to deprecated methods from linear algebra package

git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/branches/MATH_2_0@705211 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Luc Maisonobe 2008-10-16 11:34:38 +00:00
parent c376e656d7
commit 098eaa3a78
3 changed files with 34 additions and 12 deletions

View File

@ -20,8 +20,11 @@ package org.apache.commons.math.estimation;
import java.io.Serializable;
import org.apache.commons.math.linear.InvalidMatrixException;
import org.apache.commons.math.linear.LUDecompositionImpl;
import org.apache.commons.math.linear.RealMatrix;
import org.apache.commons.math.linear.RealMatrixImpl;
import org.apache.commons.math.linear.RealVector;
import org.apache.commons.math.linear.RealVectorImpl;
/**
* This class implements a solver for estimation problems.
@ -106,8 +109,8 @@ public class GaussNewtonEstimator extends AbstractEstimator implements Serializa
// work matrices
double[] grad = new double[parameters.length];
RealMatrixImpl bDecrement = new RealMatrixImpl(parameters.length, 1);
double[][] bDecrementData = bDecrement.getDataRef();
RealVectorImpl bDecrement = new RealVectorImpl(parameters.length);
double[] bDecrementData = bDecrement.getDataRef();
RealMatrixImpl wGradGradT = new RealMatrixImpl(parameters.length, parameters.length);
double[][] wggData = wGradGradT.getDataRef();
@ -117,7 +120,7 @@ public class GaussNewtonEstimator extends AbstractEstimator implements Serializa
// build the linear problem
incrementJacobianEvaluationsCounter();
RealMatrix b = new RealMatrixImpl(parameters.length, 1);
RealVector b = new RealVectorImpl(parameters.length);
RealMatrix a = new RealMatrixImpl(parameters.length, parameters.length);
for (int i = 0; i < measurements.length; ++i) {
if (! measurements [i].isIgnored()) {
@ -128,7 +131,7 @@ public class GaussNewtonEstimator extends AbstractEstimator implements Serializa
// compute the normal equation
for (int j = 0; j < parameters.length; ++j) {
grad[j] = measurements[i].getPartial(parameters[j]);
bDecrementData[j][0] = weight * residual * grad[j];
bDecrementData[j] = weight * residual * grad[j];
}
// build the contribution matrix for measurement i
@ -150,11 +153,11 @@ public class GaussNewtonEstimator extends AbstractEstimator implements Serializa
try {
// solve the linearized least squares problem
RealMatrix dX = a.solve(b);
RealVector dX = new LUDecompositionImpl(a).solve(b);
// update the estimated parameters
for (int i = 0; i < parameters.length; ++i) {
parameters[i].setEstimate(parameters[i].getEstimate() + dX.getEntry(i, 0));
parameters[i].setEstimate(parameters[i].getEstimate() + dX.getEntry(i));
}
} catch(InvalidMatrixException e) {

View File

@ -16,6 +16,7 @@
*/
package org.apache.commons.math.stat.regression;
import org.apache.commons.math.linear.LUDecompositionImpl;
import org.apache.commons.math.linear.RealMatrix;
import org.apache.commons.math.linear.RealMatrixImpl;
@ -44,6 +45,9 @@ public class GLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
/** Covariance matrix. */
private RealMatrix Omega;
/** Inverse of covariance matrix. */
private RealMatrix OmegaInverse;
public void newSampleData(double[] y, double[][] x, double[][] covariance) {
validateSampleData(x, y);
newYSampleData(y);
@ -59,6 +63,19 @@ public class GLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
*/
protected void newCovarianceData(double[][] omega){
this.Omega = new RealMatrixImpl(omega);
this.OmegaInverse = null;
}
/**
* Get the inverse of the covariance.
* <p>The inverse of the covariance matrix is lazily evaluated and cached.</p>
* @return inverse of the covariance
*/
protected RealMatrix getOmegaInverse() {
if (OmegaInverse == null) {
OmegaInverse = new LUDecompositionImpl(Omega).getInverse();
}
return OmegaInverse;
}
/**
@ -69,10 +86,10 @@ public class GLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
* @return beta
*/
protected RealMatrix calculateBeta() {
RealMatrix OI = Omega.inverse();
RealMatrix OI = getOmegaInverse();
RealMatrix XT = X.transpose();
RealMatrix XTOIX = XT.multiply(OI).multiply(X);
return XTOIX.inverse().multiply(XT).multiply(OI).multiply(Y);
return new LUDecompositionImpl(XTOIX).getInverse().multiply(XT).multiply(OI).multiply(Y);
}
/**
@ -83,8 +100,9 @@ public class GLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
* @return The beta variance matrix
*/
protected RealMatrix calculateBetaVariance() {
RealMatrix XTOIX = X.transpose().multiply(Omega.inverse()).multiply(X);
return XTOIX.inverse();
RealMatrix OI = getOmegaInverse();
RealMatrix XTOIX = X.transpose().multiply(OI).multiply(X);
return new LUDecompositionImpl(XTOIX).getInverse();
}
/**
@ -96,7 +114,7 @@ public class GLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
*/
protected double calculateYVariance() {
RealMatrix u = calculateResiduals();
RealMatrix sse = u.transpose().multiply(Omega.inverse()).multiply(u);
RealMatrix sse = u.transpose().multiply(getOmegaInverse()).multiply(u);
return sse.getTrace()/(X.getRowDimension()-X.getColumnDimension());
}

View File

@ -16,6 +16,7 @@
*/
package org.apache.commons.math.stat.regression;
import org.apache.commons.math.linear.LUDecompositionImpl;
import org.apache.commons.math.linear.QRDecomposition;
import org.apache.commons.math.linear.QRDecompositionImpl;
import org.apache.commons.math.linear.RealMatrix;
@ -107,7 +108,7 @@ public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
*/
protected RealMatrix calculateBetaVariance() {
RealMatrix XTX = X.transpose().multiply(X);
return XTX.inverse();
return new LUDecompositionImpl(XTX).getInverse();
}