Added R-squared and adjusted R-squared statistics to
OLSMultipleLinearRegression JIRA: MATH-386 git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@987983 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
ad63c1629d
commit
95ebe8294c
|
@ -22,6 +22,7 @@ import org.apache.commons.math.linear.QRDecomposition;
|
|||
import org.apache.commons.math.linear.QRDecompositionImpl;
|
||||
import org.apache.commons.math.linear.RealMatrix;
|
||||
import org.apache.commons.math.linear.RealVector;
|
||||
import org.apache.commons.math.stat.descriptive.moment.SecondMoment;
|
||||
|
||||
/**
|
||||
* <p>Implements ordinary least squares (OLS) to estimate the parameters of a
|
||||
|
@ -121,6 +122,55 @@ public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
|
|||
return Q.multiply(augI).multiply(Q.transpose());
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the sum of squared deviations of Y from its mean.
|
||||
*
|
||||
* @return total sum of squares
|
||||
*/
|
||||
public double calculateTotalSumOfSquares() {
|
||||
return new SecondMoment().evaluate(Y.getData());
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the sum of square residuals.
|
||||
*
|
||||
* @return residual sum of squares
|
||||
*/
|
||||
public double calculateResidualSumOfSquares() {
|
||||
final RealVector residuals = calculateResiduals();
|
||||
return residuals.dotProduct(residuals);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the R-Squared statistic, defined by the formula <pre>
|
||||
* R<sup>2</sup> = 1 - SSR / SSTO
|
||||
* </pre>
|
||||
* where SSR is the {@link #calculateResidualSumOfSquares() sum of squared residuals}
|
||||
* and SSTO is the {@link #calculateTotalSumOfSquares() total sum of squares}
|
||||
*
|
||||
* @return R-square statistic
|
||||
*/
|
||||
public double calculateRSquared() {
|
||||
return 1 - calculateResidualSumOfSquares() / calculateTotalSumOfSquares();
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the adjusted R-squared statistic, defined by the formula <pre>
|
||||
* R<sup>2</sup><sub>adj</sub> = 1 - [SSR (n - 1)] / [SSTO (n - p)]
|
||||
* </pre>
|
||||
* where SSR is the {@link #calculateResidualSumOfSquares() sum of squared residuals},
|
||||
* SSTO is the {@link #calculateTotalSumOfSquares() total sum of squares}, n is the number
|
||||
* of observations and p is the number of parameters estimated (including the intercept).
|
||||
*
|
||||
* @return adjusted R-Squared statistic
|
||||
*/
|
||||
public double calculateAdjustedRSquared() {
|
||||
final double n = X.getRowDimension();
|
||||
return 1 - (calculateResidualSumOfSquares() * (n - 1)) /
|
||||
(calculateTotalSumOfSquares() * (n - X.getColumnDimension()));
|
||||
// return 1 - ((1 - calculateRSquare()) * (n - 1) / (n - X.getColumnDimension() - 1));
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads new x sample data, overriding any previous sample
|
||||
*
|
||||
|
|
|
@ -52,6 +52,9 @@ The <action> type attribute can be add,update,fix,remove.
|
|||
If the output is not quite correct, check for invisible trailing spaces!
|
||||
-->
|
||||
<release version="2.2" date="TBD" description="TBD">
|
||||
<action dev="psteitz" type="fix" issue="MATH-386">
|
||||
Added R-squared and adjusted R-squared statistics to OLSMultipleLinearRegression.
|
||||
</action>
|
||||
<action dev="psteitz" type="fix" issue="MATH-392" due-to="Mark Devaney">
|
||||
Corrected the formula used for Y variance returned by calculateYVariance and associated
|
||||
methods in multiple regression classes (AbstractMultipleLinearRegression,
|
||||
|
|
|
@ -32,11 +32,13 @@ options(digits=16) # override number of digits displayed
|
|||
# function to verify OLS computations
|
||||
|
||||
verifyRegression <- function(model, expectedBeta, expectedResiduals,
|
||||
expectedErrors, expectedStdError, modelName) {
|
||||
expectedErrors, expectedStdError, expectedRSquare, expecteAdjRSquare, modelName) {
|
||||
betaHat <- as.vector(coefficients(model))
|
||||
residuals <- as.vector(residuals(model))
|
||||
errors <- as.vector(as.matrix(coefficients(summary(model)))[,2])
|
||||
stdError <- summary(model)$sigma
|
||||
rSquare <- summary(model)$r.squared
|
||||
adjRSquare <- summary(model)$adj.r.squared
|
||||
output <- c("Parameter test dataset = ", modelName)
|
||||
if (assertEquals(expectedBeta,betaHat,tol,"Parameters")) {
|
||||
displayPadded(output, SUCCEEDED, WIDTH)
|
||||
|
@ -61,6 +63,18 @@ verifyRegression <- function(model, expectedBeta, expectedResiduals,
|
|||
} else {
|
||||
displayPadded(output, FAILED, WIDTH)
|
||||
}
|
||||
output <- c("RSquared test dataset = ", modelName)
|
||||
if (assertEquals(expectedRSquare,rSquare,tol,"RSquared")) {
|
||||
displayPadded(output, SUCCEEDED, WIDTH)
|
||||
} else {
|
||||
displayPadded(output, FAILED, WIDTH)
|
||||
}
|
||||
output <- c("Adjusted RSquared test dataset = ", modelName)
|
||||
if (assertEquals(expecteAdjRSquare,adjRSquare,tol,"Adjusted RSquared")) {
|
||||
displayPadded(output, SUCCEEDED, WIDTH)
|
||||
} else {
|
||||
displayPadded(output, FAILED, WIDTH)
|
||||
}
|
||||
}
|
||||
|
||||
#--------------------------------------------------------------------------
|
||||
|
@ -78,8 +92,10 @@ expectedBeta <- c(11.0,0.5,0.666666666666667,0.75,0.8,0.8333333333333333)
|
|||
expectedResiduals <- c(0,0,0,0,0,0)
|
||||
expectedErrors <- c(NaN,NaN,NaN,NaN,NaN,NaN)
|
||||
expectedStdError <- NaN
|
||||
verifyRegression(model, expectedBeta, expectedResiduals, expectedErrors, expectedStdError,
|
||||
"perfect fit")
|
||||
expectedRSquare <- 1
|
||||
expectedAdjRSquare <- NaN
|
||||
verifyRegression(model, expectedBeta, expectedResiduals, expectedErrors,
|
||||
expectedStdError, expectedRSquare, expectedAdjRSquare, "perfect fit")
|
||||
|
||||
# Longly
|
||||
#
|
||||
|
@ -134,9 +150,11 @@ expectedResiduals <- c( 267.340029759711,-94.0139423988359,46.28716775752924,
|
|||
-39.0550425226967,-155.5499735953195,-85.6713080421283,341.9315139607727,
|
||||
-206.7578251937366)
|
||||
expectedStdError <- 304.8540735619638
|
||||
expectedRSquare <- 0.995479004577296
|
||||
expectedAdjRSquare <- 0.992465007628826
|
||||
|
||||
verifyRegression(model, expectedBeta, expectedResiduals, expectedErrors, expectedStdError,
|
||||
"Longly")
|
||||
verifyRegression(model, expectedBeta, expectedResiduals, expectedErrors,
|
||||
expectedStdError, expectedRSquare, expectedAdjRSquare, "Longly")
|
||||
|
||||
# Swiss Fertility (R dataset named "swiss")
|
||||
|
||||
|
@ -225,9 +243,11 @@ expectedResiduals <- c(7.1044267859730512,1.6580347433531366,
|
|||
-0.4515205619767598,-10.2916870903837587,-15.7812984571900063)
|
||||
|
||||
expectedStdError <- 7.73642194433223
|
||||
expectedRSquare <- 0.649789742860228
|
||||
expectedAdjRSquare <- 0.6164363850373927
|
||||
|
||||
verifyRegression(model, expectedBeta, expectedResiduals, expectedErrors, expectedStdError,
|
||||
"Swiss Fertility")
|
||||
verifyRegression(model, expectedBeta, expectedResiduals, expectedErrors,
|
||||
expectedStdError, expectedRSquare, expectedAdjRSquare, "Swiss Fertility")
|
||||
|
||||
|
||||
displayDashes(WIDTH)
|
||||
|
|
|
@ -109,7 +109,7 @@ public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs
|
|||
assertEquals(0.0,
|
||||
errors.subtract(referenceVariance).getNorm(),
|
||||
5.0e-16 * referenceVariance.getNorm());
|
||||
|
||||
assertEquals(1, ((OLSMultipleLinearRegression) regression).calculateRSquared(), 1E-12);
|
||||
}
|
||||
|
||||
|
||||
|
@ -186,6 +186,10 @@ public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs
|
|||
// Check regression standard error against R
|
||||
assertEquals(304.8540735619638, model.estimateRegressionStandardError(), 1E-10);
|
||||
|
||||
// Check R-Square statistics against R
|
||||
assertEquals(0.995479004577296, model.calculateRSquared(), 1E-12);
|
||||
assertEquals(0.992465007628826, model.calculateAdjustedRSquared(), 1E-12);
|
||||
|
||||
checkVarianceConsistency(model);
|
||||
}
|
||||
|
||||
|
@ -294,6 +298,10 @@ public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs
|
|||
// Check regression standard error against R
|
||||
assertEquals(7.73642194433223, model.estimateRegressionStandardError(), 1E-12);
|
||||
|
||||
// Check R-Square statistics against R
|
||||
assertEquals(0.649789742860228, model.calculateRSquared(), 1E-12);
|
||||
assertEquals(0.6164363850373927, model.calculateAdjustedRSquared(), 1E-12);
|
||||
|
||||
checkVarianceConsistency(model);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue