Corrected Y variance formula and added error variance methods to return

what were previously reported as Y variances.
JIRA: MATH-392
Reported and patched by Mark Devaney


git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@987897 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Phil Steitz 2010-08-22 13:13:35 +00:00
parent a46c441cec
commit 30c9e8c111
7 changed files with 136 additions and 33 deletions

View File

@ -22,6 +22,7 @@ import org.apache.commons.math.linear.RealMatrix;
import org.apache.commons.math.linear.Array2DRowRealMatrix; import org.apache.commons.math.linear.Array2DRowRealMatrix;
import org.apache.commons.math.linear.RealVector; import org.apache.commons.math.linear.RealVector;
import org.apache.commons.math.linear.ArrayRealVector; import org.apache.commons.math.linear.ArrayRealVector;
import org.apache.commons.math.stat.descriptive.moment.Variance;
/** /**
* Abstract base class for implementations of MultipleLinearRegression. * Abstract base class for implementations of MultipleLinearRegression.
@ -148,7 +149,8 @@ public abstract class AbstractMultipleLinearRegression implements
*/ */
public double[] estimateRegressionParametersStandardErrors() { public double[] estimateRegressionParametersStandardErrors() {
double[][] betaVariance = estimateRegressionParametersVariance(); double[][] betaVariance = estimateRegressionParametersVariance();
double sigma = calculateYVariance(); RealVector residuals = calculateResiduals();
double sigma = calculateErrorVariance();
int length = betaVariance[0].length; int length = betaVariance[0].length;
double[] result = new double[length]; double[] result = new double[length];
for (int i = 0; i < length; i++) { for (int i = 0; i < length; i++) {
@ -164,6 +166,25 @@ public abstract class AbstractMultipleLinearRegression implements
return calculateYVariance(); return calculateYVariance();
} }
/**
* Estimates the variance of the error.
*
* @return estimate of the error variance
*/
public double estimateErrorVariance() {
return calculateErrorVariance();
}
/**
* Estimates the standard error of the regression.
*
* @return regression standard error
*/
public double estimateRegressionStandardError() {
return Math.sqrt(estimateErrorVariance());
}
/** /**
* Calculates the beta of multiple linear regression in matrix notation. * Calculates the beta of multiple linear regression in matrix notation.
* *
@ -179,12 +200,31 @@ public abstract class AbstractMultipleLinearRegression implements
*/ */
protected abstract RealMatrix calculateBetaVariance(); protected abstract RealMatrix calculateBetaVariance();
/** /**
* Calculates the Y variance of multiple linear regression. * Calculates the variance of the y values.
* *
* @return Y variance * @return Y variance
*/ */
protected abstract double calculateYVariance(); protected double calculateYVariance() {
return new Variance().evaluate(Y.getData());
}
/**
* <p>Calculates the variance of the error term.</p>
* Uses the formula <pre>
* var(u) = u &middot; u / (n - k)
* </pre>
* where n and k are the row and column dimensions of the design
* matrix X.
*
* @return error variance estimate
*/
protected double calculateErrorVariance() {
RealVector residuals = calculateResiduals();
return residuals.dotProduct(residuals) /
(X.getRowDimension() - X.getColumnDimension());
}
/** /**
* Calculates the residuals of multiple linear regression in matrix * Calculates the residuals of multiple linear regression in matrix

View File

@ -21,7 +21,6 @@ import org.apache.commons.math.linear.RealMatrix;
import org.apache.commons.math.linear.Array2DRowRealMatrix; import org.apache.commons.math.linear.Array2DRowRealMatrix;
import org.apache.commons.math.linear.RealVector; import org.apache.commons.math.linear.RealVector;
/** /**
* The GLS implementation of the multiple linear regression. * The GLS implementation of the multiple linear regression.
* *
@ -101,7 +100,7 @@ public class GLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
} }
/** /**
* Calculates the variance on the beta by GLS. * Calculates the variance on the beta.
* <pre> * <pre>
* Var(b)=(X' Omega^-1 X)^-1 * Var(b)=(X' Omega^-1 X)^-1
* </pre> * </pre>
@ -114,18 +113,23 @@ public class GLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
return new LUDecompositionImpl(XTOIX).getSolver().getInverse(); return new LUDecompositionImpl(XTOIX).getSolver().getInverse();
} }
/** /**
* Calculates the variance on the y by GLS. * Calculates the estimated variance of the error term using the formula
* <pre> * <pre>
* Var(y)=Tr(u' Omega^-1 u)/(n-k) * Var(u) = Tr(u' Omega^-1 u)/(n-k)
* </pre> * </pre>
* @return The Y variance * where n and k are the row and column dimensions of the design
* matrix X.
*
* @return error variance
*/ */
@Override @Override
protected double calculateYVariance() { protected double calculateErrorVariance() {
RealVector residuals = calculateResiduals(); RealVector residuals = calculateResiduals();
double t = residuals.dotProduct(getOmegaInverse().operate(residuals)); double t = residuals.dotProduct(getOmegaInverse().operate(residuals));
return t / (X.getRowDimension() - X.getColumnDimension()); return t / (X.getRowDimension() - X.getColumnDimension());
} }
} }

View File

@ -49,7 +49,7 @@ import org.apache.commons.math.linear.RealVector;
* (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> R b = (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> Q<sup>T</sup> y <br/> * (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> R b = (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> Q<sup>T</sup> y <br/>
* R b = Q<sup>T</sup> y * R b = Q<sup>T</sup> y
* </p> * </p>
* Given Q and R, the last equation is solved by back-subsitution.</p> * Given Q and R, the last equation is solved by back-substitution.</p>
* *
* @version $Revision$ $Date$ * @version $Revision$ $Date$
* @since 2.0 * @since 2.0
@ -161,19 +161,4 @@ public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
return Rinv.multiply(Rinv.transpose()); return Rinv.multiply(Rinv.transpose());
} }
/**
* <p>Calculates the variance on the Y by OLS.
* </p>
* <p> Var(y) = Tr(u<sup>T</sup>u)/(n - k)
* </p>
* @return The Y variance
*/
@Override
protected double calculateYVariance() {
RealVector residuals = calculateResiduals();
return residuals.dotProduct(residuals) /
(X.getRowDimension() - X.getColumnDimension());
}
} }

View File

@ -52,6 +52,13 @@ The <action> type attribute can be add,update,fix,remove.
If the output is not quite correct, check for invisible trailing spaces! If the output is not quite correct, check for invisible trailing spaces!
--> -->
<release version="2.2" date="TBD" description="TBD"> <release version="2.2" date="TBD" description="TBD">
<action dev="psteitz" type="fix" issue="MATH-392" due-to="Mark Devaney">
Corrected the formula used for Y variance returned by calculateYVariance and associated
methods in multiple regression classes (AbstractMultipleLinearRegression,
OLSMultipleLinearRegression, GLSMultipleLinearRegression). These methods previously returned
estimates of the variance in the model error term. New "calulateErrorVariance" methods have
been added to compute what was previously returned by calculateYVariance.
</action>
<action dev="dimpbx" type="fix" issue="MATH-406"> <action dev="dimpbx" type="fix" issue="MATH-406">
Bug fixed in Levenberg-Marquardt (handling of weights). Bug fixed in Levenberg-Marquardt (handling of weights).
</action> </action>

View File

@ -32,10 +32,11 @@ options(digits=16) # override number of digits displayed
# function to verify OLS computations # function to verify OLS computations
verifyRegression <- function(model, expectedBeta, expectedResiduals, verifyRegression <- function(model, expectedBeta, expectedResiduals,
expectedErrors, modelName) { expectedErrors, expectedStdError, modelName) {
betaHat <- as.vector(coefficients(model)) betaHat <- as.vector(coefficients(model))
residuals <- as.vector(residuals(model)) residuals <- as.vector(residuals(model))
errors <- as.vector(as.matrix(coefficients(summary(model)))[,2]) errors <- as.vector(as.matrix(coefficients(summary(model)))[,2])
stdError <- summary(model)$sigma
output <- c("Parameter test dataset = ", modelName) output <- c("Parameter test dataset = ", modelName)
if (assertEquals(expectedBeta,betaHat,tol,"Parameters")) { if (assertEquals(expectedBeta,betaHat,tol,"Parameters")) {
displayPadded(output, SUCCEEDED, WIDTH) displayPadded(output, SUCCEEDED, WIDTH)
@ -54,6 +55,12 @@ verifyRegression <- function(model, expectedBeta, expectedResiduals,
} else { } else {
displayPadded(output, FAILED, WIDTH) displayPadded(output, FAILED, WIDTH)
} }
output <- c("Standard Error test dataset = ", modelName)
if (assertEquals(expectedStdError,stdError,tol,"Regression Standard Error")) {
displayPadded(output, SUCCEEDED, WIDTH)
} else {
displayPadded(output, FAILED, WIDTH)
}
} }
#-------------------------------------------------------------------------- #--------------------------------------------------------------------------
@ -70,7 +77,8 @@ model <- lm(y ~ x1 + x2 + x3 + x4 + x5)
expectedBeta <- c(11.0,0.5,0.666666666666667,0.75,0.8,0.8333333333333333) expectedBeta <- c(11.0,0.5,0.666666666666667,0.75,0.8,0.8333333333333333)
expectedResiduals <- c(0,0,0,0,0,0) expectedResiduals <- c(0,0,0,0,0,0)
expectedErrors <- c(NaN,NaN,NaN,NaN,NaN,NaN) expectedErrors <- c(NaN,NaN,NaN,NaN,NaN,NaN)
verifyRegression(model, expectedBeta, expectedResiduals, expectedErrors, expectedStdError <- NaN
verifyRegression(model, expectedBeta, expectedResiduals, expectedErrors, expectedStdError,
"perfect fit") "perfect fit")
# Longly # Longly
@ -125,8 +133,9 @@ expectedResiduals <- c( 267.340029759711,-94.0139423988359,46.28716775752924,
-13.18035686637081,14.30477260005235,455.394094551857,-17.26892711483297, -13.18035686637081,14.30477260005235,455.394094551857,-17.26892711483297,
-39.0550425226967,-155.5499735953195,-85.6713080421283,341.9315139607727, -39.0550425226967,-155.5499735953195,-85.6713080421283,341.9315139607727,
-206.7578251937366) -206.7578251937366)
expectedStdError <- 304.8540735619638
verifyRegression(model, expectedBeta, expectedResiduals, expectedErrors, verifyRegression(model, expectedBeta, expectedResiduals, expectedErrors, expectedStdError,
"Longly") "Longly")
# Swiss Fertility (R dataset named "swiss") # Swiss Fertility (R dataset named "swiss")
@ -215,7 +224,9 @@ expectedResiduals <- c(7.1044267859730512,1.6580347433531366,
15.0147574652763112,4.8625103516321015,-7.1597256413907706, 15.0147574652763112,4.8625103516321015,-7.1597256413907706,
-0.4515205619767598,-10.2916870903837587,-15.7812984571900063) -0.4515205619767598,-10.2916870903837587,-15.7812984571900063)
verifyRegression(model, expectedBeta, expectedResiduals, expectedErrors, expectedStdError <- 7.73642194433223
verifyRegression(model, expectedBeta, expectedResiduals, expectedErrors, expectedStdError,
"Swiss Fertility") "Swiss Fertility")

View File

@ -18,6 +18,8 @@ package org.apache.commons.math.stat.regression;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.apache.commons.math.TestUtils;
import org.apache.commons.math.stat.StatUtils;
public class GLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbstractTest { public class GLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbstractTest {
@ -121,4 +123,16 @@ public class GLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs
return y.length; return y.length;
} }
/**
* test calculateYVariance
*/
@Test
public void testYVariance() {
// assumes: y = new double[]{11.0, 12.0, 13.0, 14.0, 15.0, 16.0};
GLSMultipleLinearRegression model = new GLSMultipleLinearRegression();
model.newSampleData(y, x, omega);
TestUtils.assertEquals(model.calculateYVariance(), 3.5, 0);
}
} }

View File

