diff --git a/src/java/org/apache/commons/math/stat/BivariateRegression.java b/src/java/org/apache/commons/math/stat/BivariateRegression.java index 619701800..865cff5a4 100644 --- a/src/java/org/apache/commons/math/stat/BivariateRegression.java +++ b/src/java/org/apache/commons/math/stat/BivariateRegression.java @@ -85,21 +85,21 @@ import org.apache.commons.math.stat.distribution.TDistribution; * * * @author Phil Steitz - * @version $Revision: 1.2 $ $Date: 2003/06/11 03:33:05 $ + * @version $Revision: 1.3 $ $Date: 2003/06/21 02:13:41 $ */ public class BivariateRegression { /** sum of x values */ private double sumX = 0d; - /** sum of squared x values */ - private double sumSqX = 0d; + /** total variation in x (sum of squared deviations from xbar) */ + private double sumXX = 0d; /** sum of y values */ private double sumY = 0d; - /** sum of squared y values */ - private double sumSqY = 0d; + /** total variation in y (sum of squared deviations from ybar) */ + private double sumYY = 0d; /** sum of products */ private double sumXY = 0d; @@ -107,20 +107,41 @@ public class BivariateRegression { /** number of observations */ private long n = 0; + /** mean of accumulated x values, used in updating formulas */ + private double xbar = 0; + + /** mean of accumulated y values, used in updating formulas */ + private double ybar = 0; + + // ---------------------Public methods-------------------------------------- /** - * Adds the observation (x,y) to the regression data set + * Adds the observation (x,y) to the regression data set. + *
+ * Uses updating formulas for means and sums of squares defined in + * "Algorithms for Computing the Sample Variance: Analysis and + * Recommendations", Chan, T.F., Golub, G.H., and LeVeque, R.J. + * 1983, American Statistician, vol. 37, pp. 242-247, referenced in + * Weisberg, S. "Applied Linear Regression". 2nd Ed. 1985 + * * * @param x independent variable value * @param y dependent variable value */ public void addData(double x, double y) { + if (n == 0) { + xbar = x; + ybar = y; + } else { + sumXX += ((double) n / (double) (n + 1)) * (x - xbar) * (x - xbar); + sumYY += ((double) n / (double) (n + 1)) * (y - ybar) * (y - ybar); + sumXY += ((double) n / (double) (n + 1)) * (x - xbar) * (y - ybar); + xbar += (1d / (double) (n + 1)) * (x - xbar); + ybar += (1d / (double) (n + 1)) * (y - ybar); + } sumX += x; - sumSqX += x * x; sumY += y; - sumSqY += y * y; - sumXY += x * y; n++; } @@ -148,9 +169,9 @@ public class BivariateRegression { */ public void clear() { sumX = 0d; - sumSqX = 0d; + sumXX = 0d; sumY = 0d; - sumSqY = 0d; + sumYY = 0d; sumXY = 0d; n = 0; } @@ -215,7 +236,7 @@ public class BivariateRegression { * Preconditions:
Double,NaN
is
+ * invoked before a model can be estimated, Double.NaN
is
* returned.
* - * If n < 2, this returns NaN. - * - * @return sum of squared deviations of x values - */ - private double getSumSquaresX() { - if (n < 2) { - return Double.NaN; - } - return sumSqX - sumX * sumX / (double) n; - } - /** * Computes r-square from the slope. *
@@ -523,6 +524,16 @@ public class BivariateRegression { return (ssto - getSumSquaredErrors(b1)) / ssto; } + /** + * Computes SSR from b1. + * + * @param slope regression slope estimate + * @return sum of squared deviations of predicted y values + */ + private double getRegressionSumSquares(double slope) { + return slope * slope * sumXX; + } + /** * Uses distribution framework to get a t distribution instance * with df = n - 2 diff --git a/src/test/org/apache/commons/math/stat/BivariateRegressionTest.java b/src/test/org/apache/commons/math/stat/BivariateRegressionTest.java index b31bda655..5e4481146 100644 --- a/src/test/org/apache/commons/math/stat/BivariateRegressionTest.java +++ b/src/test/org/apache/commons/math/stat/BivariateRegressionTest.java @@ -60,7 +60,7 @@ import junit.framework.TestSuite; * Test cases for the TestStatistic class. * * @author Phil Steitz - * @version $Revision: 1.2 $ $Date: 2003/06/11 03:33:05 $ + * @version $Revision: 1.3 $ $Date: 2003/06/21 02:13:41 $ */ public final class BivariateRegressionTest extends TestCase { @@ -130,15 +130,15 @@ public final class BivariateRegressionTest extends TestCase { assertEquals("r-square",0.999993745883712, regression.getRSquare(),10E-12); assertEquals("SSR",4255954.13232369, - regression.getRegressionSumSquares(),10E-8); + regression.getRegressionSumSquares(),10E-9); assertEquals("MSE",0.782864662630069, - regression.getMeanSquareError(),10E-8); + regression.getMeanSquareError(),10E-10); assertEquals("SSE",26.6173985294224, - regression.getSumSquaredErrors(),10E-8); + regression.getSumSquaredErrors(),10E-9); assertEquals("predict(0)",-0.262323073774029, regression.predict(0),10E-12); assertEquals("predict(1)",1.00211681802045-0.262323073774029, - regression.predict(1),10E-11); + regression.predict(1),10E-12); } public void testCorr() {