Corrected Y variance formula and added error variance methods to return
what were previously reported as Y variances. JIRA: MATH-392 Reported and patched by Mark Devaney git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@987897 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
a46c441cec
commit
30c9e8c111
|
@ -22,6 +22,7 @@ import org.apache.commons.math.linear.RealMatrix;
|
|||
import org.apache.commons.math.linear.Array2DRowRealMatrix;
|
||||
import org.apache.commons.math.linear.RealVector;
|
||||
import org.apache.commons.math.linear.ArrayRealVector;
|
||||
import org.apache.commons.math.stat.descriptive.moment.Variance;
|
||||
|
||||
/**
|
||||
* Abstract base class for implementations of MultipleLinearRegression.
|
||||
|
@ -148,7 +149,8 @@ public abstract class AbstractMultipleLinearRegression implements
|
|||
*/
|
||||
public double[] estimateRegressionParametersStandardErrors() {
|
||||
double[][] betaVariance = estimateRegressionParametersVariance();
|
||||
double sigma = calculateYVariance();
|
||||
RealVector residuals = calculateResiduals();
|
||||
double sigma = calculateErrorVariance();
|
||||
int length = betaVariance[0].length;
|
||||
double[] result = new double[length];
|
||||
for (int i = 0; i < length; i++) {
|
||||
|
@ -164,6 +166,25 @@ public abstract class AbstractMultipleLinearRegression implements
|
|||
return calculateYVariance();
|
||||
}
|
||||
|
||||
/**
|
||||
* Estimates the variance of the error.
|
||||
*
|
||||
* @return estimate of the error variance
|
||||
*/
|
||||
public double estimateErrorVariance() {
|
||||
return calculateErrorVariance();
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Estimates the standard error of the regression.
|
||||
*
|
||||
* @return regression standard error
|
||||
*/
|
||||
public double estimateRegressionStandardError() {
|
||||
return Math.sqrt(estimateErrorVariance());
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculates the beta of multiple linear regression in matrix notation.
|
||||
*
|
||||
|
@ -179,12 +200,31 @@ public abstract class AbstractMultipleLinearRegression implements
|
|||
*/
|
||||
protected abstract RealMatrix calculateBetaVariance();
|
||||
|
||||
|
||||
/**
|
||||
* Calculates the Y variance of multiple linear regression.
|
||||
* Calculates the variance of the y values.
|
||||
*
|
||||
* @return Y variance
|
||||
*/
|
||||
protected abstract double calculateYVariance();
|
||||
protected double calculateYVariance() {
|
||||
return new Variance().evaluate(Y.getData());
|
||||
}
|
||||
|
||||
/**
|
||||
* <p>Calculates the variance of the error term.</p>
|
||||
* Uses the formula <pre>
|
||||
* var(u) = u · u / (n - k)
|
||||
* </pre>
|
||||
* where n and k are the row and column dimensions of the design
|
||||
* matrix X.
|
||||
*
|
||||
* @return error variance estimate
|
||||
*/
|
||||
protected double calculateErrorVariance() {
|
||||
RealVector residuals = calculateResiduals();
|
||||
return residuals.dotProduct(residuals) /
|
||||
(X.getRowDimension() - X.getColumnDimension());
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculates the residuals of multiple linear regression in matrix
|
||||
|
|
|
@ -21,7 +21,6 @@ import org.apache.commons.math.linear.RealMatrix;
|
|||
import org.apache.commons.math.linear.Array2DRowRealMatrix;
|
||||
import org.apache.commons.math.linear.RealVector;
|
||||
|
||||
|
||||
/**
|
||||
* The GLS implementation of the multiple linear regression.
|
||||
*
|
||||
|
@ -101,7 +100,7 @@ public class GLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
|
|||
}
|
||||
|
||||
/**
|
||||
* Calculates the variance on the beta by GLS.
|
||||
* Calculates the variance on the beta.
|
||||
* <pre>
|
||||
* Var(b)=(X' Omega^-1 X)^-1
|
||||
* </pre>
|
||||
|
@ -114,18 +113,23 @@ public class GLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
|
|||
return new LUDecompositionImpl(XTOIX).getSolver().getInverse();
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Calculates the variance on the y by GLS.
|
||||
* Calculates the estimated variance of the error term using the formula
|
||||
* <pre>
|
||||
* Var(y)=Tr(u' Omega^-1 u)/(n-k)
|
||||
* Var(u) = Tr(u' Omega^-1 u)/(n-k)
|
||||
* </pre>
|
||||
* @return The Y variance
|
||||
* where n and k are the row and column dimensions of the design
|
||||
* matrix X.
|
||||
*
|
||||
* @return error variance
|
||||
*/
|
||||
@Override
|
||||
protected double calculateYVariance() {
|
||||
protected double calculateErrorVariance() {
|
||||
RealVector residuals = calculateResiduals();
|
||||
double t = residuals.dotProduct(getOmegaInverse().operate(residuals));
|
||||
return t / (X.getRowDimension() - X.getColumnDimension());
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -49,7 +49,7 @@ import org.apache.commons.math.linear.RealVector;
|
|||
* (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> R b = (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> Q<sup>T</sup> y <br/>
|
||||
* R b = Q<sup>T</sup> y
|
||||
* </p>
|
||||
* Given Q and R, the last equation is solved by back-subsitution.</p>
|
||||
* Given Q and R, the last equation is solved by back-substitution.</p>
|
||||
*
|
||||
* @version $Revision$ $Date$
|
||||
* @since 2.0
|
||||
|
@ -161,19 +161,4 @@ public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
|
|||
return Rinv.multiply(Rinv.transpose());
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* <p>Calculates the variance on the Y by OLS.
|
||||
* </p>
|
||||
* <p> Var(y) = Tr(u<sup>T</sup>u)/(n - k)
|
||||
* </p>
|
||||
* @return The Y variance
|
||||
*/
|
||||
@Override
|
||||
protected double calculateYVariance() {
|
||||
RealVector residuals = calculateResiduals();
|
||||
return residuals.dotProduct(residuals) /
|
||||
(X.getRowDimension() - X.getColumnDimension());
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -52,6 +52,13 @@ The <action> type attribute can be add,update,fix,remove.
|
|||
If the output is not quite correct, check for invisible trailing spaces!
|
||||
-->
|
||||
<release version="2.2" date="TBD" description="TBD">
|
||||
<action dev="psteitz" type="fix" issue="MATH-392" due-to="Mark Devaney">
|
||||
Corrected the formula used for Y variance returned by calculateYVariance and associated
|
||||
methods in multiple regression classes (AbstractMultipleLinearRegression,
|
||||
OLSMultipleLinearRegression, GLSMultipleLinearRegression). These methods previously returned
|
||||
estimates of the variance in the model error term. New "calulateErrorVariance" methods have
|
||||
been added to compute what was previously returned by calculateYVariance.
|
||||
</action>
|
||||
<action dev="dimpbx" type="fix" issue="MATH-406">
|
||||
Bug fixed in Levenberg-Marquardt (handling of weights).
|
||||
</action>
|
||||
|
|
|
@ -32,10 +32,11 @@ options(digits=16) # override number of digits displayed
|
|||
# function to verify OLS computations
|
||||
|
||||
verifyRegression <- function(model, expectedBeta, expectedResiduals,
|
||||
expectedErrors, modelName) {
|
||||
expectedErrors, expectedStdError, modelName) {
|
||||
betaHat <- as.vector(coefficients(model))
|
||||
residuals <- as.vector(residuals(model))
|
||||
errors <- as.vector(as.matrix(coefficients(summary(model)))[,2])
|
||||
stdError <- summary(model)$sigma
|
||||
output <- c("Parameter test dataset = ", modelName)
|
||||
if (assertEquals(expectedBeta,betaHat,tol,"Parameters")) {
|
||||
displayPadded(output, SUCCEEDED, WIDTH)
|
||||
|
@ -53,7 +54,13 @@ verifyRegression <- function(model, expectedBeta, expectedResiduals,
|
|||
displayPadded(output, SUCCEEDED, WIDTH)
|
||||
} else {
|
||||
displayPadded(output, FAILED, WIDTH)
|
||||
}
|
||||
}
|
||||
output <- c("Standard Error test dataset = ", modelName)
|
||||
if (assertEquals(expectedStdError,stdError,tol,"Regression Standard Error")) {
|
||||
displayPadded(output, SUCCEEDED, WIDTH)
|
||||
} else {
|
||||
displayPadded(output, FAILED, WIDTH)
|
||||
}
|
||||
}
|
||||
|
||||
#--------------------------------------------------------------------------
|
||||
|
@ -70,7 +77,8 @@ model <- lm(y ~ x1 + x2 + x3 + x4 + x5)
|
|||
expectedBeta <- c(11.0,0.5,0.666666666666667,0.75,0.8,0.8333333333333333)
|
||||
expectedResiduals <- c(0,0,0,0,0,0)
|
||||
expectedErrors <- c(NaN,NaN,NaN,NaN,NaN,NaN)
|
||||
verifyRegression(model, expectedBeta, expectedResiduals, expectedErrors,
|
||||
expectedStdError <- NaN
|
||||
verifyRegression(model, expectedBeta, expectedResiduals, expectedErrors, expectedStdError,
|
||||
"perfect fit")
|
||||
|
||||
# Longly
|
||||
|
@ -125,8 +133,9 @@ expectedResiduals <- c( 267.340029759711,-94.0139423988359,46.28716775752924,
|
|||
-13.18035686637081,14.30477260005235,455.394094551857,-17.26892711483297,
|
||||
-39.0550425226967,-155.5499735953195,-85.6713080421283,341.9315139607727,
|
||||
-206.7578251937366)
|
||||
expectedStdError <- 304.8540735619638
|
||||
|
||||
verifyRegression(model, expectedBeta, expectedResiduals, expectedErrors,
|
||||
verifyRegression(model, expectedBeta, expectedResiduals, expectedErrors, expectedStdError,
|
||||
"Longly")
|
||||
|
||||
# Swiss Fertility (R dataset named "swiss")
|
||||
|
@ -215,7 +224,9 @@ expectedResiduals <- c(7.1044267859730512,1.6580347433531366,
|
|||
15.0147574652763112,4.8625103516321015,-7.1597256413907706,
|
||||
-0.4515205619767598,-10.2916870903837587,-15.7812984571900063)
|
||||
|
||||
verifyRegression(model, expectedBeta, expectedResiduals, expectedErrors,
|
||||
expectedStdError <- 7.73642194433223
|
||||
|
||||
verifyRegression(model, expectedBeta, expectedResiduals, expectedErrors, expectedStdError,
|
||||
"Swiss Fertility")
|
||||
|
||||
|
||||
|
|
|
@ -18,6 +18,8 @@ package org.apache.commons.math.stat.regression;
|
|||
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.apache.commons.math.TestUtils;
|
||||
import org.apache.commons.math.stat.StatUtils;
|
||||
|
||||
public class GLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbstractTest {
|
||||
|
||||
|
@ -121,4 +123,16 @@ public class GLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs
|
|||
return y.length;
|
||||
}
|
||||
|
||||
/**
|
||||
* test calculateYVariance
|
||||
*/
|
||||
@Test
|
||||
public void testYVariance() {
|
||||
|
||||
// assumes: y = new double[]{11.0, 12.0, 13.0, 14.0, 15.0, 16.0};
|
||||
|
||||
GLSMultipleLinearRegression model = new GLSMultipleLinearRegression();
|
||||
model.newSampleData(y, x, omega);
|
||||
TestUtils.assertEquals(model.calculateYVariance(), 3.5, 0);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -24,6 +24,7 @@ import org.apache.commons.math.linear.MatrixUtils;
|
|||
import org.apache.commons.math.linear.MatrixVisitorException;
|
||||
import org.apache.commons.math.linear.RealMatrix;
|
||||
import org.apache.commons.math.linear.Array2DRowRealMatrix;
|
||||
import org.apache.commons.math.stat.StatUtils;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
|
||||
|
@ -82,7 +83,7 @@ public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testPerfectFit() {
|
||||
public void testPerfectFit() throws Exception {
|
||||
double[] betaHat = regression.estimateRegressionParameters();
|
||||
TestUtils.assertEquals(betaHat,
|
||||
new double[]{ 11.0, 1.0 / 2.0, 2.0 / 3.0, 3.0 / 4.0, 4.0 / 5.0, 5.0 / 6.0 },
|
||||
|
@ -108,6 +109,7 @@ public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs
|
|||
assertEquals(0.0,
|
||||
errors.subtract(referenceVariance).getNorm(),
|
||||
5.0e-16 * referenceVariance.getNorm());
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
@ -122,7 +124,7 @@ public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs
|
|||
* http://www.itl.nist.gov/div898/strd/lls/data/LINKS/DATA/Longley.dat
|
||||
*/
|
||||
@Test
|
||||
public void testLongly() {
|
||||
public void testLongly() throws Exception {
|
||||
// Y values are first, then independent vars
|
||||
// Each row is one observation
|
||||
double[] design = new double[] {
|
||||
|
@ -180,6 +182,11 @@ public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs
|
|||
0.214274163161675,
|
||||
0.226073200069370,
|
||||
455.478499142212}, errors, 1E-6);
|
||||
|
||||
// Check regression standard error against R
|
||||
assertEquals(304.8540735619638, model.estimateRegressionStandardError(), 1E-10);
|
||||
|
||||
checkVarianceConsistency(model);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -187,7 +194,7 @@ public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs
|
|||
* Data Source: R datasets package
|
||||
*/
|
||||
@Test
|
||||
public void testSwissFertility() {
|
||||
public void testSwissFertility() throws Exception {
|
||||
double[] design = new double[] {
|
||||
80.2,17.0,15,12,9.96,
|
||||
83.1,45.1,6,9,84.84,
|
||||
|
@ -283,6 +290,11 @@ public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs
|
|||
0.27410957467466,
|
||||
0.19454551679325,
|
||||
0.03726654773803}, errors, 1E-10);
|
||||
|
||||
// Check regression standard error against R
|
||||
assertEquals(7.73642194433223, model.estimateRegressionStandardError(), 1E-12);
|
||||
|
||||
checkVarianceConsistency(model);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -354,4 +366,34 @@ public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs
|
|||
double[] hatResiduals = I.subtract(hat).operate(model.Y).getData();
|
||||
TestUtils.assertEquals(residuals, hatResiduals, 10e-12);
|
||||
}
|
||||
|
||||
/**
|
||||
* test calculateYVariance
|
||||
*/
|
||||
@Test
|
||||
public void testYVariance() {
|
||||
|
||||
// assumes: y = new double[]{11.0, 12.0, 13.0, 14.0, 15.0, 16.0};
|
||||
|
||||
OLSMultipleLinearRegression model = new OLSMultipleLinearRegression();
|
||||
model.newSampleData(y, x);
|
||||
TestUtils.assertEquals(model.calculateYVariance(), 3.5, 0);
|
||||
}
|
||||
|
||||
/**
|
||||
* Verifies that calculateYVariance and calculateResidualVariance return consistent
|
||||
* values with direct variance computation from Y, residuals, respectively.
|
||||
*/
|
||||
protected void checkVarianceConsistency(OLSMultipleLinearRegression model) throws Exception {
|
||||
// Check Y variance consistency
|
||||
TestUtils.assertEquals(StatUtils.variance(model.Y.getData()), model.calculateYVariance(), 0);
|
||||
|
||||
// Check residual variance consistency
|
||||
double[] residuals = model.calculateResiduals().getData();
|
||||
RealMatrix X = model.X;
|
||||
TestUtils.assertEquals(
|
||||
StatUtils.variance(model.calculateResiduals().getData()) * (residuals.length - 1),
|
||||
model.calculateErrorVariance() * (X.getRowDimension() - X.getColumnDimension()), 1E-20);
|
||||
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue