Variable visibility: "protected" -> "private". Added "protected"
getter methods. git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1296570 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
92a4d6b1a0
commit
12482617ae
|
@ -39,14 +39,28 @@ public abstract class AbstractMultipleLinearRegression implements
|
||||||
MultipleLinearRegression {
|
MultipleLinearRegression {
|
||||||
|
|
||||||
/** X sample data. */
|
/** X sample data. */
|
||||||
protected RealMatrix X;
|
private RealMatrix xMatrix;
|
||||||
|
|
||||||
/** Y sample data. */
|
/** Y sample data. */
|
||||||
protected RealVector Y;
|
private RealVector yVector;
|
||||||
|
|
||||||
/** Whether or not the regression model includes an intercept. True means no intercept. */
|
/** Whether or not the regression model includes an intercept. True means no intercept. */
|
||||||
private boolean noIntercept = false;
|
private boolean noIntercept = false;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return the X sample data.
|
||||||
|
*/
|
||||||
|
protected RealMatrix getX() {
|
||||||
|
return xMatrix;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return the Y sample data.
|
||||||
|
*/
|
||||||
|
protected RealVector getY() {
|
||||||
|
return yVector;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @return true if the model has no intercept term; false otherwise
|
* @return true if the model has no intercept term; false otherwise
|
||||||
* @since 2.2
|
* @since 2.2
|
||||||
|
@ -121,8 +135,8 @@ public abstract class AbstractMultipleLinearRegression implements
|
||||||
x[i][j] = data[pointer++];
|
x[i][j] = data[pointer++];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
this.X = new Array2DRowRealMatrix(x);
|
this.xMatrix = new Array2DRowRealMatrix(x);
|
||||||
this.Y = new ArrayRealVector(y);
|
this.yVector = new ArrayRealVector(y);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -139,7 +153,7 @@ public abstract class AbstractMultipleLinearRegression implements
|
||||||
if (y.length == 0) {
|
if (y.length == 0) {
|
||||||
throw new NoDataException();
|
throw new NoDataException();
|
||||||
}
|
}
|
||||||
this.Y = new ArrayRealVector(y);
|
this.yVector = new ArrayRealVector(y);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -175,7 +189,7 @@ public abstract class AbstractMultipleLinearRegression implements
|
||||||
throw new NoDataException();
|
throw new NoDataException();
|
||||||
}
|
}
|
||||||
if (noIntercept) {
|
if (noIntercept) {
|
||||||
this.X = new Array2DRowRealMatrix(x, true);
|
this.xMatrix = new Array2DRowRealMatrix(x, true);
|
||||||
} else { // Augment design matrix with initial unitary column
|
} else { // Augment design matrix with initial unitary column
|
||||||
final int nVars = x[0].length;
|
final int nVars = x[0].length;
|
||||||
final double[][] xAug = new double[x.length][nVars + 1];
|
final double[][] xAug = new double[x.length][nVars + 1];
|
||||||
|
@ -186,7 +200,7 @@ public abstract class AbstractMultipleLinearRegression implements
|
||||||
xAug[i][0] = 1.0d;
|
xAug[i][0] = 1.0d;
|
||||||
System.arraycopy(x[i], 0, xAug[i], 1, nVars);
|
System.arraycopy(x[i], 0, xAug[i], 1, nVars);
|
||||||
}
|
}
|
||||||
this.X = new Array2DRowRealMatrix(xAug, false);
|
this.xMatrix = new Array2DRowRealMatrix(xAug, false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -257,7 +271,7 @@ public abstract class AbstractMultipleLinearRegression implements
|
||||||
*/
|
*/
|
||||||
public double[] estimateResiduals() {
|
public double[] estimateResiduals() {
|
||||||
RealVector b = calculateBeta();
|
RealVector b = calculateBeta();
|
||||||
RealVector e = Y.subtract(X.operate(b));
|
RealVector e = yVector.subtract(xMatrix.operate(b));
|
||||||
return e.toArray();
|
return e.toArray();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -332,7 +346,7 @@ public abstract class AbstractMultipleLinearRegression implements
|
||||||
* @return Y variance
|
* @return Y variance
|
||||||
*/
|
*/
|
||||||
protected double calculateYVariance() {
|
protected double calculateYVariance() {
|
||||||
return new Variance().evaluate(Y.toArray());
|
return new Variance().evaluate(yVector.toArray());
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -349,7 +363,7 @@ public abstract class AbstractMultipleLinearRegression implements
|
||||||
protected double calculateErrorVariance() {
|
protected double calculateErrorVariance() {
|
||||||
RealVector residuals = calculateResiduals();
|
RealVector residuals = calculateResiduals();
|
||||||
return residuals.dotProduct(residuals) /
|
return residuals.dotProduct(residuals) /
|
||||||
(X.getRowDimension() - X.getColumnDimension());
|
(xMatrix.getRowDimension() - xMatrix.getColumnDimension());
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -364,7 +378,7 @@ public abstract class AbstractMultipleLinearRegression implements
|
||||||
*/
|
*/
|
||||||
protected RealVector calculateResiduals() {
|
protected RealVector calculateResiduals() {
|
||||||
RealVector b = calculateBeta();
|
RealVector b = calculateBeta();
|
||||||
return Y.subtract(X.operate(b));
|
return yVector.subtract(xMatrix.operate(b));
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -93,10 +93,10 @@ public class GLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
|
||||||
@Override
|
@Override
|
||||||
protected RealVector calculateBeta() {
|
protected RealVector calculateBeta() {
|
||||||
RealMatrix OI = getOmegaInverse();
|
RealMatrix OI = getOmegaInverse();
|
||||||
RealMatrix XT = X.transpose();
|
RealMatrix XT = getX().transpose();
|
||||||
RealMatrix XTOIX = XT.multiply(OI).multiply(X);
|
RealMatrix XTOIX = XT.multiply(OI).multiply(getX());
|
||||||
RealMatrix inverse = new LUDecomposition(XTOIX).getSolver().getInverse();
|
RealMatrix inverse = new LUDecomposition(XTOIX).getSolver().getInverse();
|
||||||
return inverse.multiply(XT).multiply(OI).operate(Y);
|
return inverse.multiply(XT).multiply(OI).operate(getY());
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -109,7 +109,7 @@ public class GLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
|
||||||
@Override
|
@Override
|
||||||
protected RealMatrix calculateBetaVariance() {
|
protected RealMatrix calculateBetaVariance() {
|
||||||
RealMatrix OI = getOmegaInverse();
|
RealMatrix OI = getOmegaInverse();
|
||||||
RealMatrix XTOIX = X.transpose().multiply(OI).multiply(X);
|
RealMatrix XTOIX = getX().transpose().multiply(OI).multiply(getX());
|
||||||
return new LUDecomposition(XTOIX).getSolver().getInverse();
|
return new LUDecomposition(XTOIX).getSolver().getInverse();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -129,7 +129,7 @@ public class GLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
|
||||||
protected double calculateErrorVariance() {
|
protected double calculateErrorVariance() {
|
||||||
RealVector residuals = calculateResiduals();
|
RealVector residuals = calculateResiduals();
|
||||||
double t = residuals.dotProduct(getOmegaInverse().operate(residuals));
|
double t = residuals.dotProduct(getOmegaInverse().operate(residuals));
|
||||||
return t / (X.getRowDimension() - X.getColumnDimension());
|
return t / (getX().getRowDimension() - getX().getColumnDimension());
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -78,7 +78,7 @@ public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
|
||||||
@Override
|
@Override
|
||||||
public void newSampleData(double[] data, int nobs, int nvars) {
|
public void newSampleData(double[] data, int nobs, int nvars) {
|
||||||
super.newSampleData(data, nobs, nvars);
|
super.newSampleData(data, nobs, nvars);
|
||||||
qr = new QRDecomposition(X);
|
qr = new QRDecomposition(getX());
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -132,9 +132,9 @@ public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
|
||||||
*/
|
*/
|
||||||
public double calculateTotalSumOfSquares() {
|
public double calculateTotalSumOfSquares() {
|
||||||
if (isNoIntercept()) {
|
if (isNoIntercept()) {
|
||||||
return StatUtils.sumSq(Y.toArray());
|
return StatUtils.sumSq(getY().toArray());
|
||||||
} else {
|
} else {
|
||||||
return new SecondMoment().evaluate(Y.toArray());
|
return new SecondMoment().evaluate(getY().toArray());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -180,12 +180,12 @@ public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
|
||||||
* @since 2.2
|
* @since 2.2
|
||||||
*/
|
*/
|
||||||
public double calculateAdjustedRSquared() {
|
public double calculateAdjustedRSquared() {
|
||||||
final double n = X.getRowDimension();
|
final double n = getX().getRowDimension();
|
||||||
if (isNoIntercept()) {
|
if (isNoIntercept()) {
|
||||||
return 1 - (1 - calculateRSquared()) * (n / (n - X.getColumnDimension()));
|
return 1 - (1 - calculateRSquared()) * (n / (n - getX().getColumnDimension()));
|
||||||
} else {
|
} else {
|
||||||
return 1 - (calculateResidualSumOfSquares() * (n - 1)) /
|
return 1 - (calculateResidualSumOfSquares() * (n - 1)) /
|
||||||
(calculateTotalSumOfSquares() * (n - X.getColumnDimension()));
|
(calculateTotalSumOfSquares() * (n - getX().getColumnDimension()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -197,7 +197,7 @@ public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
|
||||||
@Override
|
@Override
|
||||||
protected void newXSampleData(double[][] x) {
|
protected void newXSampleData(double[][] x) {
|
||||||
super.newXSampleData(x);
|
super.newXSampleData(x);
|
||||||
qr = new QRDecomposition(X);
|
qr = new QRDecomposition(getX());
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -207,7 +207,7 @@ public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
protected RealVector calculateBeta() {
|
protected RealVector calculateBeta() {
|
||||||
return qr.getSolver().solve(Y);
|
return qr.getSolver().solve(getY());
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -223,7 +223,7 @@ public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
protected RealMatrix calculateBetaVariance() {
|
protected RealMatrix calculateBetaVariance() {
|
||||||
int p = X.getColumnDimension();
|
int p = getX().getColumnDimension();
|
||||||
RealMatrix Raug = qr.getR().getSubMatrix(0, p - 1 , 0, p - 1);
|
RealMatrix Raug = qr.getR().getSubMatrix(0, p - 1 , 0, p - 1);
|
||||||
RealMatrix Rinv = new LUDecomposition(Raug).getSolver().getInverse();
|
RealMatrix Rinv = new LUDecomposition(Raug).getSolver().getInverse();
|
||||||
return Rinv.multiply(Rinv.transpose());
|
return Rinv.multiply(Rinv.transpose());
|
||||||
|
|
|
@ -178,13 +178,13 @@ public class GLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs
|
||||||
double[][] covariance = MatrixUtils.createRealIdentityMatrix(4).scalarMultiply(2).getData();
|
double[][] covariance = MatrixUtils.createRealIdentityMatrix(4).scalarMultiply(2).getData();
|
||||||
GLSMultipleLinearRegression regression = new GLSMultipleLinearRegression();
|
GLSMultipleLinearRegression regression = new GLSMultipleLinearRegression();
|
||||||
regression.newSampleData(y, x, covariance);
|
regression.newSampleData(y, x, covariance);
|
||||||
RealMatrix combinedX = regression.X.copy();
|
RealMatrix combinedX = regression.getX().copy();
|
||||||
RealVector combinedY = regression.Y.copy();
|
RealVector combinedY = regression.getY().copy();
|
||||||
RealMatrix combinedCovInv = regression.getOmegaInverse();
|
RealMatrix combinedCovInv = regression.getOmegaInverse();
|
||||||
regression.newXSampleData(x);
|
regression.newXSampleData(x);
|
||||||
regression.newYSampleData(y);
|
regression.newYSampleData(y);
|
||||||
Assert.assertEquals(combinedX, regression.X);
|
Assert.assertEquals(combinedX, regression.getX());
|
||||||
Assert.assertEquals(combinedY, regression.Y);
|
Assert.assertEquals(combinedY, regression.getY());
|
||||||
Assert.assertEquals(combinedCovInv, regression.getOmegaInverse());
|
Assert.assertEquals(combinedCovInv, regression.getOmegaInverse());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -253,7 +253,7 @@ public class GLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs
|
||||||
OLSMultipleLinearRegression ols = new OLSMultipleLinearRegression();
|
OLSMultipleLinearRegression ols = new OLSMultipleLinearRegression();
|
||||||
ols.newSampleData(longley, nObs, 6);
|
ols.newSampleData(longley, nObs, 6);
|
||||||
final RealVector b = ols.calculateBeta().copy();
|
final RealVector b = ols.calculateBeta().copy();
|
||||||
final RealMatrix x = ols.X.copy();
|
final RealMatrix x = ols.getX().copy();
|
||||||
|
|
||||||
// Create a GLS model to reuse
|
// Create a GLS model to reuse
|
||||||
GLSMultipleLinearRegression gls = new GLSMultipleLinearRegression();
|
GLSMultipleLinearRegression gls = new GLSMultipleLinearRegression();
|
||||||
|
|
|
@ -86,22 +86,22 @@ public abstract class MultipleLinearRegressionAbstractTest {
|
||||||
};
|
};
|
||||||
AbstractMultipleLinearRegression regression = createRegression();
|
AbstractMultipleLinearRegression regression = createRegression();
|
||||||
regression.newSampleData(design, 4, 3);
|
regression.newSampleData(design, 4, 3);
|
||||||
RealMatrix flatX = regression.X.copy();
|
RealMatrix flatX = regression.getX().copy();
|
||||||
RealVector flatY = regression.Y.copy();
|
RealVector flatY = regression.getY().copy();
|
||||||
regression.newXSampleData(x);
|
regression.newXSampleData(x);
|
||||||
regression.newYSampleData(y);
|
regression.newYSampleData(y);
|
||||||
Assert.assertEquals(flatX, regression.X);
|
Assert.assertEquals(flatX, regression.getX());
|
||||||
Assert.assertEquals(flatY, regression.Y);
|
Assert.assertEquals(flatY, regression.getY());
|
||||||
|
|
||||||
// No intercept
|
// No intercept
|
||||||
regression.setNoIntercept(true);
|
regression.setNoIntercept(true);
|
||||||
regression.newSampleData(design, 4, 3);
|
regression.newSampleData(design, 4, 3);
|
||||||
flatX = regression.X.copy();
|
flatX = regression.getX().copy();
|
||||||
flatY = regression.Y.copy();
|
flatY = regression.getY().copy();
|
||||||
regression.newXSampleData(x);
|
regression.newXSampleData(x);
|
||||||
regression.newYSampleData(y);
|
regression.newYSampleData(y);
|
||||||
Assert.assertEquals(flatX, regression.X);
|
Assert.assertEquals(flatX, regression.getX());
|
||||||
Assert.assertEquals(flatY, regression.Y);
|
Assert.assertEquals(flatY, regression.getY());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(expected=IllegalArgumentException.class)
|
@Test(expected=IllegalArgumentException.class)
|
||||||
|
|
|
@ -434,7 +434,7 @@ public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs
|
||||||
*/
|
*/
|
||||||
double[] residuals = model.estimateResiduals();
|
double[] residuals = model.estimateResiduals();
|
||||||
RealMatrix I = MatrixUtils.createRealIdentityMatrix(10);
|
RealMatrix I = MatrixUtils.createRealIdentityMatrix(10);
|
||||||
double[] hatResiduals = I.subtract(hat).operate(model.Y).toArray();
|
double[] hatResiduals = I.subtract(hat).operate(model.getY()).toArray();
|
||||||
TestUtils.assertEquals(residuals, hatResiduals, 10e-12);
|
TestUtils.assertEquals(residuals, hatResiduals, 10e-12);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -457,11 +457,11 @@ public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs
|
||||||
*/
|
*/
|
||||||
protected void checkVarianceConsistency(OLSMultipleLinearRegression model) throws Exception {
|
protected void checkVarianceConsistency(OLSMultipleLinearRegression model) throws Exception {
|
||||||
// Check Y variance consistency
|
// Check Y variance consistency
|
||||||
TestUtils.assertEquals(StatUtils.variance(model.Y.toArray()), model.calculateYVariance(), 0);
|
TestUtils.assertEquals(StatUtils.variance(model.getY().toArray()), model.calculateYVariance(), 0);
|
||||||
|
|
||||||
// Check residual variance consistency
|
// Check residual variance consistency
|
||||||
double[] residuals = model.calculateResiduals().toArray();
|
double[] residuals = model.calculateResiduals().toArray();
|
||||||
RealMatrix X = model.X;
|
RealMatrix X = model.getX();
|
||||||
TestUtils.assertEquals(
|
TestUtils.assertEquals(
|
||||||
StatUtils.variance(model.calculateResiduals().toArray()) * (residuals.length - 1),
|
StatUtils.variance(model.calculateResiduals().toArray()) * (residuals.length - 1),
|
||||||
model.calculateErrorVariance() * (X.getRowDimension() - X.getColumnDimension()), 1E-20);
|
model.calculateErrorVariance() * (X.getRowDimension() - X.getColumnDimension()), 1E-20);
|
||||||
|
@ -482,22 +482,22 @@ public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs
|
||||||
};
|
};
|
||||||
OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression();
|
OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression();
|
||||||
regression.newSampleData(y, x);
|
regression.newSampleData(y, x);
|
||||||
RealMatrix combinedX = regression.X.copy();
|
RealMatrix combinedX = regression.getX().copy();
|
||||||
RealVector combinedY = regression.Y.copy();
|
RealVector combinedY = regression.getY().copy();
|
||||||
regression.newXSampleData(x);
|
regression.newXSampleData(x);
|
||||||
regression.newYSampleData(y);
|
regression.newYSampleData(y);
|
||||||
Assert.assertEquals(combinedX, regression.X);
|
Assert.assertEquals(combinedX, regression.getX());
|
||||||
Assert.assertEquals(combinedY, regression.Y);
|
Assert.assertEquals(combinedY, regression.getY());
|
||||||
|
|
||||||
// No intercept
|
// No intercept
|
||||||
regression.setNoIntercept(true);
|
regression.setNoIntercept(true);
|
||||||
regression.newSampleData(y, x);
|
regression.newSampleData(y, x);
|
||||||
combinedX = regression.X.copy();
|
combinedX = regression.getX().copy();
|
||||||
combinedY = regression.Y.copy();
|
combinedY = regression.getY().copy();
|
||||||
regression.newXSampleData(x);
|
regression.newXSampleData(x);
|
||||||
regression.newYSampleData(y);
|
regression.newYSampleData(y);
|
||||||
Assert.assertEquals(combinedX, regression.X);
|
Assert.assertEquals(combinedX, regression.getX());
|
||||||
Assert.assertEquals(combinedY, regression.Y);
|
Assert.assertEquals(combinedY, regression.getY());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(expected=IllegalArgumentException.class)
|
@Test(expected=IllegalArgumentException.class)
|
||||||
|
|
Loading…
Reference in New Issue