diff --git a/src/java/org/apache/commons/math/stat/BivariateRegression.java b/src/java/org/apache/commons/math/stat/BivariateRegression.java index 619701800..865cff5a4 100644 --- a/src/java/org/apache/commons/math/stat/BivariateRegression.java +++ b/src/java/org/apache/commons/math/stat/BivariateRegression.java @@ -85,21 +85,21 @@ import org.apache.commons.math.stat.distribution.TDistribution; * * * @author Phil Steitz - * @version $Revision: 1.2 $ $Date: 2003/06/11 03:33:05 $ + * @version $Revision: 1.3 $ $Date: 2003/06/21 02:13:41 $ */ public class BivariateRegression { /** sum of x values */ private double sumX = 0d; - /** sum of squared x values */ - private double sumSqX = 0d; + /** total variation in x (sum of squared deviations from xbar) */ + private double sumXX = 0d; /** sum of y values */ private double sumY = 0d; - /** sum of squared y values */ - private double sumSqY = 0d; + /** total variation in y (sum of squared deviations from ybar) */ + private double sumYY = 0d; /** sum of products */ private double sumXY = 0d; @@ -107,20 +107,41 @@ public class BivariateRegression { /** number of observations */ private long n = 0; + /** mean of accumulated x values, used in updating formulas */ + private double xbar = 0; + + /** mean of accumulated y values, used in updating formulas */ + private double ybar = 0; + + // ---------------------Public methods-------------------------------------- /** - * Adds the observation (x,y) to the regression data set + * Adds the observation (x,y) to the regression data set. + *

+ * Uses updating formulas for means and sums of squares defined in + * "Algorithms for Computing the Sample Variance: Analysis and + * Recommendations", Chan, T.F., Golub, G.H., and LeVeque, R.J. + * 1983, American Statistician, vol. 37, pp. 242-247, referenced in + * Weisberg, S. "Applied Linear Regression". 2nd Ed. 1985 + * * * @param x independent variable value * @param y dependent variable value */ public void addData(double x, double y) { + if (n == 0) { + xbar = x; + ybar = y; + } else { + sumXX += ((double) n / (double) (n + 1)) * (x - xbar) * (x - xbar); + sumYY += ((double) n / (double) (n + 1)) * (y - ybar) * (y - ybar); + sumXY += ((double) n / (double) (n + 1)) * (x - xbar) * (y - ybar); + xbar += (1d / (double) (n + 1)) * (x - xbar); + ybar += (1d / (double) (n + 1)) * (y - ybar); + } sumX += x; - sumSqX += x * x; sumY += y; - sumSqY += y * y; - sumXY += x * y; n++; } @@ -148,9 +169,9 @@ public class BivariateRegression { */ public void clear() { sumX = 0d; - sumSqX = 0d; + sumXX = 0d; sumY = 0d; - sumSqY = 0d; + sumYY = 0d; sumXY = 0d; n = 0; } @@ -215,7 +236,7 @@ public class BivariateRegression { * Preconditions:

* @@ -225,12 +246,10 @@ public class BivariateRegression { if (n < 2) { return Double.NaN; //not enough data } - double dn = (double) n; - double denom = sumSqX - (sumX * sumX / dn); - if (Math.abs(denom) < 10 * Double.MIN_VALUE) { + if (Math.abs(sumXX) < 10 * Double.MIN_VALUE) { return Double.NaN; //not enough variation in x } - return (sumXY - (sumX * sumY / dn)) / denom; + return sumXY / sumXX; } /** @@ -265,7 +284,7 @@ public class BivariateRegression { if (n < 2) { return Double.NaN; } - return sumSqY - sumY * sumY / (double) n; + return sumYY; } /** @@ -282,11 +301,10 @@ public class BivariateRegression { * returned. * * - * @return sum of squared deviations of y values + * @return sum of squared deviations of predicted y values */ public double getRegressionSumSquares() { - double b1 = getSlope(); - return b1 * (sumXY - sumX * sumY / (double) n); + return getRegressionSumSquares(getSlope()); } /** @@ -303,8 +321,7 @@ public class BivariateRegression { if (n < 3) { return Double.NaN; } - double sse = getSumSquaredErrors(); - return sse / (double) (n - 2); + return getSumSquaredErrors() / (double) (n - 2); } /** @@ -361,8 +378,8 @@ public class BivariateRegression { * @return standard error associated with intercept estimate */ public double getInterceptStdErr() { - double ssx = getSumSquaresX(); - return Math.sqrt(getMeanSquareError() * sumSqX / (((double) n) * ssx)); + return Math.sqrt(getMeanSquareError() * ((1d / (double) n) + + (xbar * xbar) / sumXX)); } /** @@ -376,8 +393,7 @@ public class BivariateRegression { * @return standard error associated with slope estimate */ public double getSlopeStdErr() { - double ssx = getSumSquaresX(); - return Math.sqrt(getMeanSquareError() / ssx); + return Math.sqrt(getMeanSquareError() / sumXX); } /** @@ -492,24 +508,9 @@ public class BivariateRegression { * @return sum of squared errors associated with the regression model */ private double getSumSquaredErrors(double b1) { - double b0 = getIntercept(b1); - return sumSqY - b0 * sumY - b1 * sumXY; + return sumYY - sumXY * sumXY / sumXX; } - /** - * Returns the sum of squared deviations of the x values about their mean. - *

- * If n < 2, this returns NaN. - * - * @return sum of squared deviations of x values - */ - private double getSumSquaresX() { - if (n < 2) { - return Double.NaN; - } - return sumSqX - sumX * sumX / (double) n; - } - /** * Computes r-square from the slope. *

@@ -523,6 +524,16 @@ public class BivariateRegression { return (ssto - getSumSquaredErrors(b1)) / ssto; } + /** + * Computes SSR from b1. + * + * @param slope regression slope estimate + * @return sum of squared deviations of predicted y values + */ + private double getRegressionSumSquares(double slope) { + return slope * slope * sumXX; + } + /** * Uses distribution framework to get a t distribution instance * with df = n - 2 diff --git a/src/test/org/apache/commons/math/stat/BivariateRegressionTest.java b/src/test/org/apache/commons/math/stat/BivariateRegressionTest.java index b31bda655..5e4481146 100644 --- a/src/test/org/apache/commons/math/stat/BivariateRegressionTest.java +++ b/src/test/org/apache/commons/math/stat/BivariateRegressionTest.java @@ -60,7 +60,7 @@ import junit.framework.TestSuite; * Test cases for the TestStatistic class. * * @author Phil Steitz - * @version $Revision: 1.2 $ $Date: 2003/06/11 03:33:05 $ + * @version $Revision: 1.3 $ $Date: 2003/06/21 02:13:41 $ */ public final class BivariateRegressionTest extends TestCase { @@ -130,15 +130,15 @@ public final class BivariateRegressionTest extends TestCase { assertEquals("r-square",0.999993745883712, regression.getRSquare(),10E-12); assertEquals("SSR",4255954.13232369, - regression.getRegressionSumSquares(),10E-8); + regression.getRegressionSumSquares(),10E-9); assertEquals("MSE",0.782864662630069, - regression.getMeanSquareError(),10E-8); + regression.getMeanSquareError(),10E-10); assertEquals("SSE",26.6173985294224, - regression.getSumSquaredErrors(),10E-8); + regression.getSumSquaredErrors(),10E-9); assertEquals("predict(0)",-0.262323073774029, regression.predict(0),10E-12); assertEquals("predict(1)",1.00211681802045-0.262323073774029, - regression.predict(1),10E-11); + regression.predict(1),10E-12); } public void testCorr() {