@ -24,6 +24,7 @@ import org.apache.commons.math.linear.MatrixUtils;
import org.apache.commons.math.linear.MatrixVisitorException; import org.apache.commons.math.linear.MatrixVisitorException;
import org.apache.commons.math.linear.RealMatrix; import org.apache.commons.math.linear.RealMatrix;
import org.apache.commons.math.linear.Array2DRowRealMatrix; import org.apache.commons.math.linear.Array2DRowRealMatrix;
import org.apache.commons.math.stat.StatUtils;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
@ -82,7 +83,7 @@ public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs
} }
@Test @Test
public void testPerfectFit() { public void testPerfectFit() throws Exception {
double[] betaHat = regression.estimateRegressionParameters(); double[] betaHat = regression.estimateRegressionParameters();
TestUtils.assertEquals(betaHat, TestUtils.assertEquals(betaHat,
new double[]{ 11.0, 1.0 / 2.0, 2.0 / 3.0, 3.0 / 4.0, 4.0 / 5.0, 5.0 / 6.0 }, new double[]{ 11.0, 1.0 / 2.0, 2.0 / 3.0, 3.0 / 4.0, 4.0 / 5.0, 5.0 / 6.0 },
@ -108,6 +109,7 @@ public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs
assertEquals(0.0, assertEquals(0.0,
errors.subtract(referenceVariance).getNorm(), errors.subtract(referenceVariance).getNorm(),
5.0e-16 * referenceVariance.getNorm()); 5.0e-16 * referenceVariance.getNorm());
} }
@ -122,7 +124,7 @@ public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs
* http://www.itl.nist.gov/div898/strd/lls/data/LINKS/DATA/Longley.dat * http://www.itl.nist.gov/div898/strd/lls/data/LINKS/DATA/Longley.dat
*/ */
@Test @Test
public void testLongly() { public void testLongly() throws Exception {
// Y values are first, then independent vars // Y values are first, then independent vars
// Each row is one observation // Each row is one observation
double[] design = new double[] { double[] design = new double[] {
@ -180,6 +182,11 @@ public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs
0.214274163161675, 0.214274163161675,
0.226073200069370, 0.226073200069370,
455.478499142212}, errors, 1E-6); 455.478499142212}, errors, 1E-6);
// Check regression standard error against R
assertEquals(304.8540735619638, model.estimateRegressionStandardError(), 1E-10);
checkVarianceConsistency(model);
} }
/** /**
@ -187,7 +194,7 @@ public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs
* Data Source: R datasets package * Data Source: R datasets package
*/ */
@Test @Test
public void testSwissFertility() { public void testSwissFertility() throws Exception {
double[] design = new double[] { double[] design = new double[] {
80.2,17.0,15,12,9.96, 80.2,17.0,15,12,9.96,
83.1,45.1,6,9,84.84, 83.1,45.1,6,9,84.84,
@ -283,6 +290,11 @@ public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs
0.27410957467466, 0.27410957467466,
0.19454551679325, 0.19454551679325,
0.03726654773803}, errors, 1E-10); 0.03726654773803}, errors, 1E-10);
// Check regression standard error against R
assertEquals(7.73642194433223, model.estimateRegressionStandardError(), 1E-12);
checkVarianceConsistency(model);
} }
/** /**
@ -354,4 +366,34 @@ public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs
double[] hatResiduals = I.subtract(hat).operate(model.Y).getData(); double[] hatResiduals = I.subtract(hat).operate(model.Y).getData();
TestUtils.assertEquals(residuals, hatResiduals, 10e-12); TestUtils.assertEquals(residuals, hatResiduals, 10e-12);
} }
/**
* test calculateYVariance
*/
@Test
public void testYVariance() {
// assumes: y = new double[]{11.0, 12.0, 13.0, 14.0, 15.0, 16.0};
OLSMultipleLinearRegression model = new OLSMultipleLinearRegression();
model.newSampleData(y, x);
TestUtils.assertEquals(model.calculateYVariance(), 3.5, 0);
}
/**
* Verifies that calculateYVariance and calculateResidualVariance return consistent
* values with direct variance computation from Y, residuals, respectively.
*/
protected void checkVarianceConsistency(OLSMultipleLinearRegression model) throws Exception {
// Check Y variance consistency
TestUtils.assertEquals(StatUtils.variance(model.Y.getData()), model.calculateYVariance(), 0);
// Check residual variance consistency
double[] residuals = model.calculateResiduals().getData();
RealMatrix X = model.X;
TestUtils.assertEquals(
StatUtils.variance(model.calculateResiduals().getData()) * (residuals.length - 1),
model.calculateErrorVariance() * (X.getRowDimension() - X.getColumnDimension()), 1E-20);
}
} }