replaced matrix by vector where possible
git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@772114 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
659d3c4a98
commit
6c64326a17
|
@ -18,6 +18,7 @@ package org.apache.commons.math.stat.regression;
|
|||
|
||||
import org.apache.commons.math.linear.RealMatrix;
|
||||
import org.apache.commons.math.linear.RealMatrixImpl;
|
||||
import org.apache.commons.math.linear.RealVector;
|
||||
import org.apache.commons.math.linear.decomposition.LUDecompositionImpl;
|
||||
|
||||
|
||||
|
@ -91,12 +92,12 @@ public class GLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
|
|||
* @return beta
|
||||
*/
|
||||
@Override
|
||||
protected RealMatrix calculateBeta() {
|
||||
protected RealVector calculateBeta() {
|
||||
RealMatrix OI = getOmegaInverse();
|
||||
RealMatrix XT = X.transpose();
|
||||
RealMatrix XTOIX = XT.multiply(OI).multiply(X);
|
||||
RealMatrix inverse = new LUDecompositionImpl(XTOIX).getSolver().getInverse();
|
||||
return inverse.multiply(XT).multiply(OI).multiply(Y);
|
||||
return inverse.multiply(XT).multiply(OI).operate(Y);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -122,9 +123,9 @@ public class GLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
|
|||
*/
|
||||
@Override
|
||||
protected double calculateYVariance() {
|
||||
RealMatrix u = calculateResiduals();
|
||||
RealMatrix sse = u.transpose().multiply(getOmegaInverse()).multiply(u);
|
||||
return sse.getTrace()/(X.getRowDimension()-X.getColumnDimension());
|
||||
RealVector residuals = calculateResiduals();
|
||||
double t = residuals.dotProduct(getOmegaInverse().operate(residuals));
|
||||
return t / (X.getRowDimension() - X.getColumnDimension());
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -18,6 +18,8 @@ package org.apache.commons.math.stat.regression;
|
|||
|
||||
import org.apache.commons.math.linear.RealMatrix;
|
||||
import org.apache.commons.math.linear.RealMatrixImpl;
|
||||
import org.apache.commons.math.linear.RealVector;
|
||||
import org.apache.commons.math.linear.RealVectorImpl;
|
||||
import org.apache.commons.math.linear.decomposition.LUDecompositionImpl;
|
||||
import org.apache.commons.math.linear.decomposition.QRDecomposition;
|
||||
import org.apache.commons.math.linear.decomposition.QRDecompositionImpl;
|
||||
|
@ -137,8 +139,8 @@ public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
|
|||
* @return beta
|
||||
*/
|
||||
@Override
|
||||
protected RealMatrix calculateBeta() {
|
||||
return solveUpperTriangular(qr.getR(), qr.getQ().transpose().multiply(Y));
|
||||
protected RealVector calculateBeta() {
|
||||
return solveUpperTriangular(qr.getR(), qr.getQ().transpose().operate(Y));
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -170,9 +172,9 @@ public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
|
|||
*/
|
||||
@Override
|
||||
protected double calculateYVariance() {
|
||||
RealMatrix u = calculateResiduals();
|
||||
RealMatrix sse = u.transpose().multiply(u);
|
||||
return sse.getTrace()/(X.getRowDimension()-X.getColumnDimension());
|
||||
RealVector residuals = calculateResiduals();
|
||||
return residuals.dotProduct(residuals) /
|
||||
(X.getRowDimension() - X.getColumnDimension());
|
||||
}
|
||||
|
||||
/** TODO: Find a home for the following methods in the linear package */
|
||||
|
@ -191,20 +193,16 @@ public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
|
|||
* Similarly, extra (zero) rows in coefficients are ignored</p>
|
||||
*
|
||||
* @param coefficients upper-triangular coefficients matrix
|
||||
* @param constants column RHS constants matrix
|
||||
* @return solution matrix as a column matrix
|
||||
* @param constants column RHS constants vector
|
||||
* @return solution matrix as a column vector
|
||||
*
|
||||
*/
|
||||
private static RealMatrix solveUpperTriangular(RealMatrix coefficients,
|
||||
RealMatrix constants) {
|
||||
private static RealVector solveUpperTriangular(RealMatrix coefficients,
|
||||
RealVector constants) {
|
||||
if (!isUpperTriangular(coefficients, 1E-12)) {
|
||||
throw new IllegalArgumentException(
|
||||
"Coefficients is not upper-triangular");
|
||||
}
|
||||
if (constants.getColumnDimension() != 1) {
|
||||
throw new IllegalArgumentException(
|
||||
"Constants not a column matrix.");
|
||||
}
|
||||
int length = coefficients.getColumnDimension();
|
||||
double x[] = new double[length];
|
||||
for (int i = 0; i < length; i++) {
|
||||
|
@ -213,9 +211,9 @@ public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
|
|||
for (int j = index + 1; j < length; j++) {
|
||||
sum += coefficients.getEntry(index, j) * x[j];
|
||||
}
|
||||
x[index] = (constants.getEntry(index, 0) - sum) / coefficients.getEntry(index, index);
|
||||
x[index] = (constants.getEntry(index) - sum) / coefficients.getEntry(index, index);
|
||||
}
|
||||
return new RealMatrixImpl(x);
|
||||
return new RealVectorImpl(x);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -139,7 +139,7 @@ public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs
|
|||
new double[]{-3482258.63459582, 15.0618722713733,
|
||||
-0.358191792925910E-01,-2.02022980381683,
|
||||
-1.03322686717359,-0.511041056535807E-01,
|
||||
1829.15146461355}, 1E-8); //
|
||||
1829.15146461355}, 2E-8); //
|
||||
|
||||
// Check expected residuals from R
|
||||
double[] residuals = model.estimateResiduals();
|
||||
|
@ -332,7 +332,7 @@ public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs
|
|||
*/
|
||||
double[] residuals = model.estimateResiduals();
|
||||
RealMatrix I = MatrixUtils.createRealIdentityMatrix(10);
|
||||
double[] hatResiduals = I.subtract(hat).multiply(model.Y).getColumn(0);
|
||||
double[] hatResiduals = I.subtract(hat).operate(model.Y).getData();
|
||||
TestUtils.assertEquals(residuals, hatResiduals, 10e-12);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue