diff --git a/src/main/java/org/apache/commons/math/stat/regression/SimpleRegression.java b/src/main/java/org/apache/commons/math/stat/regression/SimpleRegression.java index 08ba5dad6..ccb972d94 100644 --- a/src/main/java/org/apache/commons/math/stat/regression/SimpleRegression.java +++ b/src/main/java/org/apache/commons/math/stat/regression/SimpleRegression.java @@ -84,13 +84,23 @@ public class SimpleRegression implements Serializable { /** mean of accumulated y values, used in updating formulas */ private double ybar = 0; + /** include an intercept or not */ + private final boolean hasIntercept; // ---------------------Public methods-------------------------------------- /** * Create an empty SimpleRegression instance */ public SimpleRegression() { + this(true); + } + /** + * Secondary constructor which allows the user the ability to include/exclude const + * @param includeIntercept boolean flag, true includes an intercept + */ + public SimpleRegression(boolean includeIntercept){ super(); + hasIntercept = includeIntercept; } /** @@ -106,22 +116,32 @@ public class SimpleRegression implements Serializable { * @param x independent variable value * @param y dependent variable value */ - public void addData(double x, double y) { + public void addData(final double x, final double y){ if (n == 0) { xbar = x; ybar = y; } else { - double dx = x - xbar; - double dy = y - ybar; - sumXX += dx * dx * (double) n / (n + 1d); - sumYY += dy * dy * (double) n / (n + 1d); - sumXY += dx * dy * (double) n / (n + 1d); - xbar += dx / (n + 1.0); - ybar += dy / (n + 1.0); + if( hasIntercept ){ + final double fact1 = 1.0 + (double) n; + final double fact2 = ((double) n) / (1.0 + (double) n); + final double dx = x - xbar; + final double dy = y - ybar; + sumXX += dx * dx * fact2; + sumYY += dy * dy * fact2; + sumXY += dx * dy * fact2; + xbar += dx / fact1; + ybar += dy / fact1; + } + } + if( !hasIntercept ){ + sumXX += x * x ; + sumYY += y * y ; + sumXY += x * y ; } sumX += x; sumY += y; n++; + return; } @@ -140,17 +160,29 @@ public class SimpleRegression implements Serializable { */ public void removeData(double x, double y) { if (n > 0) { - double dx = x - xbar; - double dy = y - ybar; - sumXX -= dx * dx * (double) n / (n - 1d); - sumYY -= dy * dy * (double) n / (n - 1d); - sumXY -= dx * dy * (double) n / (n - 1d); - xbar -= dx / (n - 1.0); - ybar -= dy / (n - 1.0); - sumX -= x; - sumY -= y; - n--; + if (hasIntercept) { + final double fact1 = (double) n - 1.0; + final double fact2 = ((double) n) / ((double) n - 1.0); + final double dx = x - xbar; + final double dy = y - ybar; + sumXX -= dx * dx * fact2; + sumYY -= dy * dy * fact2; + sumXY -= dx * dy * fact2; + xbar -= dx / fact1; + ybar -= dy / fact1; + } else { + final double fact1 = (double) n - 1.0; + sumXX -= x * x; + sumYY -= y * y; + sumXY -= x * y; + xbar -= x / fact1; + ybar -= y / fact1; + } + sumX -= x; + sumY -= y; + n--; } + return; } /** @@ -235,7 +267,10 @@ public class SimpleRegression implements Serializable { */ public double predict(double x) { double b1 = getSlope(); - return getIntercept(b1) + b1 * x; + if (hasIntercept) { + return getIntercept(b1) + b1 * x; + } + return b1 * x; } /** @@ -255,7 +290,16 @@ public class SimpleRegression implements Serializable { * @return the intercept of the regression line */ public double getIntercept() { - return getIntercept(getSlope()); + return hasIntercept ? getIntercept(getSlope()) : 0.0; + } + + /** + * Returns true if a constant has been included false otherwise. + * + * @return true if constant exists, false otherwise + */ + public boolean hasIntercept(){ + return hasIntercept; } /** @@ -391,7 +435,7 @@ public class SimpleRegression implements Serializable { if (n < 3) { return Double.NaN; } - return getSumSquaredErrors() / (n - 2); + return hasIntercept ? (getSumSquaredErrors() / (n - 2)) : (getSumSquaredErrors() / (n - 1)); } /** @@ -443,11 +487,15 @@ public class SimpleRegression implements Serializable { *
* If there are fewer that three observations in the
* model, or if there is no variation in x, this returns
- * Double.NaN
.
Double.NaN
. Additionally, a Double.NaN
is
+ * returned when the intercept is constrained to be zero
*
* @return standard error associated with intercept estimate
*/
public double getInterceptStdErr() {
+ if( !hasIntercept ){
+ return Double.NaN;
+ }
return FastMath.sqrt(
getMeanSquareError() * ((1d / (double) n) + (xbar * xbar) / sumXX));
}
@@ -572,8 +620,11 @@ public class SimpleRegression implements Serializable {
* @param slope current slope
* @return the intercept of the regression line
*/
- private double getIntercept(double slope) {
+ private double getIntercept(double slope){
+ if( hasIntercept){
return (sumY - slope * sumX) / n;
+ }
+ return 0.0;
}
/**
diff --git a/src/test/java/org/apache/commons/math/stat/regression/SimpleRegressionTest.java b/src/test/java/org/apache/commons/math/stat/regression/SimpleRegressionTest.java
index bccb2bc61..2789f4256 100644
--- a/src/test/java/org/apache/commons/math/stat/regression/SimpleRegressionTest.java
+++ b/src/test/java/org/apache/commons/math/stat/regression/SimpleRegressionTest.java
@@ -80,6 +80,76 @@ public final class SimpleRegressionTest {
{5, -1 }, {6, 12 }
};
+
+ /*
+ * Data from NIST NOINT1
+ */
+ private double[][] noint1 = {
+ {130.0,60.0},
+ {131.0,61.0},
+ {132.0,62.0},
+ {133.0,63.0},
+ {134.0,64.0},
+ {135.0,65.0},
+ {136.0,66.0},
+ {137.0,67.0},
+ {138.0,68.0},
+ {139.0,69.0},
+ {140.0,70.0}
+ };
+
+ /*
+ * Data from NIST NOINT2
+ *
+ */
+ private double[][] noint2 = {
+ {3.0,4},
+ {4,5},
+ {4,6}
+ };
+
+ @Test
+ public void testNoInterceot_noint2(){
+ SimpleRegression regression = new SimpleRegression(false);
+ regression.addData(noint2[0][1], noint2[0][0]);
+ regression.addData(noint2[1][1], noint2[1][0]);
+ regression.addData(noint2[2][1], noint2[2][0]);
+ Assert.assertEquals("slope", 0.727272727272727,
+ regression.getSlope(), 10E-12);
+ Assert.assertEquals("slope std err", 0.420827318078432E-01,
+ regression.getSlopeStdErr(),10E-12);
+ Assert.assertEquals("number of observations", 3, regression.getN());
+ Assert.assertEquals("r-square", 0.993348115299335,
+ regression.getRSquare(), 10E-12);
+ Assert.assertEquals("SSR", 40.7272727272727,
+ regression.getRegressionSumSquares(), 10E-9);
+ Assert.assertEquals("MSE", 0.136363636363636,
+ regression.getMeanSquareError(), 10E-10);
+ Assert.assertEquals("SSE", 0.272727272727273,
+ regression.getSumSquaredErrors(),10E-9);
+ }
+
+ @Test
+ public void testNoIntercept_noint1(){
+ SimpleRegression regression = new SimpleRegression(false);
+ for (int i = 0; i < noint1.length; i++) {
+ regression.addData(noint1[i][1], noint1[i][0]);
+ }
+ Assert.assertEquals("slope", 2.07438016528926, regression.getSlope(), 10E-12);
+ Assert.assertEquals("slope std err", 0.165289256198347E-01,
+ regression.getSlopeStdErr(),10E-12);
+ Assert.assertEquals("number of observations", 11, regression.getN());
+ Assert.assertEquals("r-square", 0.999365492298663,
+ regression.getRSquare(), 10E-12);
+ Assert.assertEquals("SSR", 200457.727272727,
+ regression.getRegressionSumSquares(), 10E-9);
+ Assert.assertEquals("MSE", 12.7272727272727,
+ regression.getMeanSquareError(), 10E-10);
+ Assert.assertEquals("SSE", 127.272727272727,
+ regression.getSumSquaredErrors(),10E-9);
+
+ }
+
@Test
public void testNorris() {
SimpleRegression regression = new SimpleRegression();