Made intercept/noIntercept configurable in multiple regression classes. JIRA: MATH-409.

git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@996404 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Phil Steitz 2010-09-13 02:01:42 +00:00
parent 820915c236
commit aba7e04d44
6 changed files with 257 additions and 58 deletions

View File

@ -39,6 +39,23 @@ public abstract class AbstractMultipleLinearRegression implements
/** Y sample data. */
protected RealVector Y;
/** Whether or not the regression model includes an intercept. True means no intercept. */
private boolean noIntercept = false;
/**
* @return true if the model has no intercept term; false otherwise
*/
public boolean isNoIntercept() {
return noIntercept;
}
/**
* @param noIntercept true means the model is to be estimated without an intercept term
*/
public void setNoIntercept(boolean noIntercept) {
this.noIntercept = noIntercept;
}
/**
* <p>Loads model x and y sample data from a flat input array, overriding any previous sample.
* </p>
@ -55,7 +72,9 @@ public abstract class AbstractMultipleLinearRegression implements
* </pre>
* </p>
* <p>Note that there is no need to add an initial unitary column (column of 1's) when
* specifying a model including an intercept term.
* specifying a model including an intercept term. If {@link #isNoIntercept()} is <code>true</code>,
* the X matrix will be created without an initial column of "1"s; otherwise this column will
* be added.
* </p>
* <p>Throws IllegalArgumentException if any of the following preconditions fail:
* <ul><li><code>data</code> cannot be null</li>
@ -82,12 +101,15 @@ public abstract class AbstractMultipleLinearRegression implements
LocalizedFormats.NOT_ENOUGH_DATA_FOR_NUMBER_OF_PREDICTORS);
}
double[] y = new double[nobs];
double[][] x = new double[nobs][nvars + 1];
final int cols = noIntercept ? nvars: nvars + 1;
double[][] x = new double[nobs][cols];
int pointer = 0;
for (int i = 0; i < nobs; i++) {
y[i] = data[pointer++];
x[i][0] = 1.0d;
for (int j = 1; j < nvars + 1; j++) {
if (!noIntercept) {
x[i][0] = 1.0d;
}
for (int j = noIntercept ? 0 : 1; j < cols; j++) {
x[i][j] = data[pointer++];
}
}
@ -145,18 +167,22 @@ public abstract class AbstractMultipleLinearRegression implements
throw MathRuntimeException.createIllegalArgumentException(
LocalizedFormats.NO_DATA);
}
final int nVars = x[0].length;
final double[][] xAug = new double[x.length][nVars + 1];
for (int i = 0; i < x.length; i++) {
if (x[i].length != nVars) {
throw MathRuntimeException.createIllegalArgumentException(
LocalizedFormats.DIFFERENT_ROWS_LENGTHS,
x[i].length, nVars);
if (noIntercept) {
this.X = new Array2DRowRealMatrix(x, true);
} else { // Augment design matrix with initial unitary column
final int nVars = x[0].length;
final double[][] xAug = new double[x.length][nVars + 1];
for (int i = 0; i < x.length; i++) {
if (x[i].length != nVars) {
throw MathRuntimeException.createIllegalArgumentException(
LocalizedFormats.DIFFERENT_ROWS_LENGTHS,
x[i].length, nVars);
}
xAug[i][0] = 1.0d;
System.arraycopy(x[i], 0, xAug[i], 1, nVars);
}
xAug[i][0] = 1.0d;
System.arraycopy(x[i], 0, xAug[i], 1, nVars);
this.X = new Array2DRowRealMatrix(xAug, false);
}
this.X = new Array2DRowRealMatrix(xAug, false);
}
/**

View File

@ -22,35 +22,32 @@ import org.apache.commons.math.linear.QRDecomposition;
import org.apache.commons.math.linear.QRDecompositionImpl;
import org.apache.commons.math.linear.RealMatrix;
import org.apache.commons.math.linear.RealVector;
import org.apache.commons.math.stat.StatUtils;
import org.apache.commons.math.stat.descriptive.moment.SecondMoment;
/**
* <p>Implements ordinary least squares (OLS) to estimate the parameters of a
* multiple linear regression model.</p>
*
* <p>OLS assumes the covariance matrix of the error to be diagonal and with
* equal variance.</p>
* <p>
* u ~ N(0, &sigma;<sup>2</sup>I)
* </p>
*
* <p>The regression coefficients, b, satisfy the normal equations:
* <p>
* X<sup>T</sup> X b = X<sup>T</sup> y
* </p>
* <p>The regression coefficients, <code>b</code>, satisfy the normal equations:
* <pre><code> X<sup>T</sup> X b = X<sup>T</sup> y </code></pre></p>
*
* <p>To solve the normal equations, this implementation uses QR decomposition
* of the X matrix. (See {@link QRDecompositionImpl} for details on the
* decomposition algorithm.)
* </p>
* <p>X<sup>T</sup>X b = X<sup>T</sup> y <br/>
* (QR)<sup>T</sup> (QR) b = (QR)<sup>T</sup>y <br/>
* R<sup>T</sup> (Q<sup>T</sup>Q) R b = R<sup>T</sup> Q<sup>T</sup> y <br/>
* R<sup>T</sup> R b = R<sup>T</sup> Q<sup>T</sup> y <br/>
* (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> R b = (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> Q<sup>T</sup> y <br/>
* R b = Q<sup>T</sup> y
* </p>
* Given Q and R, the last equation is solved by back-substitution.</p>
* of the <code>X</code> matrix. (See {@link QRDecompositionImpl} for details on the
* decomposition algorithm.) The <code>X</code> matrix, also known as the <i>design matrix,</i>
* has rows corresponding to sample observations and columns corresponding to independent
* variables. When the model is estimated using an intercept term (i.e. when
* {@link #isNoIntercept() isNoIntercept} is false as it is by default), the <code>X</code>
* matrix includes an initial column identically equal to 1. We solve the normal equations
* as follows:
* <pre><code> X<sup>T</sup>X b = X<sup>T</sup> y
* (QR)<sup>T</sup> (QR) b = (QR)<sup>T</sup>y
* R<sup>T</sup> (Q<sup>T</sup>Q) R b = R<sup>T</sup> Q<sup>T</sup> y
* R<sup>T</sup> R b = R<sup>T</sup> Q<sup>T</sup> y
* (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> R b = (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> Q<sup>T</sup> y
* R b = Q<sup>T</sup> y </code></pre></p>
*
* <p>Given <code>Q</code> and <code>R</code>, the last equation is solved by back-substitution.</p>
*
* @version $Revision$ $Date$
* @since 2.0
@ -122,12 +119,23 @@ public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
}
/**
* Returns the sum of squared deviations of Y from its mean.
* <p>Returns the sum of squared deviations of Y from its mean.</p>
*
* @return total sum of squares
* <p>If the model has no intercept term, <code>0</code> is used for the
* mean of Y - i.e., what is returned is the sum of the squared Y values.</p>
*
* <p>The value returned by this method is the SSTO value used in
* the {@link #calculateRSquared() R-squared} computation.</p>
*
* @return SSTO - the total sum of squares
* @see #isNoIntercept()
*/
public double calculateTotalSumOfSquares() {
return new SecondMoment().evaluate(Y.getData());
if (isNoIntercept()) {
return StatUtils.sumSq(Y.getData());
} else {
return new SecondMoment().evaluate(Y.getData());
}
}
/**
@ -154,24 +162,34 @@ public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
}
/**
* Returns the adjusted R-squared statistic, defined by the formula <pre>
* <p>Returns the adjusted R-squared statistic, defined by the formula <pre>
* R<sup>2</sup><sub>adj</sub> = 1 - [SSR (n - 1)] / [SSTO (n - p)]
* </pre>
* where SSR is the {@link #calculateResidualSumOfSquares() sum of squared residuals},
* SSTO is the {@link #calculateTotalSumOfSquares() total sum of squares}, n is the number
* of observations and p is the number of parameters estimated (including the intercept).
* of observations and p is the number of parameters estimated (including the intercept).</p>
*
* <p>If the regression is estimated without an intercept term, what is returned is <pre>
* <code> 1 - (1 - {@link #calculateRSquared()}) * (n / (n - p)) </code>
* </pre></p>
*
* @return adjusted R-Squared statistic
* @see #isNoIntercept()
*/
public double calculateAdjustedRSquared() {
final double n = X.getRowDimension();
return 1 - (calculateResidualSumOfSquares() * (n - 1)) /
(calculateTotalSumOfSquares() * (n - X.getColumnDimension()));
if (isNoIntercept()) {
return 1 - (1 - calculateRSquared()) * (n / (n - X.getColumnDimension()));
} else {
return 1 - (calculateResidualSumOfSquares() * (n - 1)) /
(calculateTotalSumOfSquares() * (n - X.getColumnDimension()));
}
}
/**
* {@inheritDoc}
* <p>This implementation computes and caches the QR decomposition of the X matrix once it is successfully loaded.</p>
* <p>This implementation computes and caches the QR decomposition of the X matrix
* once it is successfully loaded.</p>
*/
@Override
protected void newXSampleData(double[][] x) {
@ -190,7 +208,7 @@ public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
}
/**
* <p>Calculates the variance on the beta by OLS.
* <p>Calculates the variance-covariance matrix of the regression parameters.
* </p>
* <p>Var(b) = (X<sup>T</sup>X)<sup>-1</sup>
* </p>
@ -198,7 +216,7 @@ public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
* to (R<sup>T</sup>R)<sup>-1</sup>, with only the top p rows of
* R included, where p = the length of the beta vector.</p>
*
* @return The beta variance
* @return The beta variance-covariance matrix
*/
@Override
protected RealMatrix calculateBetaVariance() {

View File

@ -71,6 +71,11 @@ The <action> type attribute can be add,update,fix,remove.
</action>
</release>
<release version="2.2" date="TBD" description="TBD">
<action dev="psteitz" type="update" issue="MATH-409">
Made intercept / no intercept configurable in multiple regression classes. By default, regression
models are estimated with an intercept term. When the "noIntercept" property is set to
true, regression models are estimated without intercepts.
</action>
<action dev="luc" type="fix" issue="MATH-415">
Fixed lost cause in MathRuntimeException.createInternalError. Note that the message is still the default
message for internal errors asking to report a bug to commons-math JIRA tracker. In order to retrieve
@ -84,10 +89,9 @@ The <action> type attribute can be add,update,fix,remove.
into the x[][] arrays to create a model with an intercept term; while newSampleData(double[], int, int)
created a model including an intercept term without requiring the unitary column. All methods have
been changed to eliminate the need for users to add unitary columns to specify regression models.
<!-- uncomment when MATH-409 is resolved (noIntercept option)
Users of OLSMultipleLinearRegression or GLSMultipleLinearRegression versions 2.0 or 2.1 should either
verify that their code either does not use the first set of data loading methods above or set the noIntercept
property on estimated models to get the previous behavior. -->
property on estimated models to get the previous behavior.
</action>
<action dev="luc" type="fix" issue="MATH-412" due-to="Bill Rossi">
Added the dfp library providing arbitrary precision floating point computation in the spirit of

View File

@ -22,12 +22,12 @@
# source("<name-of-this-file>")
#
#------------------------------------------------------------------------------
tol <- 1E-8 # error tolerance for tests
tol <- 1E-8 # default error tolerance for tests
#------------------------------------------------------------------------------
# Function definitions
source("testFunctions") # utility test functions
options(digits=16) # override number of digits displayed
options(digits=16) # override number of digits displayed
# function to verify OLS computations
@ -97,7 +97,7 @@ expectedAdjRSquare <- NaN
verifyRegression(model, expectedBeta, expectedResiduals, expectedErrors,
expectedStdError, expectedRSquare, expectedAdjRSquare, "perfect fit")
# Longly
# Longley
#
# Data Source: J. Longley (1967) "An Appraisal of Least Squares Programs for the
# Electronic Computer from the Point of View of the User",
@ -144,7 +144,7 @@ estimates <- matrix(c(-3482258.63459582,890420.383607373,
expectedBeta <- estimates[,1]
expectedErrors <- estimates[,2]
expectedResiduals <- c( 267.340029759711,-94.0139423988359,46.28716775752924,
expectedResiduals <- c(267.340029759711,-94.0139423988359,46.28716775752924,
-410.114621930906,309.7145907602313,-249.3112153297231,-164.0489563956039,
-13.18035686637081,14.30477260005235,455.394094551857,-17.26892711483297,
-39.0550425226967,-155.5499735953195,-85.6713080421283,341.9315139607727,
@ -154,7 +154,32 @@ expectedRSquare <- 0.995479004577296
expectedAdjRSquare <- 0.992465007628826
verifyRegression(model, expectedBeta, expectedResiduals, expectedErrors,
expectedStdError, expectedRSquare, expectedAdjRSquare, "Longly")
expectedStdError, expectedRSquare, expectedAdjRSquare, "Longley")
# Model with no intercept
model <- lm(y ~ 0 + x1 + x2 + x3 + x4 + x5 + x6)
estimates <- matrix(c(-52.99357013868291, 129.54486693117232,
0.07107319907358, 0.03016640003786,
-0.42346585566399, 0.41773654056612,
-0.57256866841929, 0.27899087467676,
-0.41420358884978, 0.32128496193363,
48.41786562001326, 17.68948737819961),
nrow = 6, ncol = 2, byrow = TRUE)
expectedBeta <- estimates[,1]
expectedErrors <- estimates[,2]
expectedResiduals <- c(279.90274927293092, -130.32465380836874, 90.73228661967445,
-401.31252201634948, -440.46768772620027, -543.54512853774793, 201.32111639536299,
215.90889365977932, 73.09368242049943, 913.21694494481869, 424.82484953610174,
-8.56475876776709, -361.32974610842876, 27.34560497213464, 151.28955976355002,
-492.49937355336846)
expectedStdError <- 475.1655079819517
expectedRSquare <- 0.9999670130706
expectedAdjRSquare <- 0.999947220913
verifyRegression(model, expectedBeta, expectedResiduals, expectedErrors,
expectedStdError, expectedRSquare, expectedAdjRSquare, "Longley No Intercept")
# Swiss Fertility (R dataset named "swiss")
@ -249,5 +274,36 @@ expectedAdjRSquare <- 0.6164363850373927
verifyRegression(model, expectedBeta, expectedResiduals, expectedErrors,
expectedStdError, expectedRSquare, expectedAdjRSquare, "Swiss Fertility")
# model with no intercept
model <- lm(y ~ 0 + x1 + x2 + x3 + x4)
estimates <- matrix(c(0.52191832900513, 0.10470063765677,
2.36588087917963, 0.41684100584290,
-0.94770353802795, 0.43370143099691,
0.30851985863609, 0.07694953606522),
nrow = 4, ncol = 2, byrow = TRUE)
expectedBeta <- estimates[,1]
expectedErrors <- estimates[,2]
expectedResiduals <- c(44.138759883538249, 27.720705122356215, 35.873200836126799,
34.574619581211977, 26.600168342080213, 15.074636243026923, -12.704904871199814,
1.497443824078134, 2.691972687079431, 5.582798774291231, -4.422986561283165,
-9.198581600334345, 4.481765170730647, 2.273520207553216, -22.649827853221336,
-17.747900013943308, 20.298314638496436, 6.861405135329779, -8.684712790954924,
-10.298639278062371, -9.896618896845819, 4.568568616351242, -15.313570491727944,
-13.762961360873966, 7.156100301980509, 16.722282219843990, 26.716200609071898,
-1.991466398777079, -2.523342564719335, 9.776486693095093, -5.297535127628603,
-16.639070567471094, -10.302057295211819, -23.549487860816846, 1.506624392156384,
-17.939174438345930, 13.105792202765040, -1.943329906928462, -1.516005841666695,
-0.759066561832886, 20.793137744128977, -2.485236153005426, 27.588238710486976,
2.658333257106881, -15.998337823623046, -5.550742066720694, -14.219077806826615)
expectedStdError <- 17.24710630547
expectedRSquare <- 0.946350722085
expectedAdjRSquare <- 0.9413600915813
verifyRegression(model, expectedBeta, expectedResiduals, expectedErrors,
expectedStdError, expectedRSquare, expectedAdjRSquare, "Swiss Fertility No Intercept")
displayDashes(WIDTH)

View File

@ -93,6 +93,16 @@ public abstract class MultipleLinearRegressionAbstractTest {
regression.newYSampleData(y);
assertEquals(flatX, regression.X);
assertEquals(flatY, regression.Y);
// No intercept
regression.setNoIntercept(true);
regression.newSampleData(design, 4, 3);
flatX = regression.X.copy();
flatY = regression.Y.copy();
regression.newXSampleData(x);
regression.newYSampleData(y);
assertEquals(flatX, regression.X);
assertEquals(flatY, regression.Y);
}
@Test(expected=IllegalArgumentException.class)

View File

@ -137,9 +137,8 @@ public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs
70551,116.9,554894,4007,2827,130081,1962
};
// Transform to Y and X required by interface
int nobs = 16;
int nvars = 6;
final int nobs = 16;
final int nvars = 6;
// Estimate the model
OLSMultipleLinearRegression model = new OLSMultipleLinearRegression();
@ -182,6 +181,40 @@ public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs
assertEquals(0.992465007628826, model.calculateAdjustedRSquared(), 1E-12);
checkVarianceConsistency(model);
// Estimate model without intercept
model.setNoIntercept(true);
model.newSampleData(design, nobs, nvars);
// Check expected beta values from R
betaHat = model.estimateRegressionParameters();
TestUtils.assertEquals(betaHat,
new double[]{-52.99357013868291, 0.07107319907358,
-0.42346585566399,-0.57256866841929,
-0.41420358884978, 48.41786562001326}, 1E-11);
// Check standard errors from R
errors = model.estimateRegressionParametersStandardErrors();
TestUtils.assertEquals(new double[] {129.54486693117232, 0.03016640003786,
0.41773654056612, 0.27899087467676, 0.32128496193363,
17.68948737819961}, errors, 1E-11);
// Check expected residuals from R
residuals = model.estimateResiduals();
TestUtils.assertEquals(residuals, new double[]{
279.90274927293092, -130.32465380836874, 90.73228661967445, -401.31252201634948,
-440.46768772620027, -543.54512853774793, 201.32111639536299, 215.90889365977932,
73.09368242049943, 913.21694494481869, 424.82484953610174, -8.56475876776709,
-361.32974610842876, 27.34560497213464, 151.28955976355002, -492.49937355336846},
1E-10);
// Check regression standard error against R
assertEquals(475.1655079819517, model.estimateRegressionStandardError(), 1E-10);
// Check R-Square statistics against R
assertEquals(0.9999670130706, model.calculateRSquared(), 1E-12);
assertEquals(0.999947220913, model.calculateAdjustedRSquared(), 1E-12);
}
/**
@ -248,7 +281,7 @@ public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs
model.newSampleData(design, nobs, nvars);
// Check expected beta values from R
final double[] betaHat = model.estimateRegressionParameters();
double[] betaHat = model.estimateRegressionParameters();
TestUtils.assertEquals(betaHat,
new double[]{91.05542390271397,
-0.22064551045715,
@ -257,7 +290,7 @@ public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs
0.12441843147162}, 1E-12);
// Check expected residuals from R
final double[] residuals = model.estimateResiduals();
double[] residuals = model.estimateResiduals();
TestUtils.assertEquals(residuals, new double[]{
7.1044267859730512,1.6580347433531366,
4.6944952770029644,8.4548022690166160,13.6547432343186212,
@ -278,7 +311,7 @@ public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs
1E-12);
// Check standard errors from R
final double[] errors = model.estimateRegressionParametersStandardErrors();
double[] errors = model.estimateRegressionParametersStandardErrors();
TestUtils.assertEquals(new double[] {6.94881329475087,
0.07360008972340,
0.27410957467466,
@ -293,6 +326,48 @@ public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs
assertEquals(0.6164363850373927, model.calculateAdjustedRSquared(), 1E-12);
checkVarianceConsistency(model);
// Estimate the model with no intercept
model = new OLSMultipleLinearRegression();
model.setNoIntercept(true);
model.newSampleData(design, nobs, nvars);
// Check expected beta values from R
betaHat = model.estimateRegressionParameters();
TestUtils.assertEquals(betaHat,
new double[]{0.52191832900513,
2.36588087917963,
-0.94770353802795,
0.30851985863609}, 1E-12);
// Check expected residuals from R
residuals = model.estimateResiduals();
TestUtils.assertEquals(residuals, new double[]{
44.138759883538249, 27.720705122356215, 35.873200836126799,
34.574619581211977, 26.600168342080213, 15.074636243026923, -12.704904871199814,
1.497443824078134, 2.691972687079431, 5.582798774291231, -4.422986561283165,
-9.198581600334345, 4.481765170730647, 2.273520207553216, -22.649827853221336,
-17.747900013943308, 20.298314638496436, 6.861405135329779, -8.684712790954924,
-10.298639278062371, -9.896618896845819, 4.568568616351242, -15.313570491727944,
-13.762961360873966, 7.156100301980509, 16.722282219843990, 26.716200609071898,
-1.991466398777079, -2.523342564719335, 9.776486693095093, -5.297535127628603,
-16.639070567471094, -10.302057295211819, -23.549487860816846, 1.506624392156384,
-17.939174438345930, 13.105792202765040, -1.943329906928462, -1.516005841666695,
-0.759066561832886, 20.793137744128977, -2.485236153005426, 27.588238710486976,
2.658333257106881, -15.998337823623046, -5.550742066720694, -14.219077806826615},
1E-12);
// Check standard errors from R
errors = model.estimateRegressionParametersStandardErrors();
TestUtils.assertEquals(new double[] {0.10470063765677, 0.41684100584290,
0.43370143099691, 0.07694953606522}, errors, 1E-10);
// Check regression standard error against R
assertEquals(17.24710630547, model.estimateRegressionStandardError(), 1E-10);
// Check R-Square statistics against R
assertEquals(0.946350722085, model.calculateRSquared(), 1E-12);
assertEquals(0.9413600915813, model.calculateAdjustedRSquared(), 1E-12);
}
/**
@ -415,6 +490,16 @@ public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs
regression.newYSampleData(y);
assertEquals(combinedX, regression.X);
assertEquals(combinedY, regression.Y);
// No intercept
regression.setNoIntercept(true);
regression.newSampleData(y, x);
combinedX = regression.X.copy();
combinedY = regression.Y.copy();
regression.newXSampleData(x);
regression.newYSampleData(y);
assertEquals(combinedX, regression.X);
assertEquals(combinedY, regression.Y);
}
@Test(expected=IllegalArgumentException.class)