Mark R. Diggory 2003-06-21 02:13:41 +00:00
parent 7e2a7b5027
commit 81c03fcee2
2 changed files with 58 additions and 47 deletions

View File

@ -85,21 +85,21 @@ import org.apache.commons.math.stat.distribution.TDistribution;
* </ul> * </ul>
* *
* @author Phil Steitz * @author Phil Steitz
* @version $Revision: 1.2 $ $Date: 2003/06/11 03:33:05 $ * @version $Revision: 1.3 $ $Date: 2003/06/21 02:13:41 $
*/ */
public class BivariateRegression { public class BivariateRegression {
/** sum of x values */ /** sum of x values */
private double sumX = 0d; private double sumX = 0d;
/** sum of squared x values */ /** total variation in x (sum of squared deviations from xbar) */
private double sumSqX = 0d; private double sumXX = 0d;
/** sum of y values */ /** sum of y values */
private double sumY = 0d; private double sumY = 0d;
/** sum of squared y values */ /** total variation in y (sum of squared deviations from ybar) */
private double sumSqY = 0d; private double sumYY = 0d;
/** sum of products */ /** sum of products */
private double sumXY = 0d; private double sumXY = 0d;
@ -107,20 +107,41 @@ public class BivariateRegression {
/** number of observations */ /** number of observations */
private long n = 0; private long n = 0;
/** mean of accumulated x values, used in updating formulas */
private double xbar = 0;
/** mean of accumulated y values, used in updating formulas */
private double ybar = 0;
// ---------------------Public methods-------------------------------------- // ---------------------Public methods--------------------------------------
/** /**
* Adds the observation (x,y) to the regression data set * Adds the observation (x,y) to the regression data set.
* <p>
* Uses updating formulas for means and sums of squares defined in
* "Algorithms for Computing the Sample Variance: Analysis and
* Recommendations", Chan, T.F., Golub, G.H., and LeVeque, R.J.
* 1983, American Statistician, vol. 37, pp. 242-247, referenced in
* Weisberg, S. "Applied Linear Regression". 2nd Ed. 1985
*
* *
* @param x independent variable value * @param x independent variable value
* @param y dependent variable value * @param y dependent variable value
*/ */
public void addData(double x, double y) { public void addData(double x, double y) {
if (n == 0) {
xbar = x;
ybar = y;
} else {
sumXX += ((double) n / (double) (n + 1)) * (x - xbar) * (x - xbar);
sumYY += ((double) n / (double) (n + 1)) * (y - ybar) * (y - ybar);
sumXY += ((double) n / (double) (n + 1)) * (x - xbar) * (y - ybar);
xbar += (1d / (double) (n + 1)) * (x - xbar);
ybar += (1d / (double) (n + 1)) * (y - ybar);
}
sumX += x; sumX += x;
sumSqX += x * x;
sumY += y; sumY += y;
sumSqY += y * y;
sumXY += x * y;
n++; n++;
} }
@ -148,9 +169,9 @@ public class BivariateRegression {
*/ */
public void clear() { public void clear() {
sumX = 0d; sumX = 0d;
sumSqX = 0d; sumXX = 0d;
sumY = 0d; sumY = 0d;
sumSqY = 0d; sumYY = 0d;
sumXY = 0d; sumXY = 0d;
n = 0; n = 0;
} }
@ -215,7 +236,7 @@ public class BivariateRegression {
* <strong>Preconditions</strong>: <ul> * <strong>Preconditions</strong>: <ul>
* <li>At least two observations (with at least two different x values) * <li>At least two observations (with at least two different x values)
* must have been added before invoking this method. If this method is * must have been added before invoking this method. If this method is
* invoked before a model can be estimated, <code>Double,NaN</code> is * invoked before a model can be estimated, <code>Double.NaN</code> is
* returned. * returned.
* </li></ul> * </li></ul>
* *
@ -225,12 +246,10 @@ public class BivariateRegression {
if (n < 2) { if (n < 2) {
return Double.NaN; //not enough data return Double.NaN; //not enough data
} }
double dn = (double) n; if (Math.abs(sumXX) < 10 * Double.MIN_VALUE) {
double denom = sumSqX - (sumX * sumX / dn);
if (Math.abs(denom) < 10 * Double.MIN_VALUE) {
return Double.NaN; //not enough variation in x return Double.NaN; //not enough variation in x
} }
return (sumXY - (sumX * sumY / dn)) / denom; return sumXY / sumXX;
} }
/** /**
@ -265,7 +284,7 @@ public class BivariateRegression {
if (n < 2) { if (n < 2) {
return Double.NaN; return Double.NaN;
} }
return sumSqY - sumY * sumY / (double) n; return sumYY;
} }
/** /**
@ -282,11 +301,10 @@ public class BivariateRegression {
* returned. * returned.
* </li></ul> * </li></ul>
* *
* @return sum of squared deviations of y values * @return sum of squared deviations of predicted y values
*/ */
public double getRegressionSumSquares() { public double getRegressionSumSquares() {
double b1 = getSlope(); return getRegressionSumSquares(getSlope());
return b1 * (sumXY - sumX * sumY / (double) n);
} }
/** /**
@ -303,8 +321,7 @@ public class BivariateRegression {
if (n < 3) { if (n < 3) {
return Double.NaN; return Double.NaN;
} }
double sse = getSumSquaredErrors(); return getSumSquaredErrors() / (double) (n - 2);
return sse / (double) (n - 2);
} }
/** /**
@ -361,8 +378,8 @@ public class BivariateRegression {
* @return standard error associated with intercept estimate * @return standard error associated with intercept estimate
*/ */
public double getInterceptStdErr() { public double getInterceptStdErr() {
double ssx = getSumSquaresX(); return Math.sqrt(getMeanSquareError() * ((1d / (double) n) +
return Math.sqrt(getMeanSquareError() * sumSqX / (((double) n) * ssx)); (xbar * xbar) / sumXX));
} }
/** /**
@ -376,8 +393,7 @@ public class BivariateRegression {
* @return standard error associated with slope estimate * @return standard error associated with slope estimate
*/ */
public double getSlopeStdErr() { public double getSlopeStdErr() {
double ssx = getSumSquaresX(); return Math.sqrt(getMeanSquareError() / sumXX);
return Math.sqrt(getMeanSquareError() / ssx);
} }
/** /**
@ -492,24 +508,9 @@ public class BivariateRegression {
* @return sum of squared errors associated with the regression model * @return sum of squared errors associated with the regression model
*/ */
private double getSumSquaredErrors(double b1) { private double getSumSquaredErrors(double b1) {
double b0 = getIntercept(b1); return sumYY - sumXY * sumXY / sumXX;
return sumSqY - b0 * sumY - b1 * sumXY;
} }
/**
* Returns the sum of squared deviations of the x values about their mean.
* <p>
* If n < 2, this returns NaN.
*
* @return sum of squared deviations of x values
*/
private double getSumSquaresX() {
if (n < 2) {
return Double.NaN;
}
return sumSqX - sumX * sumX / (double) n;
}
/** /**
* Computes r-square from the slope. * Computes r-square from the slope.
* <p> * <p>
@ -523,6 +524,16 @@ public class BivariateRegression {
return (ssto - getSumSquaredErrors(b1)) / ssto; return (ssto - getSumSquaredErrors(b1)) / ssto;
} }
/**
* Computes SSR from b1.
*
* @param slope regression slope estimate
* @return sum of squared deviations of predicted y values
*/
private double getRegressionSumSquares(double slope) {
return slope * slope * sumXX;
}
/** /**
* Uses distribution framework to get a t distribution instance * Uses distribution framework to get a t distribution instance
* with df = n - 2 * with df = n - 2

View File

@ -60,7 +60,7 @@ import junit.framework.TestSuite;
* Test cases for the TestStatistic class. * Test cases for the TestStatistic class.
* *
* @author Phil Steitz * @author Phil Steitz
* @version $Revision: 1.2 $ $Date: 2003/06/11 03:33:05 $ * @version $Revision: 1.3 $ $Date: 2003/06/21 02:13:41 $
*/ */
public final class BivariateRegressionTest extends TestCase { public final class BivariateRegressionTest extends TestCase {
@ -130,15 +130,15 @@ public final class BivariateRegressionTest extends TestCase {
assertEquals("r-square",0.999993745883712, assertEquals("r-square",0.999993745883712,
regression.getRSquare(),10E-12); regression.getRSquare(),10E-12);
assertEquals("SSR",4255954.13232369, assertEquals("SSR",4255954.13232369,
regression.getRegressionSumSquares(),10E-8); regression.getRegressionSumSquares(),10E-9);
assertEquals("MSE",0.782864662630069, assertEquals("MSE",0.782864662630069,
regression.getMeanSquareError(),10E-8); regression.getMeanSquareError(),10E-10);
assertEquals("SSE",26.6173985294224, assertEquals("SSE",26.6173985294224,
regression.getSumSquaredErrors(),10E-8); regression.getSumSquaredErrors(),10E-9);
assertEquals("predict(0)",-0.262323073774029, assertEquals("predict(0)",-0.262323073774029,
regression.predict(0),10E-12); regression.predict(0),10E-12);
assertEquals("predict(1)",1.00211681802045-0.262323073774029, assertEquals("predict(1)",1.00211681802045-0.262323073774029,
regression.predict(1),10E-11); regression.predict(1),10E-12);
} }
public void testCorr() { public void testCorr() {