replaced matrix by vector where possible

git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@772114 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Luc Maisonobe 2009-05-06 09:40:13 +00:00
parent 659d3c4a98
commit 6c64326a17
3 changed files with 21 additions and 22 deletions

View File

@ -18,6 +18,7 @@ package org.apache.commons.math.stat.regression;
import org.apache.commons.math.linear.RealMatrix;
import org.apache.commons.math.linear.RealMatrixImpl;
import org.apache.commons.math.linear.RealVector;
import org.apache.commons.math.linear.decomposition.LUDecompositionImpl;
@ -91,12 +92,12 @@ public class GLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
* @return beta
*/
@Override
protected RealMatrix calculateBeta() {
protected RealVector calculateBeta() {
RealMatrix OI = getOmegaInverse();
RealMatrix XT = X.transpose();
RealMatrix XTOIX = XT.multiply(OI).multiply(X);
RealMatrix inverse = new LUDecompositionImpl(XTOIX).getSolver().getInverse();
return inverse.multiply(XT).multiply(OI).multiply(Y);
return inverse.multiply(XT).multiply(OI).operate(Y);
}
/**
@ -122,9 +123,9 @@ public class GLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
*/
@Override
protected double calculateYVariance() {
RealMatrix u = calculateResiduals();
RealMatrix sse = u.transpose().multiply(getOmegaInverse()).multiply(u);
return sse.getTrace()/(X.getRowDimension()-X.getColumnDimension());
RealVector residuals = calculateResiduals();
double t = residuals.dotProduct(getOmegaInverse().operate(residuals));
return t / (X.getRowDimension() - X.getColumnDimension());
}
}

View File

@ -18,6 +18,8 @@ package org.apache.commons.math.stat.regression;
import org.apache.commons.math.linear.RealMatrix;
import org.apache.commons.math.linear.RealMatrixImpl;
import org.apache.commons.math.linear.RealVector;
import org.apache.commons.math.linear.RealVectorImpl;
import org.apache.commons.math.linear.decomposition.LUDecompositionImpl;
import org.apache.commons.math.linear.decomposition.QRDecomposition;
import org.apache.commons.math.linear.decomposition.QRDecompositionImpl;
@ -137,8 +139,8 @@ public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
* @return beta
*/
@Override
protected RealMatrix calculateBeta() {
return solveUpperTriangular(qr.getR(), qr.getQ().transpose().multiply(Y));
protected RealVector calculateBeta() {
return solveUpperTriangular(qr.getR(), qr.getQ().transpose().operate(Y));
}
/**
@ -170,9 +172,9 @@ public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
*/
@Override
protected double calculateYVariance() {
RealMatrix u = calculateResiduals();
RealMatrix sse = u.transpose().multiply(u);
return sse.getTrace()/(X.getRowDimension()-X.getColumnDimension());
RealVector residuals = calculateResiduals();
return residuals.dotProduct(residuals) /
(X.getRowDimension() - X.getColumnDimension());
}
/** TODO: Find a home for the following methods in the linear package */
@ -191,20 +193,16 @@ public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
* Similarly, extra (zero) rows in coefficients are ignored</p>
*
* @param coefficients upper-triangular coefficients matrix
* @param constants column RHS constants matrix
* @return solution matrix as a column matrix
* @param constants column RHS constants vector
* @return solution matrix as a column vector
*
*/
private static RealMatrix solveUpperTriangular(RealMatrix coefficients,
RealMatrix constants) {
private static RealVector solveUpperTriangular(RealMatrix coefficients,
RealVector constants) {
if (!isUpperTriangular(coefficients, 1E-12)) {
throw new IllegalArgumentException(
"Coefficients is not upper-triangular");
}
if (constants.getColumnDimension() != 1) {
throw new IllegalArgumentException(
"Constants not a column matrix.");
}
int length = coefficients.getColumnDimension();
double x[] = new double[length];
for (int i = 0; i < length; i++) {
@ -213,9 +211,9 @@ public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
for (int j = index + 1; j < length; j++) {
sum += coefficients.getEntry(index, j) * x[j];
}
x[index] = (constants.getEntry(index, 0) - sum) / coefficients.getEntry(index, index);
x[index] = (constants.getEntry(index) - sum) / coefficients.getEntry(index, index);
}
return new RealMatrixImpl(x);
return new RealVectorImpl(x);
}
/**

View File

@ -139,7 +139,7 @@ public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs
new double[]{-3482258.63459582, 15.0618722713733,
-0.358191792925910E-01,-2.02022980381683,
-1.03322686717359,-0.511041056535807E-01,
1829.15146461355}, 1E-8); //
1829.15146461355}, 2E-8); //
// Check expected residuals from R
double[] residuals = model.estimateResiduals();
@ -332,7 +332,7 @@ public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs
*/
double[] residuals = model.estimateResiduals();
RealMatrix I = MatrixUtils.createRealIdentityMatrix(10);
double[] hatResiduals = I.subtract(hat).multiply(model.Y).getColumn(0);
double[] hatResiduals = I.subtract(hat).operate(model.Y).getData();
TestUtils.assertEquals(residuals, hatResiduals, 10e-12);
}
}