Variable visibility: "protected" -> "private". Added "protected"

getter methods.


git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1296570 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Gilles Sadowski 2012-03-03 03:35:20 +00:00
parent 92a4d6b1a0
commit 12482617ae
6 changed files with 63 additions and 49 deletions

View File

@ -39,14 +39,28 @@ public abstract class AbstractMultipleLinearRegression implements
MultipleLinearRegression { MultipleLinearRegression {
/** X sample data. */ /** X sample data. */
protected RealMatrix X; private RealMatrix xMatrix;
/** Y sample data. */ /** Y sample data. */
protected RealVector Y; private RealVector yVector;
/** Whether or not the regression model includes an intercept. True means no intercept. */ /** Whether or not the regression model includes an intercept. True means no intercept. */
private boolean noIntercept = false; private boolean noIntercept = false;
/**
* @return the X sample data.
*/
protected RealMatrix getX() {
return xMatrix;
}
/**
* @return the Y sample data.
*/
protected RealVector getY() {
return yVector;
}
/** /**
* @return true if the model has no intercept term; false otherwise * @return true if the model has no intercept term; false otherwise
* @since 2.2 * @since 2.2
@ -121,8 +135,8 @@ public abstract class AbstractMultipleLinearRegression implements
x[i][j] = data[pointer++]; x[i][j] = data[pointer++];
} }
} }
this.X = new Array2DRowRealMatrix(x); this.xMatrix = new Array2DRowRealMatrix(x);
this.Y = new ArrayRealVector(y); this.yVector = new ArrayRealVector(y);
} }
/** /**
@ -139,7 +153,7 @@ public abstract class AbstractMultipleLinearRegression implements
if (y.length == 0) { if (y.length == 0) {
throw new NoDataException(); throw new NoDataException();
} }
this.Y = new ArrayRealVector(y); this.yVector = new ArrayRealVector(y);
} }
/** /**
@ -175,7 +189,7 @@ public abstract class AbstractMultipleLinearRegression implements
throw new NoDataException(); throw new NoDataException();
} }
if (noIntercept) { if (noIntercept) {
this.X = new Array2DRowRealMatrix(x, true); this.xMatrix = new Array2DRowRealMatrix(x, true);
} else { // Augment design matrix with initial unitary column } else { // Augment design matrix with initial unitary column
final int nVars = x[0].length; final int nVars = x[0].length;
final double[][] xAug = new double[x.length][nVars + 1]; final double[][] xAug = new double[x.length][nVars + 1];
@ -186,7 +200,7 @@ public abstract class AbstractMultipleLinearRegression implements
xAug[i][0] = 1.0d; xAug[i][0] = 1.0d;
System.arraycopy(x[i], 0, xAug[i], 1, nVars); System.arraycopy(x[i], 0, xAug[i], 1, nVars);
} }
this.X = new Array2DRowRealMatrix(xAug, false); this.xMatrix = new Array2DRowRealMatrix(xAug, false);
} }
} }
@ -257,7 +271,7 @@ public abstract class AbstractMultipleLinearRegression implements
*/ */
public double[] estimateResiduals() { public double[] estimateResiduals() {
RealVector b = calculateBeta(); RealVector b = calculateBeta();
RealVector e = Y.subtract(X.operate(b)); RealVector e = yVector.subtract(xMatrix.operate(b));
return e.toArray(); return e.toArray();
} }
@ -332,7 +346,7 @@ public abstract class AbstractMultipleLinearRegression implements
* @return Y variance * @return Y variance
*/ */
protected double calculateYVariance() { protected double calculateYVariance() {
return new Variance().evaluate(Y.toArray()); return new Variance().evaluate(yVector.toArray());
} }
/** /**
@ -349,7 +363,7 @@ public abstract class AbstractMultipleLinearRegression implements
protected double calculateErrorVariance() { protected double calculateErrorVariance() {
RealVector residuals = calculateResiduals(); RealVector residuals = calculateResiduals();
return residuals.dotProduct(residuals) / return residuals.dotProduct(residuals) /
(X.getRowDimension() - X.getColumnDimension()); (xMatrix.getRowDimension() - xMatrix.getColumnDimension());
} }
/** /**
@ -364,7 +378,7 @@ public abstract class AbstractMultipleLinearRegression implements
*/ */
protected RealVector calculateResiduals() { protected RealVector calculateResiduals() {
RealVector b = calculateBeta(); RealVector b = calculateBeta();
return Y.subtract(X.operate(b)); return yVector.subtract(xMatrix.operate(b));
} }
} }

