From aba7e04d4430e2f9786b97d19a22ccb536419687 Mon Sep 17 00:00:00 2001
From: Phil Steitz Loads model x and y sample data from a flat input array, overriding any previous sample.
*
Note that there is no need to add an initial unitary column (column of 1's) when
- * specifying a model including an intercept term.
+ * specifying a model including an intercept term. If {@link #isNoIntercept()} is true
,
+ * the X matrix will be created without an initial column of "1"s; otherwise this column will
+ * be added.
*
Throws IllegalArgumentException if any of the following preconditions fail: *
data
cannot be nullImplements ordinary least squares (OLS) to estimate the parameters of a * multiple linear regression model.
* - *OLS assumes the covariance matrix of the error to be diagonal and with - * equal variance.
- *- * u ~ N(0, σ2I) - *
- * - *The regression coefficients, b, satisfy the normal equations: - *
- * XT X b = XT y - *
+ *The regression coefficients, b
, satisfy the normal equations:
+ *
XT X b = XT y
*
* To solve the normal equations, this implementation uses QR decomposition - * of the X matrix. (See {@link QRDecompositionImpl} for details on the - * decomposition algorithm.) - *
- *XTX b = XT y
- * (QR)T (QR) b = (QR)Ty
- * RT (QTQ) R b = RT QT y
- * RT R b = RT QT y
- * (RT)-1 RT R b = (RT)-1 RT QT y
- * R b = QT y
- *
X
matrix. (See {@link QRDecompositionImpl} for details on the
+ * decomposition algorithm.) The X
matrix, also known as the design matrix,
+ * has rows corresponding to sample observations and columns corresponding to independent
+ * variables. When the model is estimated using an intercept term (i.e. when
+ * {@link #isNoIntercept() isNoIntercept} is false as it is by default), the X
+ * matrix includes an initial column identically equal to 1. We solve the normal equations
+ * as follows:
+ * XTX b = XT y
+ * (QR)T (QR) b = (QR)Ty
+ * RT (QTQ) R b = RT QT y
+ * RT R b = RT QT y
+ * (RT)-1 RT R b = (RT)-1 RT QT y
+ * R b = QT y
+ *
+ * Given Q
and R
, the last equation is solved by back-substitution.
Returns the sum of squared deviations of Y from its mean.
* - * @return total sum of squares + *If the model has no intercept term, 0
is used for the
+ * mean of Y - i.e., what is returned is the sum of the squared Y values.
The value returned by this method is the SSTO value used in + * the {@link #calculateRSquared() R-squared} computation.
+ * + * @return SSTO - the total sum of squares + * @see #isNoIntercept() */ public double calculateTotalSumOfSquares() { - return new SecondMoment().evaluate(Y.getData()); + if (isNoIntercept()) { + return StatUtils.sumSq(Y.getData()); + } else { + return new SecondMoment().evaluate(Y.getData()); + } } /** @@ -154,24 +162,34 @@ public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegressio } /** - * Returns the adjusted R-squared statistic, defined by the formula+ *Returns the adjusted R-squared statistic, defined by the formula
* R2adj = 1 - [SSR (n - 1)] / [SSTO (n - p)] ** where SSR is the {@link #calculateResidualSumOfSquares() sum of squared residuals}, * SSTO is the {@link #calculateTotalSumOfSquares() total sum of squares}, n is the number - * of observations and p is the number of parameters estimated (including the intercept). + * of observations and p is the number of parameters estimated (including the intercept). + * + *If the regression is estimated without an intercept term, what is returned is
+ ** * @return adjusted R-Squared statistic + * @see #isNoIntercept() */ public double calculateAdjustedRSquared() { final double n = X.getRowDimension(); - return 1 - (calculateResidualSumOfSquares() * (n - 1)) / - (calculateTotalSumOfSquares() * (n - X.getColumnDimension())); + if (isNoIntercept()) { + return 1 - (1 - calculateRSquared()) * (n / (n - X.getColumnDimension())); + } else { + return 1 - (calculateResidualSumOfSquares() * (n - 1)) / + (calculateTotalSumOfSquares() * (n - X.getColumnDimension())); + } } /** * {@inheritDoc} - *1 - (1 - {@link #calculateRSquared()}) * (n / (n - p))
+ *This implementation computes and caches the QR decomposition of the X matrix once it is successfully loaded.
+ *This implementation computes and caches the QR decomposition of the X matrix + * once it is successfully loaded.
*/ @Override protected void newXSampleData(double[][] x) { @@ -190,7 +208,7 @@ public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegressio } /** - *Calculates the variance on the beta by OLS. + *
Calculates the variance-covariance matrix of the regression parameters. *
*Var(b) = (XTX)-1 *
@@ -198,7 +216,7 @@ public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegressio * to (RTR)-1, with only the top p rows of * R included, where p = the length of the beta vector. * - * @return The beta variance + * @return The beta variance-covariance matrix */ @Override protected RealMatrix calculateBetaVariance() { diff --git a/src/site/xdoc/changes.xml b/src/site/xdoc/changes.xml index 454d9989c..819d95143 100644 --- a/src/site/xdoc/changes.xml +++ b/src/site/xdoc/changes.xml @@ -71,6 +71,11 @@ Thetype attribute can be add,update,fix,remove. + + Made intercept / no intercept configurable in multiple regression classes. By default, regression + models are estimated with an intercept term. When the "noIntercept" property is set to + true, regression models are estimated without intercepts. + Fixed lost cause in MathRuntimeException.createInternalError. Note that the message is still the default message for internal errors asking to report a bug to commons-math JIRA tracker. In order to retrieve @@ -84,10 +89,9 @@ The type attribute can be add,update,fix,remove. into the x[][] arrays to create a model with an intercept term; while newSampleData(double[], int, int) created a model including an intercept term without requiring the unitary column. All methods have been changed to eliminate the need for users to add unitary columns to specify regression models. - + property on estimated models to get the previous behavior. Added the dfp library providing arbitrary precision floating point computation in the spirit of diff --git a/src/test/R/multipleOLSRegressionTestCases b/src/test/R/multipleOLSRegressionTestCases index 7b288d4d4..1fe983de9 100644 --- a/src/test/R/multipleOLSRegressionTestCases +++ b/src/test/R/multipleOLSRegressionTestCases @@ -22,12 +22,12 @@ # source(" ") # #------------------------------------------------------------------------------ -tol <- 1E-8 # error tolerance for tests +tol <- 1E-8 # default error tolerance for tests #------------------------------------------------------------------------------ # Function definitions source("testFunctions") # utility test functions -options(digits=16) # override number of digits displayed +options(digits=16) # override number of digits displayed # function to verify OLS computations @@ -97,7 +97,7 @@ expectedAdjRSquare <- NaN verifyRegression(model, expectedBeta, expectedResiduals, expectedErrors, expectedStdError, expectedRSquare, expectedAdjRSquare, "perfect fit") -# Longly +# Longley # # Data Source: J. Longley (1967) "An Appraisal of Least Squares Programs for the # Electronic Computer from the Point of View of the User", @@ -144,7 +144,7 @@ estimates <- matrix(c(-3482258.63459582,890420.383607373, expectedBeta <- estimates[,1] expectedErrors <- estimates[,2] -expectedResiduals <- c( 267.340029759711,-94.0139423988359,46.28716775752924, +expectedResiduals <- c(267.340029759711,-94.0139423988359,46.28716775752924, -410.114621930906,309.7145907602313,-249.3112153297231,-164.0489563956039, -13.18035686637081,14.30477260005235,455.394094551857,-17.26892711483297, -39.0550425226967,-155.5499735953195,-85.6713080421283,341.9315139607727, @@ -154,7 +154,32 @@ expectedRSquare <- 0.995479004577296 expectedAdjRSquare <- 0.992465007628826 verifyRegression(model, expectedBeta, expectedResiduals, expectedErrors, -expectedStdError, expectedRSquare, expectedAdjRSquare, "Longly") +expectedStdError, expectedRSquare, expectedAdjRSquare, "Longley") + +# Model with no intercept +model <- lm(y ~ 0 + x1 + x2 + x3 + x4 + x5 + x6) + +estimates <- matrix(c(-52.99357013868291, 129.54486693117232, + 0.07107319907358, 0.03016640003786, + -0.42346585566399, 0.41773654056612, + -0.57256866841929, 0.27899087467676, + -0.41420358884978, 0.32128496193363, + 48.41786562001326, 17.68948737819961), + nrow = 6, ncol = 2, byrow = TRUE) + +expectedBeta <- estimates[,1] +expectedErrors <- estimates[,2] +expectedResiduals <- c(279.90274927293092, -130.32465380836874, 90.73228661967445, + -401.31252201634948, -440.46768772620027, -543.54512853774793, 201.32111639536299, + 215.90889365977932, 73.09368242049943, 913.21694494481869, 424.82484953610174, + -8.56475876776709, -361.32974610842876, 27.34560497213464, 151.28955976355002, + -492.49937355336846) +expectedStdError <- 475.1655079819517 +expectedRSquare <- 0.9999670130706 +expectedAdjRSquare <- 0.999947220913 + +verifyRegression(model, expectedBeta, expectedResiduals, expectedErrors, +expectedStdError, expectedRSquare, expectedAdjRSquare, "Longley No Intercept") # Swiss Fertility (R dataset named "swiss") @@ -249,5 +274,36 @@ expectedAdjRSquare <- 0.6164363850373927 verifyRegression(model, expectedBeta, expectedResiduals, expectedErrors, expectedStdError, expectedRSquare, expectedAdjRSquare, "Swiss Fertility") +# model with no intercept +model <- lm(y ~ 0 + x1 + x2 + x3 + x4) + +estimates <- matrix(c(0.52191832900513, 0.10470063765677, + 2.36588087917963, 0.41684100584290, + -0.94770353802795, 0.43370143099691, + 0.30851985863609, 0.07694953606522), + nrow = 4, ncol = 2, byrow = TRUE) + +expectedBeta <- estimates[,1] +expectedErrors <- estimates[,2] + +expectedResiduals <- c(44.138759883538249, 27.720705122356215, 35.873200836126799, + 34.574619581211977, 26.600168342080213, 15.074636243026923, -12.704904871199814, + 1.497443824078134, 2.691972687079431, 5.582798774291231, -4.422986561283165, + -9.198581600334345, 4.481765170730647, 2.273520207553216, -22.649827853221336, + -17.747900013943308, 20.298314638496436, 6.861405135329779, -8.684712790954924, + -10.298639278062371, -9.896618896845819, 4.568568616351242, -15.313570491727944, + -13.762961360873966, 7.156100301980509, 16.722282219843990, 26.716200609071898, + -1.991466398777079, -2.523342564719335, 9.776486693095093, -5.297535127628603, + -16.639070567471094, -10.302057295211819, -23.549487860816846, 1.506624392156384, + -17.939174438345930, 13.105792202765040, -1.943329906928462, -1.516005841666695, + -0.759066561832886, 20.793137744128977, -2.485236153005426, 27.588238710486976, + 2.658333257106881, -15.998337823623046, -5.550742066720694, -14.219077806826615) + +expectedStdError <- 17.24710630547 +expectedRSquare <- 0.946350722085 +expectedAdjRSquare <- 0.9413600915813 + +verifyRegression(model, expectedBeta, expectedResiduals, expectedErrors, +expectedStdError, expectedRSquare, expectedAdjRSquare, "Swiss Fertility No Intercept") displayDashes(WIDTH) diff --git a/src/test/java/org/apache/commons/math/stat/regression/MultipleLinearRegressionAbstractTest.java b/src/test/java/org/apache/commons/math/stat/regression/MultipleLinearRegressionAbstractTest.java index c829d8e36..daffb6b33 100644 --- a/src/test/java/org/apache/commons/math/stat/regression/MultipleLinearRegressionAbstractTest.java +++ b/src/test/java/org/apache/commons/math/stat/regression/MultipleLinearRegressionAbstractTest.java @@ -93,6 +93,16 @@ public abstract class MultipleLinearRegressionAbstractTest { regression.newYSampleData(y); assertEquals(flatX, regression.X); assertEquals(flatY, regression.Y); + + // No intercept + regression.setNoIntercept(true); + regression.newSampleData(design, 4, 3); + flatX = regression.X.copy(); + flatY = regression.Y.copy(); + regression.newXSampleData(x); + regression.newYSampleData(y); + assertEquals(flatX, regression.X); + assertEquals(flatY, regression.Y); } @Test(expected=IllegalArgumentException.class) diff --git a/src/test/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegressionTest.java b/src/test/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegressionTest.java index 56c251d41..88739e2c3 100644 --- a/src/test/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegressionTest.java +++ b/src/test/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegressionTest.java @@ -137,9 +137,8 @@ public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs 70551,116.9,554894,4007,2827,130081,1962 }; - // Transform to Y and X required by interface - int nobs = 16; - int nvars = 6; + final int nobs = 16; + final int nvars = 6; // Estimate the model OLSMultipleLinearRegression model = new OLSMultipleLinearRegression(); @@ -182,6 +181,40 @@ public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs assertEquals(0.992465007628826, model.calculateAdjustedRSquared(), 1E-12); checkVarianceConsistency(model); + + // Estimate model without intercept + model.setNoIntercept(true); + model.newSampleData(design, nobs, nvars); + + // Check expected beta values from R + betaHat = model.estimateRegressionParameters(); + TestUtils.assertEquals(betaHat, + new double[]{-52.99357013868291, 0.07107319907358, + -0.42346585566399,-0.57256866841929, + -0.41420358884978, 48.41786562001326}, 1E-11); + + // Check standard errors from R + errors = model.estimateRegressionParametersStandardErrors(); + TestUtils.assertEquals(new double[] {129.54486693117232, 0.03016640003786, + 0.41773654056612, 0.27899087467676, 0.32128496193363, + 17.68948737819961}, errors, 1E-11); + + // Check expected residuals from R + residuals = model.estimateResiduals(); + TestUtils.assertEquals(residuals, new double[]{ + 279.90274927293092, -130.32465380836874, 90.73228661967445, -401.31252201634948, + -440.46768772620027, -543.54512853774793, 201.32111639536299, 215.90889365977932, + 73.09368242049943, 913.21694494481869, 424.82484953610174, -8.56475876776709, + -361.32974610842876, 27.34560497213464, 151.28955976355002, -492.49937355336846}, + 1E-10); + + // Check regression standard error against R + assertEquals(475.1655079819517, model.estimateRegressionStandardError(), 1E-10); + + // Check R-Square statistics against R + assertEquals(0.9999670130706, model.calculateRSquared(), 1E-12); + assertEquals(0.999947220913, model.calculateAdjustedRSquared(), 1E-12); + } /** @@ -248,7 +281,7 @@ public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs model.newSampleData(design, nobs, nvars); // Check expected beta values from R - final double[] betaHat = model.estimateRegressionParameters(); + double[] betaHat = model.estimateRegressionParameters(); TestUtils.assertEquals(betaHat, new double[]{91.05542390271397, -0.22064551045715, @@ -257,7 +290,7 @@ public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs 0.12441843147162}, 1E-12); // Check expected residuals from R - final double[] residuals = model.estimateResiduals(); + double[] residuals = model.estimateResiduals(); TestUtils.assertEquals(residuals, new double[]{ 7.1044267859730512,1.6580347433531366, 4.6944952770029644,8.4548022690166160,13.6547432343186212, @@ -278,7 +311,7 @@ public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs 1E-12); // Check standard errors from R - final double[] errors = model.estimateRegressionParametersStandardErrors(); + double[] errors = model.estimateRegressionParametersStandardErrors(); TestUtils.assertEquals(new double[] {6.94881329475087, 0.07360008972340, 0.27410957467466, @@ -293,6 +326,48 @@ public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs assertEquals(0.6164363850373927, model.calculateAdjustedRSquared(), 1E-12); checkVarianceConsistency(model); + + // Estimate the model with no intercept + model = new OLSMultipleLinearRegression(); + model.setNoIntercept(true); + model.newSampleData(design, nobs, nvars); + + // Check expected beta values from R + betaHat = model.estimateRegressionParameters(); + TestUtils.assertEquals(betaHat, + new double[]{0.52191832900513, + 2.36588087917963, + -0.94770353802795, + 0.30851985863609}, 1E-12); + + // Check expected residuals from R + residuals = model.estimateResiduals(); + TestUtils.assertEquals(residuals, new double[]{ + 44.138759883538249, 27.720705122356215, 35.873200836126799, + 34.574619581211977, 26.600168342080213, 15.074636243026923, -12.704904871199814, + 1.497443824078134, 2.691972687079431, 5.582798774291231, -4.422986561283165, + -9.198581600334345, 4.481765170730647, 2.273520207553216, -22.649827853221336, + -17.747900013943308, 20.298314638496436, 6.861405135329779, -8.684712790954924, + -10.298639278062371, -9.896618896845819, 4.568568616351242, -15.313570491727944, + -13.762961360873966, 7.156100301980509, 16.722282219843990, 26.716200609071898, + -1.991466398777079, -2.523342564719335, 9.776486693095093, -5.297535127628603, + -16.639070567471094, -10.302057295211819, -23.549487860816846, 1.506624392156384, + -17.939174438345930, 13.105792202765040, -1.943329906928462, -1.516005841666695, + -0.759066561832886, 20.793137744128977, -2.485236153005426, 27.588238710486976, + 2.658333257106881, -15.998337823623046, -5.550742066720694, -14.219077806826615}, + 1E-12); + + // Check standard errors from R + errors = model.estimateRegressionParametersStandardErrors(); + TestUtils.assertEquals(new double[] {0.10470063765677, 0.41684100584290, + 0.43370143099691, 0.07694953606522}, errors, 1E-10); + + // Check regression standard error against R + assertEquals(17.24710630547, model.estimateRegressionStandardError(), 1E-10); + + // Check R-Square statistics against R + assertEquals(0.946350722085, model.calculateRSquared(), 1E-12); + assertEquals(0.9413600915813, model.calculateAdjustedRSquared(), 1E-12); } /** @@ -415,6 +490,16 @@ public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs regression.newYSampleData(y); assertEquals(combinedX, regression.X); assertEquals(combinedY, regression.Y); + + // No intercept + regression.setNoIntercept(true); + regression.newSampleData(y, x); + combinedX = regression.X.copy(); + combinedY = regression.Y.copy(); + regression.newXSampleData(x); + regression.newYSampleData(y); + assertEquals(combinedX, regression.X); + assertEquals(combinedY, regression.Y); } @Test(expected=IllegalArgumentException.class)