(MATH-649) SimpleRegression needs the ability to suppress the intercept
This commit pushes changes to allow the estimation of the a regression in which the intercept is constrained to be zero. I am also pushing two unit tests. git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1167451 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
8929c05521
commit
88ced96cd0
|
@ -84,13 +84,23 @@ public class SimpleRegression implements Serializable {
|
|||
/** mean of accumulated y values, used in updating formulas */
|
||||
private double ybar = 0;
|
||||
|
||||
/** include an intercept or not */
|
||||
private final boolean hasIntercept;
|
||||
// ---------------------Public methods--------------------------------------
|
||||
|
||||
/**
|
||||
* Create an empty SimpleRegression instance
|
||||
*/
|
||||
public SimpleRegression() {
|
||||
this(true);
|
||||
}
|
||||
/**
|
||||
* Secondary constructor which allows the user the ability to include/exclude const
|
||||
* @param includeIntercept boolean flag, true includes an intercept
|
||||
*/
|
||||
public SimpleRegression(boolean includeIntercept){
|
||||
super();
|
||||
hasIntercept = includeIntercept;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -106,22 +116,32 @@ public class SimpleRegression implements Serializable {
|
|||
* @param x independent variable value
|
||||
* @param y dependent variable value
|
||||
*/
|
||||
public void addData(double x, double y) {
|
||||
public void addData(final double x, final double y){
|
||||
if (n == 0) {
|
||||
xbar = x;
|
||||
ybar = y;
|
||||
} else {
|
||||
double dx = x - xbar;
|
||||
double dy = y - ybar;
|
||||
sumXX += dx * dx * (double) n / (n + 1d);
|
||||
sumYY += dy * dy * (double) n / (n + 1d);
|
||||
sumXY += dx * dy * (double) n / (n + 1d);
|
||||
xbar += dx / (n + 1.0);
|
||||
ybar += dy / (n + 1.0);
|
||||
if( hasIntercept ){
|
||||
final double fact1 = 1.0 + (double) n;
|
||||
final double fact2 = ((double) n) / (1.0 + (double) n);
|
||||
final double dx = x - xbar;
|
||||
final double dy = y - ybar;
|
||||
sumXX += dx * dx * fact2;
|
||||
sumYY += dy * dy * fact2;
|
||||
sumXY += dx * dy * fact2;
|
||||
xbar += dx / fact1;
|
||||
ybar += dy / fact1;
|
||||
}
|
||||
}
|
||||
if( !hasIntercept ){
|
||||
sumXX += x * x ;
|
||||
sumYY += y * y ;
|
||||
sumXY += x * y ;
|
||||
}
|
||||
sumX += x;
|
||||
sumY += y;
|
||||
n++;
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
|
@ -140,17 +160,29 @@ public class SimpleRegression implements Serializable {
|
|||
*/
|
||||
public void removeData(double x, double y) {
|
||||
if (n > 0) {
|
||||
double dx = x - xbar;
|
||||
double dy = y - ybar;
|
||||
sumXX -= dx * dx * (double) n / (n - 1d);
|
||||
sumYY -= dy * dy * (double) n / (n - 1d);
|
||||
sumXY -= dx * dy * (double) n / (n - 1d);
|
||||
xbar -= dx / (n - 1.0);
|
||||
ybar -= dy / (n - 1.0);
|
||||
sumX -= x;
|
||||
sumY -= y;
|
||||
n--;
|
||||
if (hasIntercept) {
|
||||
final double fact1 = (double) n - 1.0;
|
||||
final double fact2 = ((double) n) / ((double) n - 1.0);
|
||||
final double dx = x - xbar;
|
||||
final double dy = y - ybar;
|
||||
sumXX -= dx * dx * fact2;
|
||||
sumYY -= dy * dy * fact2;
|
||||
sumXY -= dx * dy * fact2;
|
||||
xbar -= dx / fact1;
|
||||
ybar -= dy / fact1;
|
||||
} else {
|
||||
final double fact1 = (double) n - 1.0;
|
||||
sumXX -= x * x;
|
||||
sumYY -= y * y;
|
||||
sumXY -= x * y;
|
||||
xbar -= x / fact1;
|
||||
ybar -= y / fact1;
|
||||
}
|
||||
sumX -= x;
|
||||
sumY -= y;
|
||||
n--;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -235,7 +267,10 @@ public class SimpleRegression implements Serializable {
|
|||
*/
|
||||
public double predict(double x) {
|
||||
double b1 = getSlope();
|
||||
return getIntercept(b1) + b1 * x;
|
||||
if (hasIntercept) {
|
||||
return getIntercept(b1) + b1 * x;
|
||||
}
|
||||
return b1 * x;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -255,7 +290,16 @@ public class SimpleRegression implements Serializable {
|
|||
* @return the intercept of the regression line
|
||||
*/
|
||||
public double getIntercept() {
|
||||
return getIntercept(getSlope());
|
||||
return hasIntercept ? getIntercept(getSlope()) : 0.0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true if a constant has been included false otherwise.
|
||||
*
|
||||
* @return true if constant exists, false otherwise
|
||||
*/
|
||||
public boolean hasIntercept(){
|
||||
return hasIntercept;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -391,7 +435,7 @@ public class SimpleRegression implements Serializable {
|
|||
if (n < 3) {
|
||||
return Double.NaN;
|
||||
}
|
||||
return getSumSquaredErrors() / (n - 2);
|
||||
return hasIntercept ? (getSumSquaredErrors() / (n - 2)) : (getSumSquaredErrors() / (n - 1));
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -443,11 +487,15 @@ public class SimpleRegression implements Serializable {
|
|||
* <p>
|
||||
* If there are fewer that <strong>three</strong> observations in the
|
||||
* model, or if there is no variation in x, this returns
|
||||
* <code>Double.NaN</code>.</p>
|
||||
* <code>Double.NaN</code>.</p> Additionally, a <code>Double.NaN</code> is
|
||||
* returned when the intercept is constrained to be zero
|
||||
*
|
||||
* @return standard error associated with intercept estimate
|
||||
*/
|
||||
public double getInterceptStdErr() {
|
||||
if( !hasIntercept ){
|
||||
return Double.NaN;
|
||||
}
|
||||
return FastMath.sqrt(
|
||||
getMeanSquareError() * ((1d / (double) n) + (xbar * xbar) / sumXX));
|
||||
}
|
||||
|
@ -572,8 +620,11 @@ public class SimpleRegression implements Serializable {
|
|||
* @param slope current slope
|
||||
* @return the intercept of the regression line
|
||||
*/
|
||||
private double getIntercept(double slope) {
|
||||
private double getIntercept(double slope){
|
||||
if( hasIntercept){
|
||||
return (sumY - slope * sumX) / n;
|
||||
}
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -80,6 +80,76 @@ public final class SimpleRegressionTest {
|
|||
{5, -1 }, {6, 12 }
|
||||
};
|
||||
|
||||
|
||||
/*
|
||||
* Data from NIST NOINT1
|
||||
*/
|
||||
private double[][] noint1 = {
|
||||
{130.0,60.0},
|
||||
{131.0,61.0},
|
||||
{132.0,62.0},
|
||||
{133.0,63.0},
|
||||
{134.0,64.0},
|
||||
{135.0,65.0},
|
||||
{136.0,66.0},
|
||||
{137.0,67.0},
|
||||
{138.0,68.0},
|
||||
{139.0,69.0},
|
||||
{140.0,70.0}
|
||||
};
|
||||
|
||||
/*
|
||||
* Data from NIST NOINT2
|
||||
*
|
||||
*/
|
||||
private double[][] noint2 = {
|
||||
{3.0,4},
|
||||
{4,5},
|
||||
{4,6}
|
||||
};
|
||||
|
||||
@Test
|
||||
public void testNoInterceot_noint2(){
|
||||
SimpleRegression regression = new SimpleRegression(false);
|
||||
regression.addData(noint2[0][1], noint2[0][0]);
|
||||
regression.addData(noint2[1][1], noint2[1][0]);
|
||||
regression.addData(noint2[2][1], noint2[2][0]);
|
||||
Assert.assertEquals("slope", 0.727272727272727,
|
||||
regression.getSlope(), 10E-12);
|
||||
Assert.assertEquals("slope std err", 0.420827318078432E-01,
|
||||
regression.getSlopeStdErr(),10E-12);
|
||||
Assert.assertEquals("number of observations", 3, regression.getN());
|
||||
Assert.assertEquals("r-square", 0.993348115299335,
|
||||
regression.getRSquare(), 10E-12);
|
||||
Assert.assertEquals("SSR", 40.7272727272727,
|
||||
regression.getRegressionSumSquares(), 10E-9);
|
||||
Assert.assertEquals("MSE", 0.136363636363636,
|
||||
regression.getMeanSquareError(), 10E-10);
|
||||
Assert.assertEquals("SSE", 0.272727272727273,
|
||||
regression.getSumSquaredErrors(),10E-9);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNoIntercept_noint1(){
|
||||
SimpleRegression regression = new SimpleRegression(false);
|
||||
for (int i = 0; i < noint1.length; i++) {
|
||||
regression.addData(noint1[i][1], noint1[i][0]);
|
||||
}
|
||||
Assert.assertEquals("slope", 2.07438016528926, regression.getSlope(), 10E-12);
|
||||
Assert.assertEquals("slope std err", 0.165289256198347E-01,
|
||||
regression.getSlopeStdErr(),10E-12);
|
||||
Assert.assertEquals("number of observations", 11, regression.getN());
|
||||
Assert.assertEquals("r-square", 0.999365492298663,
|
||||
regression.getRSquare(), 10E-12);
|
||||
Assert.assertEquals("SSR", 200457.727272727,
|
||||
regression.getRegressionSumSquares(), 10E-9);
|
||||
Assert.assertEquals("MSE", 12.7272727272727,
|
||||
regression.getMeanSquareError(), 10E-10);
|
||||
Assert.assertEquals("SSE", 127.272727272727,
|
||||
regression.getSumSquaredErrors(),10E-9);
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNorris() {
|
||||
SimpleRegression regression = new SimpleRegression();
|
||||
|
|
Loading…
Reference in New Issue