View File

@ -93,10 +93,10 @@ public class GLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
@Override @Override
protected RealVector calculateBeta() { protected RealVector calculateBeta() {
RealMatrix OI = getOmegaInverse(); RealMatrix OI = getOmegaInverse();
RealMatrix XT = X.transpose(); RealMatrix XT = getX().transpose();
RealMatrix XTOIX = XT.multiply(OI).multiply(X); RealMatrix XTOIX = XT.multiply(OI).multiply(getX());
RealMatrix inverse = new LUDecomposition(XTOIX).getSolver().getInverse(); RealMatrix inverse = new LUDecomposition(XTOIX).getSolver().getInverse();
return inverse.multiply(XT).multiply(OI).operate(Y); return inverse.multiply(XT).multiply(OI).operate(getY());
} }
/** /**
@ -109,7 +109,7 @@ public class GLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
@Override @Override
protected RealMatrix calculateBetaVariance() { protected RealMatrix calculateBetaVariance() {
RealMatrix OI = getOmegaInverse(); RealMatrix OI = getOmegaInverse();
RealMatrix XTOIX = X.transpose().multiply(OI).multiply(X); RealMatrix XTOIX = getX().transpose().multiply(OI).multiply(getX());
return new LUDecomposition(XTOIX).getSolver().getInverse(); return new LUDecomposition(XTOIX).getSolver().getInverse();
} }
@ -129,7 +129,7 @@ public class GLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
protected double calculateErrorVariance() { protected double calculateErrorVariance() {
RealVector residuals = calculateResiduals(); RealVector residuals = calculateResiduals();
double t = residuals.dotProduct(getOmegaInverse().operate(residuals)); double t = residuals.dotProduct(getOmegaInverse().operate(residuals));
return t / (X.getRowDimension() - X.getColumnDimension()); return t / (getX().getRowDimension() - getX().getColumnDimension());
} }

View File

@ -78,7 +78,7 @@ public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
@Override @Override
public void newSampleData(double[] data, int nobs, int nvars) { public void newSampleData(double[] data, int nobs, int nvars) {
super.newSampleData(data, nobs, nvars); super.newSampleData(data, nobs, nvars);
qr = new QRDecomposition(X); qr = new QRDecomposition(getX());
} }
/** /**
@ -132,9 +132,9 @@ public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
*/ */
public double calculateTotalSumOfSquares() { public double calculateTotalSumOfSquares() {
if (isNoIntercept()) { if (isNoIntercept()) {
return StatUtils.sumSq(Y.toArray()); return StatUtils.sumSq(getY().toArray());
} else { } else {
return new SecondMoment().evaluate(Y.toArray()); return new SecondMoment().evaluate(getY().toArray());
} }
} }
@ -180,12 +180,12 @@ public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
* @since 2.2 * @since 2.2
*/ */
public double calculateAdjustedRSquared() { public double calculateAdjustedRSquared() {
final double n = X.getRowDimension(); final double n = getX().getRowDimension();
if (isNoIntercept()) { if (isNoIntercept()) {
return 1 - (1 - calculateRSquared()) * (n / (n - X.getColumnDimension())); return 1 - (1 - calculateRSquared()) * (n / (n - getX().getColumnDimension()));
} else { } else {
return 1 - (calculateResidualSumOfSquares() * (n - 1)) / return 1 - (calculateResidualSumOfSquares() * (n - 1)) /
(calculateTotalSumOfSquares() * (n - X.getColumnDimension())); (calculateTotalSumOfSquares() * (n - getX().getColumnDimension()));
} }
} }
@ -197,7 +197,7 @@ public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
@Override @Override
protected void newXSampleData(double[][] x) { protected void newXSampleData(double[][] x) {
super.newXSampleData(x); super.newXSampleData(x);
qr = new QRDecomposition(X); qr = new QRDecomposition(getX());
} }
/** /**
@ -207,7 +207,7 @@ public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
*/ */
@Override @Override
protected RealVector calculateBeta() { protected RealVector calculateBeta() {
return qr.getSolver().solve(Y); return qr.getSolver().solve(getY());
} }
/** /**
@ -223,7 +223,7 @@ public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
*/ */
@Override @Override
protected RealMatrix calculateBetaVariance() { protected RealMatrix calculateBetaVariance() {
int p = X.getColumnDimension(); int p = getX().getColumnDimension();
RealMatrix Raug = qr.getR().getSubMatrix(0, p - 1 , 0, p - 1); RealMatrix Raug = qr.getR().getSubMatrix(0, p - 1 , 0, p - 1);
RealMatrix Rinv = new LUDecomposition(Raug).getSolver().getInverse(); RealMatrix Rinv = new LUDecomposition(Raug).getSolver().getInverse();
return Rinv.multiply(Rinv.transpose()); return Rinv.multiply(Rinv.transpose());

View File

@ -178,13 +178,13 @@ public class GLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs
double[][] covariance = MatrixUtils.createRealIdentityMatrix(4).scalarMultiply(2).getData(); double[][] covariance = MatrixUtils.createRealIdentityMatrix(4).scalarMultiply(2).getData();
GLSMultipleLinearRegression regression = new GLSMultipleLinearRegression(); GLSMultipleLinearRegression regression = new GLSMultipleLinearRegression();
regression.newSampleData(y, x, covariance); regression.newSampleData(y, x, covariance);
RealMatrix combinedX = regression.X.copy(); RealMatrix combinedX = regression.getX().copy();
RealVector combinedY = regression.Y.copy(); RealVector combinedY = regression.getY().copy();
RealMatrix combinedCovInv = regression.getOmegaInverse(); RealMatrix combinedCovInv = regression.getOmegaInverse();
regression.newXSampleData(x); regression.newXSampleData(x);
regression.newYSampleData(y); regression.newYSampleData(y);
Assert.assertEquals(combinedX, regression.X); Assert.assertEquals(combinedX, regression.getX());
Assert.assertEquals(combinedY, regression.Y); Assert.assertEquals(combinedY, regression.getY());
Assert.assertEquals(combinedCovInv, regression.getOmegaInverse()); Assert.assertEquals(combinedCovInv, regression.getOmegaInverse());
} }
@ -253,7 +253,7 @@ public class GLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs
OLSMultipleLinearRegression ols = new OLSMultipleLinearRegression(); OLSMultipleLinearRegression ols = new OLSMultipleLinearRegression();
ols.newSampleData(longley, nObs, 6); ols.newSampleData(longley, nObs, 6);
final RealVector b = ols.calculateBeta().copy(); final RealVector b = ols.calculateBeta().copy();
final RealMatrix x = ols.X.copy(); final RealMatrix x = ols.getX().copy();
// Create a GLS model to reuse // Create a GLS model to reuse
GLSMultipleLinearRegression gls = new GLSMultipleLinearRegression(); GLSMultipleLinearRegression gls = new GLSMultipleLinearRegression();

View File

@ -86,22 +86,22 @@ public abstract class MultipleLinearRegressionAbstractTest {
}; };
AbstractMultipleLinearRegression regression = createRegression(); AbstractMultipleLinearRegression regression = createRegression();
regression.newSampleData(design, 4, 3); regression.newSampleData(design, 4, 3);
RealMatrix flatX = regression.X.copy(); RealMatrix flatX = regression.getX().copy();
RealVector flatY = regression.Y.copy(); RealVector flatY = regression.getY().copy();
regression.newXSampleData(x); regression.newXSampleData(x);
regression.newYSampleData(y); regression.newYSampleData(y);
Assert.assertEquals(flatX, regression.X); Assert.assertEquals(flatX, regression.getX());
Assert.assertEquals(flatY, regression.Y); Assert.assertEquals(flatY, regression.getY());
// No intercept // No intercept
regression.setNoIntercept(true); regression.setNoIntercept(true);
regression.newSampleData(design, 4, 3); regression.newSampleData(design, 4, 3);
flatX = regression.X.copy(); flatX = regression.getX().copy();
flatY = regression.Y.copy(); flatY = regression.getY().copy();
regression.newXSampleData(x); regression.newXSampleData(x);
regression.newYSampleData(y); regression.newYSampleData(y);
Assert.assertEquals(flatX, regression.X); Assert.assertEquals(flatX, regression.getX());
Assert.assertEquals(flatY, regression.Y); Assert.assertEquals(flatY, regression.getY());
} }
@Test(expected=IllegalArgumentException.class) @Test(expected=IllegalArgumentException.class)

View File

@ -434,7 +434,7 @@ public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs
*/ */
double[] residuals = model.estimateResiduals(); double[] residuals = model.estimateResiduals();
RealMatrix I = MatrixUtils.createRealIdentityMatrix(10); RealMatrix I = MatrixUtils.createRealIdentityMatrix(10);
double[] hatResiduals = I.subtract(hat).operate(model.Y).toArray(); double[] hatResiduals = I.subtract(hat).operate(model.getY()).toArray();
TestUtils.assertEquals(residuals, hatResiduals, 10e-12); TestUtils.assertEquals(residuals, hatResiduals, 10e-12);
} }
@ -457,11 +457,11 @@ public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs
*/ */
protected void checkVarianceConsistency(OLSMultipleLinearRegression model) throws Exception { protected void checkVarianceConsistency(OLSMultipleLinearRegression model) throws Exception {
// Check Y variance consistency // Check Y variance consistency
TestUtils.assertEquals(StatUtils.variance(model.Y.toArray()), model.calculateYVariance(), 0); TestUtils.assertEquals(StatUtils.variance(model.getY().toArray()), model.calculateYVariance(), 0);
// Check residual variance consistency // Check residual variance consistency
double[] residuals = model.calculateResiduals().toArray(); double[] residuals = model.calculateResiduals().toArray();
RealMatrix X = model.X; RealMatrix X = model.getX();
TestUtils.assertEquals( TestUtils.assertEquals(
StatUtils.variance(model.calculateResiduals().toArray()) * (residuals.length - 1), StatUtils.variance(model.calculateResiduals().toArray()) * (residuals.length - 1),
model.calculateErrorVariance() * (X.getRowDimension() - X.getColumnDimension()), 1E-20); model.calculateErrorVariance() * (X.getRowDimension() - X.getColumnDimension()), 1E-20);
@ -482,22 +482,22 @@ public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs
}; };
OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression(); OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression();
regression.newSampleData(y, x); regression.newSampleData(y, x);
RealMatrix combinedX = regression.X.copy(); RealMatrix combinedX = regression.getX().copy();
RealVector combinedY = regression.Y.copy(); RealVector combinedY = regression.getY().copy();
regression.newXSampleData(x); regression.newXSampleData(x);
regression.newYSampleData(y); regression.newYSampleData(y);
Assert.assertEquals(combinedX, regression.X); Assert.assertEquals(combinedX, regression.getX());
Assert.assertEquals(combinedY, regression.Y); Assert.assertEquals(combinedY, regression.getY());
// No intercept // No intercept
regression.setNoIntercept(true); regression.setNoIntercept(true);
regression.newSampleData(y, x); regression.newSampleData(y, x);
combinedX = regression.X.copy(); combinedX = regression.getX().copy();
combinedY = regression.Y.copy(); combinedY = regression.getY().copy();
regression.newXSampleData(x); regression.newXSampleData(x);
regression.newYSampleData(y); regression.newYSampleData(y);
Assert.assertEquals(combinedX, regression.X); Assert.assertEquals(combinedX, regression.getX());
Assert.assertEquals(combinedY, regression.Y); Assert.assertEquals(combinedY, regression.getY());
} }
@Test(expected=IllegalArgumentException.class) @Test(expected=IllegalArgumentException.class)