diff --git a/src/main/java/org/apache/commons/math/stat/regression/SimpleRegression.java b/src/main/java/org/apache/commons/math/stat/regression/SimpleRegression.java index 08ba5dad6..ccb972d94 100644 --- a/src/main/java/org/apache/commons/math/stat/regression/SimpleRegression.java +++ b/src/main/java/org/apache/commons/math/stat/regression/SimpleRegression.java @@ -84,13 +84,23 @@ public class SimpleRegression implements Serializable { /** mean of accumulated y values, used in updating formulas */ private double ybar = 0; + /** include an intercept or not */ + private final boolean hasIntercept; // ---------------------Public methods-------------------------------------- /** * Create an empty SimpleRegression instance */ public SimpleRegression() { + this(true); + } + /** + * Secondary constructor which allows the user the ability to include/exclude const + * @param includeIntercept boolean flag, true includes an intercept + */ + public SimpleRegression(boolean includeIntercept){ super(); + hasIntercept = includeIntercept; } /** @@ -106,22 +116,32 @@ public class SimpleRegression implements Serializable { * @param x independent variable value * @param y dependent variable value */ - public void addData(double x, double y) { + public void addData(final double x, final double y){ if (n == 0) { xbar = x; ybar = y; } else { - double dx = x - xbar; - double dy = y - ybar; - sumXX += dx * dx * (double) n / (n + 1d); - sumYY += dy * dy * (double) n / (n + 1d); - sumXY += dx * dy * (double) n / (n + 1d); - xbar += dx / (n + 1.0); - ybar += dy / (n + 1.0); + if( hasIntercept ){ + final double fact1 = 1.0 + (double) n; + final double fact2 = ((double) n) / (1.0 + (double) n); + final double dx = x - xbar; + final double dy = y - ybar; + sumXX += dx * dx * fact2; + sumYY += dy * dy * fact2; + sumXY += dx * dy * fact2; + xbar += dx / fact1; + ybar += dy / fact1; + } + } + if( !hasIntercept ){ + sumXX += x * x ; + sumYY += y * y ; + sumXY += x * y ; } sumX += x; sumY += y; n++; + return; } @@ -140,17 +160,29 @@ public class SimpleRegression implements Serializable { */ public void removeData(double x, double y) { if (n > 0) { - double dx = x - xbar; - double dy = y - ybar; - sumXX -= dx * dx * (double) n / (n - 1d); - sumYY -= dy * dy * (double) n / (n - 1d); - sumXY -= dx * dy * (double) n / (n - 1d); - xbar -= dx / (n - 1.0); - ybar -= dy / (n - 1.0); - sumX -= x; - sumY -= y; - n--; + if (hasIntercept) { + final double fact1 = (double) n - 1.0; + final double fact2 = ((double) n) / ((double) n - 1.0); + final double dx = x - xbar; + final double dy = y - ybar; + sumXX -= dx * dx * fact2; + sumYY -= dy * dy * fact2; + sumXY -= dx * dy * fact2; + xbar -= dx / fact1; + ybar -= dy / fact1; + } else { + final double fact1 = (double) n - 1.0; + sumXX -= x * x; + sumYY -= y * y; + sumXY -= x * y; + xbar -= x / fact1; + ybar -= y / fact1; + } + sumX -= x; + sumY -= y; + n--; } + return; } /** @@ -235,7 +267,10 @@ public class SimpleRegression implements Serializable { */ public double predict(double x) { double b1 = getSlope(); - return getIntercept(b1) + b1 * x; + if (hasIntercept) { + return getIntercept(b1) + b1 * x; + } + return b1 * x; } /** @@ -255,7 +290,16 @@ public class SimpleRegression implements Serializable { * @return the intercept of the regression line */ public double getIntercept() { - return getIntercept(getSlope()); + return hasIntercept ? getIntercept(getSlope()) : 0.0; + } + + /** + * Returns true if a constant has been included false otherwise. + * + * @return true if constant exists, false otherwise + */ + public boolean hasIntercept(){ + return hasIntercept; } /** @@ -391,7 +435,7 @@ public class SimpleRegression implements Serializable { if (n < 3) { return Double.NaN; } - return getSumSquaredErrors() / (n - 2); + return hasIntercept ? (getSumSquaredErrors() / (n - 2)) : (getSumSquaredErrors() / (n - 1)); } /** @@ -443,11 +487,15 @@ public class SimpleRegression implements Serializable { *

* If there are fewer that three observations in the * model, or if there is no variation in x, this returns - * Double.NaN.

+ * Double.NaN.

Additionally, a Double.NaN is + * returned when the intercept is constrained to be zero * * @return standard error associated with intercept estimate */ public double getInterceptStdErr() { + if( !hasIntercept ){ + return Double.NaN; + } return FastMath.sqrt( getMeanSquareError() * ((1d / (double) n) + (xbar * xbar) / sumXX)); } @@ -572,8 +620,11 @@ public class SimpleRegression implements Serializable { * @param slope current slope * @return the intercept of the regression line */ - private double getIntercept(double slope) { + private double getIntercept(double slope){ + if( hasIntercept){ return (sumY - slope * sumX) / n; + } + return 0.0; } /** diff --git a/src/test/java/org/apache/commons/math/stat/regression/SimpleRegressionTest.java b/src/test/java/org/apache/commons/math/stat/regression/SimpleRegressionTest.java index bccb2bc61..2789f4256 100644 --- a/src/test/java/org/apache/commons/math/stat/regression/SimpleRegressionTest.java +++ b/src/test/java/org/apache/commons/math/stat/regression/SimpleRegressionTest.java @@ -80,6 +80,76 @@ public final class SimpleRegressionTest { {5, -1 }, {6, 12 } }; + + /* + * Data from NIST NOINT1 + */ + private double[][] noint1 = { + {130.0,60.0}, + {131.0,61.0}, + {132.0,62.0}, + {133.0,63.0}, + {134.0,64.0}, + {135.0,65.0}, + {136.0,66.0}, + {137.0,67.0}, + {138.0,68.0}, + {139.0,69.0}, + {140.0,70.0} + }; + + /* + * Data from NIST NOINT2 + * + */ + private double[][] noint2 = { + {3.0,4}, + {4,5}, + {4,6} + }; + + @Test + public void testNoInterceot_noint2(){ + SimpleRegression regression = new SimpleRegression(false); + regression.addData(noint2[0][1], noint2[0][0]); + regression.addData(noint2[1][1], noint2[1][0]); + regression.addData(noint2[2][1], noint2[2][0]); + Assert.assertEquals("slope", 0.727272727272727, + regression.getSlope(), 10E-12); + Assert.assertEquals("slope std err", 0.420827318078432E-01, + regression.getSlopeStdErr(),10E-12); + Assert.assertEquals("number of observations", 3, regression.getN()); + Assert.assertEquals("r-square", 0.993348115299335, + regression.getRSquare(), 10E-12); + Assert.assertEquals("SSR", 40.7272727272727, + regression.getRegressionSumSquares(), 10E-9); + Assert.assertEquals("MSE", 0.136363636363636, + regression.getMeanSquareError(), 10E-10); + Assert.assertEquals("SSE", 0.272727272727273, + regression.getSumSquaredErrors(),10E-9); + } + + @Test + public void testNoIntercept_noint1(){ + SimpleRegression regression = new SimpleRegression(false); + for (int i = 0; i < noint1.length; i++) { + regression.addData(noint1[i][1], noint1[i][0]); + } + Assert.assertEquals("slope", 2.07438016528926, regression.getSlope(), 10E-12); + Assert.assertEquals("slope std err", 0.165289256198347E-01, + regression.getSlopeStdErr(),10E-12); + Assert.assertEquals("number of observations", 11, regression.getN()); + Assert.assertEquals("r-square", 0.999365492298663, + regression.getRSquare(), 10E-12); + Assert.assertEquals("SSR", 200457.727272727, + regression.getRegressionSumSquares(), 10E-9); + Assert.assertEquals("MSE", 12.7272727272727, + regression.getMeanSquareError(), 10E-10); + Assert.assertEquals("SSE", 127.272727272727, + regression.getSumSquaredErrors(),10E-9); + + } + @Test public void testNorris() { SimpleRegression regression = new SimpleRegression();