From 95ebe8294c66d2aeb91db99911c9de42ab3dc87b Mon Sep 17 00:00:00 2001 From: Phil Steitz Date: Mon, 23 Aug 2010 02:55:01 +0000 Subject: [PATCH] Added R-squared and adjusted R-squared statistics to OLSMultipleLinearRegression JIRA: MATH-386 git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@987983 13f79535-47bb-0310-9956-ffa450edef68 --- .../OLSMultipleLinearRegression.java | 50 +++++++++++++++++++ src/site/xdoc/changes.xml | 3 ++ src/test/R/multipleOLSRegressionTestCases | 34 ++++++++++--- .../OLSMultipleLinearRegressionTest.java | 10 +++- 4 files changed, 89 insertions(+), 8 deletions(-) diff --git a/src/main/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java b/src/main/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java index 6f3912e0c..333f0dfaf 100644 --- a/src/main/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java +++ b/src/main/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java @@ -22,6 +22,7 @@ import org.apache.commons.math.linear.QRDecomposition; import org.apache.commons.math.linear.QRDecompositionImpl; import org.apache.commons.math.linear.RealMatrix; import org.apache.commons.math.linear.RealVector; +import org.apache.commons.math.stat.descriptive.moment.SecondMoment; /** *

Implements ordinary least squares (OLS) to estimate the parameters of a @@ -121,6 +122,55 @@ public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegressio return Q.multiply(augI).multiply(Q.transpose()); } + /** + * Returns the sum of squared deviations of Y from its mean. + * + * @return total sum of squares + */ + public double calculateTotalSumOfSquares() { + return new SecondMoment().evaluate(Y.getData()); + } + + /** + * Returns the sum of square residuals. + * + * @return residual sum of squares + */ + public double calculateResidualSumOfSquares() { + final RealVector residuals = calculateResiduals(); + return residuals.dotProduct(residuals); + } + + /** + * Returns the R-Squared statistic, defined by the formula

+     * R2 = 1 - SSR / SSTO
+     * 
+ * where SSR is the {@link #calculateResidualSumOfSquares() sum of squared residuals} + * and SSTO is the {@link #calculateTotalSumOfSquares() total sum of squares} + * + * @return R-square statistic + */ + public double calculateRSquared() { + return 1 - calculateResidualSumOfSquares() / calculateTotalSumOfSquares(); + } + + /** + * Returns the adjusted R-squared statistic, defined by the formula
+     * R2adj = 1 - [SSR (n - 1)] / [SSTO (n - p)]
+     * 
+ * where SSR is the {@link #calculateResidualSumOfSquares() sum of squared residuals}, + * SSTO is the {@link #calculateTotalSumOfSquares() total sum of squares}, n is the number + * of observations and p is the number of parameters estimated (including the intercept). + * + * @return adjusted R-Squared statistic + */ + public double calculateAdjustedRSquared() { + final double n = X.getRowDimension(); + return 1 - (calculateResidualSumOfSquares() * (n - 1)) / + (calculateTotalSumOfSquares() * (n - X.getColumnDimension())); + // return 1 - ((1 - calculateRSquare()) * (n - 1) / (n - X.getColumnDimension() - 1)); + } + /** * Loads new x sample data, overriding any previous sample * diff --git a/src/site/xdoc/changes.xml b/src/site/xdoc/changes.xml index db34f4688..9f4888e8d 100644 --- a/src/site/xdoc/changes.xml +++ b/src/site/xdoc/changes.xml @@ -52,6 +52,9 @@ The type attribute can be add,update,fix,remove. If the output is not quite correct, check for invisible trailing spaces! --> + + Added R-squared and adjusted R-squared statistics to OLSMultipleLinearRegression. + Corrected the formula used for Y variance returned by calculateYVariance and associated methods in multiple regression classes (AbstractMultipleLinearRegression, diff --git a/src/test/R/multipleOLSRegressionTestCases b/src/test/R/multipleOLSRegressionTestCases index 1a2e4bf32..7b288d4d4 100644 --- a/src/test/R/multipleOLSRegressionTestCases +++ b/src/test/R/multipleOLSRegressionTestCases @@ -32,11 +32,13 @@ options(digits=16) # override number of digits displayed # function to verify OLS computations verifyRegression <- function(model, expectedBeta, expectedResiduals, - expectedErrors, expectedStdError, modelName) { + expectedErrors, expectedStdError, expectedRSquare, expecteAdjRSquare, modelName) { betaHat <- as.vector(coefficients(model)) residuals <- as.vector(residuals(model)) errors <- as.vector(as.matrix(coefficients(summary(model)))[,2]) stdError <- summary(model)$sigma + rSquare <- summary(model)$r.squared + adjRSquare <- summary(model)$adj.r.squared output <- c("Parameter test dataset = ", modelName) if (assertEquals(expectedBeta,betaHat,tol,"Parameters")) { displayPadded(output, SUCCEEDED, WIDTH) @@ -61,6 +63,18 @@ verifyRegression <- function(model, expectedBeta, expectedResiduals, } else { displayPadded(output, FAILED, WIDTH) } + output <- c("RSquared test dataset = ", modelName) + if (assertEquals(expectedRSquare,rSquare,tol,"RSquared")) { + displayPadded(output, SUCCEEDED, WIDTH) + } else { + displayPadded(output, FAILED, WIDTH) + } + output <- c("Adjusted RSquared test dataset = ", modelName) + if (assertEquals(expecteAdjRSquare,adjRSquare,tol,"Adjusted RSquared")) { + displayPadded(output, SUCCEEDED, WIDTH) + } else { + displayPadded(output, FAILED, WIDTH) + } } #-------------------------------------------------------------------------- @@ -78,8 +92,10 @@ expectedBeta <- c(11.0,0.5,0.666666666666667,0.75,0.8,0.8333333333333333) expectedResiduals <- c(0,0,0,0,0,0) expectedErrors <- c(NaN,NaN,NaN,NaN,NaN,NaN) expectedStdError <- NaN -verifyRegression(model, expectedBeta, expectedResiduals, expectedErrors, expectedStdError, - "perfect fit") +expectedRSquare <- 1 +expectedAdjRSquare <- NaN +verifyRegression(model, expectedBeta, expectedResiduals, expectedErrors, + expectedStdError, expectedRSquare, expectedAdjRSquare, "perfect fit") # Longly # @@ -134,9 +150,11 @@ expectedResiduals <- c( 267.340029759711,-94.0139423988359,46.28716775752924, -39.0550425226967,-155.5499735953195,-85.6713080421283,341.9315139607727, -206.7578251937366) expectedStdError <- 304.8540735619638 +expectedRSquare <- 0.995479004577296 +expectedAdjRSquare <- 0.992465007628826 -verifyRegression(model, expectedBeta, expectedResiduals, expectedErrors, expectedStdError, - "Longly") +verifyRegression(model, expectedBeta, expectedResiduals, expectedErrors, +expectedStdError, expectedRSquare, expectedAdjRSquare, "Longly") # Swiss Fertility (R dataset named "swiss") @@ -225,9 +243,11 @@ expectedResiduals <- c(7.1044267859730512,1.6580347433531366, -0.4515205619767598,-10.2916870903837587,-15.7812984571900063) expectedStdError <- 7.73642194433223 +expectedRSquare <- 0.649789742860228 +expectedAdjRSquare <- 0.6164363850373927 -verifyRegression(model, expectedBeta, expectedResiduals, expectedErrors, expectedStdError, - "Swiss Fertility") +verifyRegression(model, expectedBeta, expectedResiduals, expectedErrors, +expectedStdError, expectedRSquare, expectedAdjRSquare, "Swiss Fertility") displayDashes(WIDTH) diff --git a/src/test/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegressionTest.java b/src/test/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegressionTest.java index bc43486e0..3b39bc477 100644 --- a/src/test/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegressionTest.java +++ b/src/test/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegressionTest.java @@ -109,7 +109,7 @@ public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs assertEquals(0.0, errors.subtract(referenceVariance).getNorm(), 5.0e-16 * referenceVariance.getNorm()); - + assertEquals(1, ((OLSMultipleLinearRegression) regression).calculateRSquared(), 1E-12); } @@ -186,6 +186,10 @@ public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs // Check regression standard error against R assertEquals(304.8540735619638, model.estimateRegressionStandardError(), 1E-10); + // Check R-Square statistics against R + assertEquals(0.995479004577296, model.calculateRSquared(), 1E-12); + assertEquals(0.992465007628826, model.calculateAdjustedRSquared(), 1E-12); + checkVarianceConsistency(model); } @@ -294,6 +298,10 @@ public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs // Check regression standard error against R assertEquals(7.73642194433223, model.estimateRegressionStandardError(), 1E-12); + // Check R-Square statistics against R + assertEquals(0.649789742860228, model.calculateRSquared(), 1E-12); + assertEquals(0.6164363850373927, model.calculateAdjustedRSquared(), 1E-12); + checkVarianceConsistency(model); }