(MATH-649) SimpleRegression needs the ability to suppress the intercept

This commit pushes changes to allow the estimation of the a regression in which the intercept is constrained to be zero. I am also pushing two unit tests. 

git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1167451 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Greg Sterijevski 2011-09-10 04:18:31 +00:00
parent 8929c05521
commit 88ced96cd0
2 changed files with 144 additions and 23 deletions

View File

@ -84,13 +84,23 @@ public class SimpleRegression implements Serializable {
/** mean of accumulated y values, used in updating formulas */
private double ybar = 0;
/** include an intercept or not */
private final boolean hasIntercept;
// ---------------------Public methods--------------------------------------
/**
* Create an empty SimpleRegression instance
*/
public SimpleRegression() {
this(true);
}
/**
* Secondary constructor which allows the user the ability to include/exclude const
* @param includeIntercept boolean flag, true includes an intercept
*/
public SimpleRegression(boolean includeIntercept){
super();
hasIntercept = includeIntercept;
}
/**
@ -106,22 +116,32 @@ public class SimpleRegression implements Serializable {
* @param x independent variable value
* @param y dependent variable value
*/
public void addData(double x, double y) {
public void addData(final double x, final double y){
if (n == 0) {
xbar = x;
ybar = y;
} else {
double dx = x - xbar;
double dy = y - ybar;
sumXX += dx * dx * (double) n / (n + 1d);
sumYY += dy * dy * (double) n / (n + 1d);
sumXY += dx * dy * (double) n / (n + 1d);
xbar += dx / (n + 1.0);
ybar += dy / (n + 1.0);
if( hasIntercept ){
final double fact1 = 1.0 + (double) n;
final double fact2 = ((double) n) / (1.0 + (double) n);
final double dx = x - xbar;
final double dy = y - ybar;
sumXX += dx * dx * fact2;
sumYY += dy * dy * fact2;
sumXY += dx * dy * fact2;
xbar += dx / fact1;
ybar += dy / fact1;
}
}
if( !hasIntercept ){
sumXX += x * x ;
sumYY += y * y ;
sumXY += x * y ;
}
sumX += x;
sumY += y;
n++;
return;
}
@ -140,17 +160,29 @@ public class SimpleRegression implements Serializable {
*/
public void removeData(double x, double y) {
if (n > 0) {
double dx = x - xbar;
double dy = y - ybar;
sumXX -= dx * dx * (double) n / (n - 1d);
sumYY -= dy * dy * (double) n / (n - 1d);
sumXY -= dx * dy * (double) n / (n - 1d);
xbar -= dx / (n - 1.0);
ybar -= dy / (n - 1.0);
if (hasIntercept) {
final double fact1 = (double) n - 1.0;
final double fact2 = ((double) n) / ((double) n - 1.0);
final double dx = x - xbar;
final double dy = y - ybar;
sumXX -= dx * dx * fact2;
sumYY -= dy * dy * fact2;
sumXY -= dx * dy * fact2;
xbar -= dx / fact1;
ybar -= dy / fact1;
} else {
final double fact1 = (double) n - 1.0;
sumXX -= x * x;
sumYY -= y * y;
sumXY -= x * y;
xbar -= x / fact1;
ybar -= y / fact1;
}
sumX -= x;
sumY -= y;
n--;
}
return;
}
/**
@ -235,8 +267,11 @@ public class SimpleRegression implements Serializable {
*/
public double predict(double x) {
double b1 = getSlope();
if (hasIntercept) {
return getIntercept(b1) + b1 * x;
}
return b1 * x;
}
/**
* Returns the intercept of the estimated regression line.
@ -255,7 +290,16 @@ public class SimpleRegression implements Serializable {
* @return the intercept of the regression line
*/
public double getIntercept() {
return getIntercept(getSlope());
return hasIntercept ? getIntercept(getSlope()) : 0.0;
}
/**
* Returns true if a constant has been included false otherwise.
*
* @return true if constant exists, false otherwise
*/
public boolean hasIntercept(){
return hasIntercept;
}
/**
@ -391,7 +435,7 @@ public class SimpleRegression implements Serializable {
if (n < 3) {
return Double.NaN;
}
return getSumSquaredErrors() / (n - 2);
return hasIntercept ? (getSumSquaredErrors() / (n - 2)) : (getSumSquaredErrors() / (n - 1));
}
/**
@ -443,11 +487,15 @@ public class SimpleRegression implements Serializable {
* <p>
* If there are fewer that <strong>three</strong> observations in the
* model, or if there is no variation in x, this returns
* <code>Double.NaN</code>.</p>
* <code>Double.NaN</code>.</p> Additionally, a <code>Double.NaN</code> is
* returned when the intercept is constrained to be zero
*
* @return standard error associated with intercept estimate
*/
public double getInterceptStdErr() {
if( !hasIntercept ){
return Double.NaN;
}
return FastMath.sqrt(
getMeanSquareError() * ((1d / (double) n) + (xbar * xbar) / sumXX));
}
@ -573,8 +621,11 @@ public class SimpleRegression implements Serializable {
* @return the intercept of the regression line
*/
private double getIntercept(double slope){
if( hasIntercept){
return (sumY - slope * sumX) / n;
}
return 0.0;
}
/**
* Computes SSR from b1.

View File

@ -80,6 +80,76 @@ public final class SimpleRegressionTest {
{5, -1 }, {6, 12 }
};
/*
* Data from NIST NOINT1
*/
private double[][] noint1 = {
{130.0,60.0},
{131.0,61.0},
{132.0,62.0},
{133.0,63.0},
{134.0,64.0},
{135.0,65.0},
{136.0,66.0},
{137.0,67.0},
{138.0,68.0},
{139.0,69.0},
{140.0,70.0}
};
/*
* Data from NIST NOINT2
*
*/
private double[][] noint2 = {
{3.0,4},
{4,5},
{4,6}
};
@Test
public void testNoInterceot_noint2(){
SimpleRegression regression = new SimpleRegression(false);
regression.addData(noint2[0][1], noint2[0][0]);
regression.addData(noint2[1][1], noint2[1][0]);
regression.addData(noint2[2][1], noint2[2][0]);
Assert.assertEquals("slope", 0.727272727272727,
regression.getSlope(), 10E-12);
Assert.assertEquals("slope std err", 0.420827318078432E-01,
regression.getSlopeStdErr(),10E-12);
Assert.assertEquals("number of observations", 3, regression.getN());
Assert.assertEquals("r-square", 0.993348115299335,
regression.getRSquare(), 10E-12);
Assert.assertEquals("SSR", 40.7272727272727,
regression.getRegressionSumSquares(), 10E-9);
Assert.assertEquals("MSE", 0.136363636363636,
regression.getMeanSquareError(), 10E-10);
Assert.assertEquals("SSE", 0.272727272727273,
regression.getSumSquaredErrors(),10E-9);
}
@Test
public void testNoIntercept_noint1(){
SimpleRegression regression = new SimpleRegression(false);
for (int i = 0; i < noint1.length; i++) {
regression.addData(noint1[i][1], noint1[i][0]);
}
Assert.assertEquals("slope", 2.07438016528926, regression.getSlope(), 10E-12);
Assert.assertEquals("slope std err", 0.165289256198347E-01,
regression.getSlopeStdErr(),10E-12);
Assert.assertEquals("number of observations", 11, regression.getN());
Assert.assertEquals("r-square", 0.999365492298663,
regression.getRSquare(), 10E-12);
Assert.assertEquals("SSR", 200457.727272727,
regression.getRegressionSumSquares(), 10E-9);
Assert.assertEquals("MSE", 12.7272727272727,
regression.getMeanSquareError(), 10E-10);
Assert.assertEquals("SSE", 127.272727272727,
regression.getSumSquaredErrors(),10E-9);
}
@Test
public void testNorris() {
SimpleRegression regression = new SimpleRegression();