Added R-squared and adjusted R-squared statistics to

OLSMultipleLinearRegression
JIRA: MATH-386



git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@987983 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Phil Steitz 2010-08-23 02:55:01 +00:00
parent ad63c1629d
commit 95ebe8294c
4 changed files with 89 additions and 8 deletions

View File

@ -22,6 +22,7 @@ import org.apache.commons.math.linear.QRDecomposition;
import org.apache.commons.math.linear.QRDecompositionImpl;
import org.apache.commons.math.linear.RealMatrix;
import org.apache.commons.math.linear.RealVector;
import org.apache.commons.math.stat.descriptive.moment.SecondMoment;
/**
* <p>Implements ordinary least squares (OLS) to estimate the parameters of a
@ -121,6 +122,55 @@ public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
return Q.multiply(augI).multiply(Q.transpose());
}
/**
* Returns the sum of squared deviations of Y from its mean.
*
* @return total sum of squares
*/
public double calculateTotalSumOfSquares() {
return new SecondMoment().evaluate(Y.getData());
}
/**
* Returns the sum of square residuals.
*
* @return residual sum of squares
*/
public double calculateResidualSumOfSquares() {
final RealVector residuals = calculateResiduals();
return residuals.dotProduct(residuals);
}
/**
* Returns the R-Squared statistic, defined by the formula <pre>
* R<sup>2</sup> = 1 - SSR / SSTO
* </pre>
* where SSR is the {@link #calculateResidualSumOfSquares() sum of squared residuals}
* and SSTO is the {@link #calculateTotalSumOfSquares() total sum of squares}
*
* @return R-square statistic
*/
public double calculateRSquared() {
return 1 - calculateResidualSumOfSquares() / calculateTotalSumOfSquares();
}
/**
* Returns the adjusted R-squared statistic, defined by the formula <pre>
* R<sup>2</sup><sub>adj</sub> = 1 - [SSR (n - 1)] / [SSTO (n - p)]
* </pre>
* where SSR is the {@link #calculateResidualSumOfSquares() sum of squared residuals},
* SSTO is the {@link #calculateTotalSumOfSquares() total sum of squares}, n is the number
* of observations and p is the number of parameters estimated (including the intercept).
*
* @return adjusted R-Squared statistic
*/
public double calculateAdjustedRSquared() {
final double n = X.getRowDimension();
return 1 - (calculateResidualSumOfSquares() * (n - 1)) /
(calculateTotalSumOfSquares() * (n - X.getColumnDimension()));
// return 1 - ((1 - calculateRSquare()) * (n - 1) / (n - X.getColumnDimension() - 1));
}
/**
* Loads new x sample data, overriding any previous sample
*

View File

@ -52,6 +52,9 @@ The <action> type attribute can be add,update,fix,remove.
If the output is not quite correct, check for invisible trailing spaces!
-->
<release version="2.2" date="TBD" description="TBD">
<action dev="psteitz" type="fix" issue="MATH-386">
Added R-squared and adjusted R-squared statistics to OLSMultipleLinearRegression.
</action>
<action dev="psteitz" type="fix" issue="MATH-392" due-to="Mark Devaney">
Corrected the formula used for Y variance returned by calculateYVariance and associated
methods in multiple regression classes (AbstractMultipleLinearRegression,

View File

@ -32,11 +32,13 @@ options(digits=16) # override number of digits displayed
# function to verify OLS computations
verifyRegression <- function(model, expectedBeta, expectedResiduals,
expectedErrors, expectedStdError, modelName) {
expectedErrors, expectedStdError, expectedRSquare, expecteAdjRSquare, modelName) {
betaHat <- as.vector(coefficients(model))
residuals <- as.vector(residuals(model))
errors <- as.vector(as.matrix(coefficients(summary(model)))[,2])
stdError <- summary(model)$sigma
rSquare <- summary(model)$r.squared
adjRSquare <- summary(model)$adj.r.squared
output <- c("Parameter test dataset = ", modelName)
if (assertEquals(expectedBeta,betaHat,tol,"Parameters")) {
displayPadded(output, SUCCEEDED, WIDTH)
@ -61,6 +63,18 @@ verifyRegression <- function(model, expectedBeta, expectedResiduals,
} else {
displayPadded(output, FAILED, WIDTH)
}
output <- c("RSquared test dataset = ", modelName)
if (assertEquals(expectedRSquare,rSquare,tol,"RSquared")) {
displayPadded(output, SUCCEEDED, WIDTH)
} else {
displayPadded(output, FAILED, WIDTH)
}
output <- c("Adjusted RSquared test dataset = ", modelName)
if (assertEquals(expecteAdjRSquare,adjRSquare,tol,"Adjusted RSquared")) {
displayPadded(output, SUCCEEDED, WIDTH)
} else {
displayPadded(output, FAILED, WIDTH)
}
}
#--------------------------------------------------------------------------
@ -78,8 +92,10 @@ expectedBeta <- c(11.0,0.5,0.666666666666667,0.75,0.8,0.8333333333333333)
expectedResiduals <- c(0,0,0,0,0,0)
expectedErrors <- c(NaN,NaN,NaN,NaN,NaN,NaN)
expectedStdError <- NaN
verifyRegression(model, expectedBeta, expectedResiduals, expectedErrors, expectedStdError,
"perfect fit")
expectedRSquare <- 1
expectedAdjRSquare <- NaN
verifyRegression(model, expectedBeta, expectedResiduals, expectedErrors,
expectedStdError, expectedRSquare, expectedAdjRSquare, "perfect fit")
# Longly
#
@ -134,9 +150,11 @@ expectedResiduals <- c( 267.340029759711,-94.0139423988359,46.28716775752924,
-39.0550425226967,-155.5499735953195,-85.6713080421283,341.9315139607727,
-206.7578251937366)
expectedStdError <- 304.8540735619638
expectedRSquare <- 0.995479004577296
expectedAdjRSquare <- 0.992465007628826
verifyRegression(model, expectedBeta, expectedResiduals, expectedErrors, expectedStdError,
"Longly")
verifyRegression(model, expectedBeta, expectedResiduals, expectedErrors,
expectedStdError, expectedRSquare, expectedAdjRSquare, "Longly")
# Swiss Fertility (R dataset named "swiss")
@ -225,9 +243,11 @@ expectedResiduals <- c(7.1044267859730512,1.6580347433531366,
-0.4515205619767598,-10.2916870903837587,-15.7812984571900063)
expectedStdError <- 7.73642194433223
expectedRSquare <- 0.649789742860228
expectedAdjRSquare <- 0.6164363850373927
verifyRegression(model, expectedBeta, expectedResiduals, expectedErrors, expectedStdError,
"Swiss Fertility")
verifyRegression(model, expectedBeta, expectedResiduals, expectedErrors,
expectedStdError, expectedRSquare, expectedAdjRSquare, "Swiss Fertility")
displayDashes(WIDTH)

View File

@ -109,7 +109,7 @@ public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs
assertEquals(0.0,
errors.subtract(referenceVariance).getNorm(),
5.0e-16 * referenceVariance.getNorm());
assertEquals(1, ((OLSMultipleLinearRegression) regression).calculateRSquared(), 1E-12);
}
@ -186,6 +186,10 @@ public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs
// Check regression standard error against R
assertEquals(304.8540735619638, model.estimateRegressionStandardError(), 1E-10);
// Check R-Square statistics against R
assertEquals(0.995479004577296, model.calculateRSquared(), 1E-12);
assertEquals(0.992465007628826, model.calculateAdjustedRSquared(), 1E-12);
checkVarianceConsistency(model);
}
@ -294,6 +298,10 @@ public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs
// Check regression standard error against R
assertEquals(7.73642194433223, model.estimateRegressionStandardError(), 1E-12);
// Check R-Square statistics against R
assertEquals(0.649789742860228, model.calculateRSquared(), 1E-12);
assertEquals(0.6164363850373927, model.calculateAdjustedRSquared(), 1E-12);
checkVarianceConsistency(model);
}