Mark R. Diggory 2003-06-21 02:13:41 +00:00
parent 7e2a7b5027
commit 81c03fcee2
2 changed files with 58 additions and 47 deletions

View File

@ -85,21 +85,21 @@ import org.apache.commons.math.stat.distribution.TDistribution;
* </ul>
*
* @author Phil Steitz
* @version $Revision: 1.2 $ $Date: 2003/06/11 03:33:05 $
* @version $Revision: 1.3 $ $Date: 2003/06/21 02:13:41 $
*/
public class BivariateRegression {
/** sum of x values */
private double sumX = 0d;
/** sum of squared x values */
private double sumSqX = 0d;
/** total variation in x (sum of squared deviations from xbar) */
private double sumXX = 0d;
/** sum of y values */
private double sumY = 0d;
/** sum of squared y values */
private double sumSqY = 0d;
/** total variation in y (sum of squared deviations from ybar) */
private double sumYY = 0d;
/** sum of products */
private double sumXY = 0d;
@ -107,20 +107,41 @@ public class BivariateRegression {
/** number of observations */
private long n = 0;
/** mean of accumulated x values, used in updating formulas */
private double xbar = 0;
/** mean of accumulated y values, used in updating formulas */
private double ybar = 0;
// ---------------------Public methods--------------------------------------
/**
* Adds the observation (x,y) to the regression data set
* Adds the observation (x,y) to the regression data set.
* <p>
* Uses updating formulas for means and sums of squares defined in
* "Algorithms for Computing the Sample Variance: Analysis and
* Recommendations", Chan, T.F., Golub, G.H., and LeVeque, R.J.
* 1983, American Statistician, vol. 37, pp. 242-247, referenced in
* Weisberg, S. "Applied Linear Regression". 2nd Ed. 1985
*
*
* @param x independent variable value
* @param y dependent variable value
*/
public void addData(double x, double y) {
if (n == 0) {
xbar = x;
ybar = y;
} else {
sumXX += ((double) n / (double) (n + 1)) * (x - xbar) * (x - xbar);
sumYY += ((double) n / (double) (n + 1)) * (y - ybar) * (y - ybar);
sumXY += ((double) n / (double) (n + 1)) * (x - xbar) * (y - ybar);
xbar += (1d / (double) (n + 1)) * (x - xbar);
ybar += (1d / (double) (n + 1)) * (y - ybar);
}
sumX += x;
sumSqX += x * x;
sumY += y;
sumSqY += y * y;
sumXY += x * y;
n++;
}
@ -148,9 +169,9 @@ public class BivariateRegression {
*/
public void clear() {
sumX = 0d;
sumSqX = 0d;
sumXX = 0d;
sumY = 0d;
sumSqY = 0d;
sumYY = 0d;
sumXY = 0d;
n = 0;
}
@ -215,7 +236,7 @@ public class BivariateRegression {
* <strong>Preconditions</strong>: <ul>
* <li>At least two observations (with at least two different x values)
* must have been added before invoking this method. If this method is
* invoked before a model can be estimated, <code>Double,NaN</code> is
* invoked before a model can be estimated, <code>Double.NaN</code> is
* returned.
* </li></ul>
*
@ -225,12 +246,10 @@ public class BivariateRegression {
if (n < 2) {
return Double.NaN; //not enough data
}
double dn = (double) n;
double denom = sumSqX - (sumX * sumX / dn);
if (Math.abs(denom) < 10 * Double.MIN_VALUE) {
if (Math.abs(sumXX) < 10 * Double.MIN_VALUE) {
return Double.NaN; //not enough variation in x
}
return (sumXY - (sumX * sumY / dn)) / denom;
return sumXY / sumXX;
}
/**
@ -265,7 +284,7 @@ public class BivariateRegression {
if (n < 2) {
return Double.NaN;
}
return sumSqY - sumY * sumY / (double) n;
return sumYY;
}
/**
@ -282,11 +301,10 @@ public class BivariateRegression {
* returned.
* </li></ul>
*
* @return sum of squared deviations of y values
* @return sum of squared deviations of predicted y values
*/
public double getRegressionSumSquares() {
double b1 = getSlope();
return b1 * (sumXY - sumX * sumY / (double) n);
return getRegressionSumSquares(getSlope());
}
/**
@ -303,8 +321,7 @@ public class BivariateRegression {
if (n < 3) {
return Double.NaN;
}
double sse = getSumSquaredErrors();
return sse / (double) (n - 2);
return getSumSquaredErrors() / (double) (n - 2);
}
/**
@ -361,8 +378,8 @@ public class BivariateRegression {
* @return standard error associated with intercept estimate
*/
public double getInterceptStdErr() {
double ssx = getSumSquaresX();
return Math.sqrt(getMeanSquareError() * sumSqX / (((double) n) * ssx));
return Math.sqrt(getMeanSquareError() * ((1d / (double) n) +
(xbar * xbar) / sumXX));
}
/**
@ -376,8 +393,7 @@ public class BivariateRegression {
* @return standard error associated with slope estimate
*/
public double getSlopeStdErr() {
double ssx = getSumSquaresX();
return Math.sqrt(getMeanSquareError() / ssx);
return Math.sqrt(getMeanSquareError() / sumXX);
}
/**
@ -492,22 +508,7 @@ public class BivariateRegression {
* @return sum of squared errors associated with the regression model
*/
private double getSumSquaredErrors(double b1) {
double b0 = getIntercept(b1);
return sumSqY - b0 * sumY - b1 * sumXY;
}
/**
* Returns the sum of squared deviations of the x values about their mean.
* <p>
* If n < 2, this returns NaN.
*
* @return sum of squared deviations of x values
*/
private double getSumSquaresX() {
if (n < 2) {
return Double.NaN;
}
return sumSqX - sumX * sumX / (double) n;
return sumYY - sumXY * sumXY / sumXX;
}
/**
@ -523,6 +524,16 @@ public class BivariateRegression {
return (ssto - getSumSquaredErrors(b1)) / ssto;
}
/**
* Computes SSR from b1.
*
* @param slope regression slope estimate
* @return sum of squared deviations of predicted y values
*/
private double getRegressionSumSquares(double slope) {
return slope * slope * sumXX;
}
/**
* Uses distribution framework to get a t distribution instance
* with df = n - 2

View File

@ -60,7 +60,7 @@ import junit.framework.TestSuite;
* Test cases for the TestStatistic class.
*
* @author Phil Steitz
* @version $Revision: 1.2 $ $Date: 2003/06/11 03:33:05 $
* @version $Revision: 1.3 $ $Date: 2003/06/21 02:13:41 $
*/
public final class BivariateRegressionTest extends TestCase {
@ -130,15 +130,15 @@ public final class BivariateRegressionTest extends TestCase {
assertEquals("r-square",0.999993745883712,
regression.getRSquare(),10E-12);
assertEquals("SSR",4255954.13232369,
regression.getRegressionSumSquares(),10E-8);
regression.getRegressionSumSquares(),10E-9);
assertEquals("MSE",0.782864662630069,
regression.getMeanSquareError(),10E-8);
regression.getMeanSquareError(),10E-10);
assertEquals("SSE",26.6173985294224,
regression.getSumSquaredErrors(),10E-8);
regression.getSumSquaredErrors(),10E-9);
assertEquals("predict(0)",-0.262323073774029,
regression.predict(0),10E-12);
assertEquals("predict(1)",1.00211681802045-0.262323073774029,
regression.predict(1),10E-11);
regression.predict(1),10E-12);
}
public void testCorr() {