JIRA: MATH-607 Adding support for UpdatingMultipleLinearRegression to SimpleRegression

git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1174509 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Greg Sterijevski 2011-09-23 03:36:11 +00:00
parent 0ba7ddcd9e
commit 81821bc466
6 changed files with 266 additions and 26 deletions

View File

@ -42,6 +42,7 @@ public enum LocalizedFormats implements Localizable {
// CHECKSTYLE: stop JavadocVariable
ARGUMENT_OUTSIDE_DOMAIN("Argument {0} outside domain [{1} ; {2}]"),
ARRAY_SIZE_EXCEEDS_MAX_VARIABLES("array size cannot be greater than {0}"),
ARRAY_SIZES_SHOULD_HAVE_DIFFERENCE_1("array sizes should have difference 1 ({0} != {1} + 1)"),
ARRAY_SUMS_TO_ZERO("array sums to zero"),
ASSYMETRIC_EIGEN_NOT_SUPPORTED("eigen decomposition of assymetric matrices not supported yet"),
@ -135,6 +136,7 @@ public enum LocalizedFormats implements Localizable {
INVALID_INTERVAL_INITIAL_VALUE_PARAMETERS("invalid interval, initial value parameters: lower={0}, initial={1}, upper={2}"),
INVALID_ITERATIONS_LIMITS("invalid iteration limits: min={0}, max={1}"),
INVALID_MAX_ITERATIONS("bad value for maximum iterations number: {0}"),
NOT_ENOUGH_DATA_REGRESSION("the number of observations is not sufficient to conduct regression"),
INVALID_REGRESSION_ARRAY("input data array length = {0} does not match the number of observations = {1} and the number of regressors = {2}"),
INVALID_REGRESSION_OBSERVATION("length of regressor array = {0} does not match the number of variables = {1} in the model"),
INVALID_ROUNDING_METHOD("invalid rounding method {0}, valid methods: {1} ({2}), {3} ({4}), {5} ({6}), {7} ({8}), {9} ({10}), {11} ({12}), {13} ({14}), {15} ({16})"),
@ -239,6 +241,7 @@ public enum LocalizedFormats implements Localizable {
NO_RESULT_AVAILABLE("no result available"),
NO_SUCH_MATRIX_ENTRY("no entry at indices ({0}, {1}) in a {2}x{3} matrix"),
NULL_NOT_ALLOWED("null is not allowed"), /* keep */
ARRAY_ZERO_LENGTH_OR_NULL_NOTALLOWED("A null or zero length array not allowed"),
COVARIANCE_MATRIX("covariance matrix"), /* keep */
DENOMINATOR("denominator"), /* keep */
DENOMINATOR_FORMAT("denominator format"), /* keep */

View File

@ -114,7 +114,7 @@ public class RegressionResults implements Serializable {
this.globalFitInfo = new double[5];
Arrays.fill(this.globalFitInfo, Double.NaN);
if (rank > 2) {
if (rank > 0) {
this.globalFitInfo[SST_IDX] = containsConstant ?
(sumysq - sumy * sumy / ((double) nobs)) : sumysq;
}

View File

@ -22,8 +22,11 @@ import org.apache.commons.math.MathException;
import org.apache.commons.math.exception.OutOfRangeException;
import org.apache.commons.math.distribution.TDistribution;
import org.apache.commons.math.distribution.TDistributionImpl;
import org.apache.commons.math.exception.MathIllegalArgumentException;
import org.apache.commons.math.exception.NoDataException;
import org.apache.commons.math.exception.util.LocalizedFormats;
import org.apache.commons.math.util.FastMath;
import org.apache.commons.math.util.MathUtils;
/**
* Estimates an ordinary least squares regression model
@ -55,7 +58,7 @@ import org.apache.commons.math.util.FastMath;
*
* @version $Id$
*/
public class SimpleRegression implements Serializable {
public class SimpleRegression implements Serializable, UpdatingMultipleLinearRegression {
/** Serializable version identifier */
private static final long serialVersionUID = -3004689053607543335L;
@ -98,7 +101,7 @@ public class SimpleRegression implements Serializable {
* Secondary constructor which allows the user the ability to include/exclude const
* @param includeIntercept boolean flag, true includes an intercept
*/
public SimpleRegression(boolean includeIntercept){
public SimpleRegression(boolean includeIntercept) {
super();
hasIntercept = includeIntercept;
}
@ -116,7 +119,7 @@ public class SimpleRegression implements Serializable {
* @param x independent variable value
* @param y dependent variable value
*/
public void addData(final double x, final double y){
public void addData(final double x,final double y) {
if (n == 0) {
xbar = x;
ybar = y;
@ -158,7 +161,7 @@ public class SimpleRegression implements Serializable {
* @param x independent variable value
* @param y dependent variable value
*/
public void removeData(double x, double y) {
public void removeData(final double x,final double y) {
if (n > 0) {
if (hasIntercept) {
final double fact1 = (double) n - 1.0;
@ -200,13 +203,69 @@ public class SimpleRegression implements Serializable {
* data.</p>
*
* @param data array of observations to be added
* @throws ModelSpecificationException if the length of {@code data[i]} is not
* greater than or equal to 2
*/
public void addData(double[][] data) {
public void addData(final double[][] data) {
for (int i = 0; i < data.length; i++) {
if( data[i].length < 2 ){
throw new ModelSpecificationException(LocalizedFormats.INVALID_REGRESSION_OBSERVATION,
data[i].length, 2);
}
addData(data[i][0], data[i][1]);
}
return;
}
/**
* Adds one observation to the regression model.
*
* @param x the independent variables which form the design matrix
* @param y the dependent or response variable
* @throws ModelSpecificationException if the length of {@code x} does not equal
* the number of independent variables in the model
*/
public void addObservation(final double[] x,final double y) throws ModelSpecificationException{
if( x == null || x.length == 0 ){
throw new ModelSpecificationException(LocalizedFormats.INVALID_REGRESSION_OBSERVATION,x!=null?x.length:0, 1);
}
addData( x[0], y );
return;
}
/**
* Adds a series of observations to the regression model. The lengths of
* x and y must be the same and x must be rectangular.
*
* @param x a series of observations on the independent variables
* @param y a series of observations on the dependent variable
* The length of x and y must be the same
* @throws ModelSpecificationException if {@code x} is not rectangular, does not match
* the length of {@code y} or does not contain sufficient data to estimate the model
*/
public void addObservations(final double[][] x,final double[] y) {
if ((x == null) || (y == null) || (x.length != y.length)) {
throw new ModelSpecificationException(
LocalizedFormats.DIMENSIONS_MISMATCH_SIMPLE,
(x == null) ? 0 : x.length,
(y == null) ? 0 : y.length);
}
boolean obsOk=true;
for( int i = 0 ; i < x.length; i++){
if( x[i] == null || x[i].length == 0 ){
obsOk = false;
}
}
if( !obsOk ){
throw new ModelSpecificationException(
LocalizedFormats.NOT_ENOUGH_DATA_FOR_NUMBER_OF_PREDICTORS,
0, 1);
}
for( int i = 0 ; i < x.length ; i++){
addData( x[i][0], y[i] );
}
return;
}
/**
* Removes observations represented by the elements in <code>data</code>.
@ -265,8 +324,8 @@ public class SimpleRegression implements Serializable {
* @param x input <code>x</code> value
* @return predicted <code>y</code> value
*/
public double predict(double x) {
double b1 = getSlope();
public double predict(final double x) {
final double b1 = getSlope();
if (hasIntercept) {
return getIntercept(b1) + b1 * x;
}
@ -298,7 +357,7 @@ public class SimpleRegression implements Serializable {
*
* @return true if constant exists, false otherwise
*/
public boolean hasIntercept(){
public boolean hasIntercept() {
return hasIntercept;
}
@ -572,7 +631,7 @@ public class SimpleRegression implements Serializable {
* @return half-width of 95% confidence interval for the slope estimate
* @throws MathException if the confidence interval can not be computed.
*/
public double getSlopeConfidenceInterval(double alpha)
public double getSlopeConfidenceInterval(final double alpha)
throws MathException {
if (alpha >= 1 || alpha <= 0) {
throw new OutOfRangeException(LocalizedFormats.SIGNIFICANCE_LEVEL,
@ -620,7 +679,7 @@ public class SimpleRegression implements Serializable {
* @param slope current slope
* @return the intercept of the regression line
*/
private double getIntercept(double slope){
private double getIntercept(final double slope) {
if( hasIntercept){
return (sumY - slope * sumX) / n;
}
@ -633,7 +692,134 @@ public class SimpleRegression implements Serializable {
* @param slope regression slope estimate
* @return sum of squared deviations of predicted y values
*/
private double getRegressionSumSquares(double slope) {
private double getRegressionSumSquares(final double slope) {
return slope * slope * sumXX;
}
/**
* Performs a regression on data present in buffers and outputs a RegressionResults object
* @return RegressionResults acts as a container of regression output
* @throws ModelSpecificationException if the model is not correctly specified
*/
public RegressionResults regress() throws ModelSpecificationException{
if( hasIntercept ){
if( n < 3 ){
throw new NoDataException( LocalizedFormats.NOT_ENOUGH_DATA_REGRESSION );
}
if( FastMath.abs( sumXX ) > MathUtils.SAFE_MIN ){
final double[] params = new double[]{ getIntercept(), getSlope() };
final double mse = getMeanSquareError();
final double _syy = sumYY + sumY * sumY / ((double) n);
final double[] vcv = new double[]{
mse * (xbar *xbar /sumXX + 1.0 / ((double) n)),
-xbar*mse/sumXX,
mse/sumXX };
return new RegressionResults(
params, new double[][]{vcv}, true, n, 2,
sumY, _syy, getSumSquaredErrors(),true,false);
}else{
final double[] params = new double[]{ sumY/((double) n), Double.NaN };
//final double mse = getMeanSquareError();
final double[] vcv = new double[]{
ybar / ((double) n - 1.0),
Double.NaN,
Double.NaN };
return new RegressionResults(
params, new double[][]{vcv}, true, n, 1,
sumY, sumYY, getSumSquaredErrors(),true,false);
}
}else{
if( n < 2 ){
throw new NoDataException( LocalizedFormats.NOT_ENOUGH_DATA_REGRESSION );
}
if( !Double.isNaN(sumXX) ){
final double[] vcv = new double[]{ getMeanSquareError() / sumXX };
final double[] params = new double[]{ sumXY/sumXX };
return new RegressionResults(
params, new double[][]{vcv}, true, n, 1,
sumY, sumYY, getSumSquaredErrors(),false,false);
}else{
final double[] vcv = new double[]{Double.NaN };
final double[] params = new double[]{ Double.NaN };
return new RegressionResults(
params, new double[][]{vcv}, true, n, 1,
Double.NaN, Double.NaN, Double.NaN,false,false);
}
}
}
/**
* Performs a regression on data present in buffers including only regressors
* indexed in variablesToInclude and outputs a RegressionResults object
* @param variablesToInclude an array of indices of regressors to include
* @return RegressionResults acts as a container of regression output
* @throws ModelSpecificationException if the model is not correctly specified
* @throws MathIllegalArgumentException if the variablesToInclude array is null or zero length
* @throws OutOfRangeException if a requested variable is not present in model
*/
public RegressionResults regress(int[] variablesToInclude) throws ModelSpecificationException{
if( variablesToInclude == null || variablesToInclude.length == 0){
throw new MathIllegalArgumentException(LocalizedFormats.ARRAY_ZERO_LENGTH_OR_NULL_NOTALLOWED);
}
if( variablesToInclude.length > 2 || (variablesToInclude.length > 1 && !hasIntercept) ){
throw new ModelSpecificationException(
LocalizedFormats.ARRAY_SIZE_EXCEEDS_MAX_VARIABLES,
(variablesToInclude.length > 1 && !hasIntercept) ? 1 : 2);
}
if( hasIntercept ){
if( variablesToInclude.length == 2 ){
if( variablesToInclude[0] == 1 ){
throw new ModelSpecificationException(LocalizedFormats.NOT_INCREASING_SEQUENCE);
}else if( variablesToInclude[0] != 0 ){
throw new OutOfRangeException( variablesToInclude[0], 0,1 );
}
if( variablesToInclude[1] != 1){
throw new OutOfRangeException( variablesToInclude[0], 0,1 );
}
return regress();
}else{
if( variablesToInclude[0] != 1 && variablesToInclude[0] != 0 ){
throw new OutOfRangeException( variablesToInclude[0],0,1 );
}
final double _mean = sumY * sumY / ((double) n);
final double _syy = sumYY + _mean;
if( variablesToInclude[0] == 0 ){
//just the mean
final double[] vcv = new double[]{ sumYY/((double)((n-1)*n)) };
final double[] params = new double[]{ ybar };
return new RegressionResults(
params, new double[][]{vcv}, true, n, 1,
sumY, _syy+_mean, sumYY,true,false);
}else if( variablesToInclude[0] == 1){
//final double _syy = sumYY + sumY * sumY / ((double) n);
final double _sxx = sumXX + sumX * sumX / ((double) n);
final double _sxy = sumXY + sumX * sumY / ((double) n);
final double _sse = FastMath.max(0d, _syy - _sxy * _sxy / _sxx);
final double _mse = _sse/((double)(n-1));
if( !Double.isNaN(_sxx) ){
final double[] vcv = new double[]{ _mse / _sxx };
final double[] params = new double[]{ _sxy/_sxx };
return new RegressionResults(
params, new double[][]{vcv}, true, n, 1,
sumY, _syy, _sse,false,false);
}else{
final double[] vcv = new double[]{Double.NaN };
final double[] params = new double[]{ Double.NaN };
return new RegressionResults(
params, new double[][]{vcv}, true, n, 1,
Double.NaN, Double.NaN, Double.NaN,false,false);
}
}
}
}else{
if( variablesToInclude[0] != 0 ){
throw new OutOfRangeException(variablesToInclude[0],0,0);
}
return regress();
}
return null;
}
}

View File

@ -61,7 +61,7 @@ public interface UpdatingMultipleLinearRegression {
* @throws ModelSpecificationException if {@code x} is not rectangular, does not match
* the length of {@code y} or does not contain sufficient data to estimate the model
*/
void addObservations(double[][] x, double[] y);
void addObservations(double[][] x, double[] y) throws ModelSpecificationException;
/**
* Clears internal buffers and resets the regression model. This means all

View File

@ -52,6 +52,9 @@ The <action> type attribute can be add,update,fix,remove.
If the output is not quite correct, check for invisible trailing spaces!
-->
<release version="3.0" date="TBD" description="TBD">
<action dev="gregs" type="update" issue="MATH-607">
SimpleRegression implements UpdatingMultipleLinearRegression interface.
</action>
<action dev="gregs" type="update" issue="MATH-675">
Added isMonotone methods in MathUtils. Optimized checkOrder method.
</action>

View File

@ -80,9 +80,9 @@ public final class SimpleRegressionTest {
{5, -1 }, {6, 12 }
};
/*
* Data from NIST NOINT1
* Data from NIST NOINT1
*/
private double[][] noint1 = {
{130.0,60.0},
@ -95,26 +95,74 @@ public final class SimpleRegressionTest {
{137.0,67.0},
{138.0,68.0},
{139.0,69.0},
{140.0,70.0}
{140.0,70.0}
};
/*
* Data from NIST NOINT2
*
* Data from NIST NOINT2
*
*/
private double[][] noint2 = {
{3.0,4},
{4,5},
{4,6}
};
@Test
public void testRegressIfaceMethod(){
final SimpleRegression regression = new SimpleRegression(true);
final UpdatingMultipleLinearRegression iface = regression;
final SimpleRegression regressionNoint = new SimpleRegression( false );
final SimpleRegression regressionIntOnly= new SimpleRegression( false );
for (int i = 0; i < data.length; i++) {
iface.addObservation( new double[]{data[i][1]}, data[i][0]);
regressionNoint.addData(data[i][1], data[i][0]);
regressionIntOnly.addData(1.0, data[i][0]);
}
//should not be null
final RegressionResults fullReg = iface.regress( );
Assert.assertTrue(fullReg != null);
Assert.assertEquals("intercept", regression.getIntercept(), fullReg.getParameterEstimate(0), 1.0e-16);
Assert.assertEquals("intercept std err",regression.getInterceptStdErr(), fullReg.getStdErrorOfEstimate(0),1.0E-16);
Assert.assertEquals("slope", regression.getSlope(), fullReg.getParameterEstimate(1), 1.0e-16);
Assert.assertEquals("slope std err",regression.getSlopeStdErr(), fullReg.getStdErrorOfEstimate(1),1.0E-16);
Assert.assertEquals("number of observations",regression.getN(), fullReg.getN());
Assert.assertEquals("r-square",regression.getRSquare(), fullReg.getRSquared(), 1.0E-16);
Assert.assertEquals("SSR", regression.getRegressionSumSquares(), fullReg.getRegressionSumSquares() ,1.0E-16);
Assert.assertEquals("MSE", regression.getMeanSquareError(), fullReg.getMeanSquareError() ,1.0E-16);
Assert.assertEquals("SSE", regression.getSumSquaredErrors(), fullReg.getErrorSumSquares() ,1.0E-16);
final RegressionResults noInt = iface.regress( new int[]{1} );
Assert.assertTrue(noInt != null);
Assert.assertEquals("slope", regressionNoint.getSlope(), noInt.getParameterEstimate(0), 1.0e-12);
Assert.assertEquals("slope std err",regressionNoint.getSlopeStdErr(), noInt.getStdErrorOfEstimate(0),1.0E-16);
Assert.assertEquals("number of observations",regressionNoint.getN(), noInt.getN());
Assert.assertEquals("r-square",regressionNoint.getRSquare(), noInt.getRSquared(), 1.0E-16);
Assert.assertEquals("SSR", regressionNoint.getRegressionSumSquares(), noInt.getRegressionSumSquares() ,1.0E-8);
Assert.assertEquals("MSE", regressionNoint.getMeanSquareError(), noInt.getMeanSquareError() ,1.0E-16);
Assert.assertEquals("SSE", regressionNoint.getSumSquaredErrors(), noInt.getErrorSumSquares() ,1.0E-16);
final RegressionResults onlyInt = iface.regress( new int[]{0} );
Assert.assertTrue( onlyInt != null );
Assert.assertEquals("slope", regressionIntOnly.getSlope(), onlyInt.getParameterEstimate(0), 1.0e-12);
Assert.assertEquals("slope std err",regressionIntOnly.getSlopeStdErr(), onlyInt.getStdErrorOfEstimate(0),1.0E-12);
Assert.assertEquals("number of observations",regressionIntOnly.getN(), onlyInt.getN());
Assert.assertEquals("r-square",regressionIntOnly.getRSquare(), onlyInt.getRSquared(), 1.0E-14);
Assert.assertEquals("SSE", regressionIntOnly.getSumSquaredErrors(), onlyInt.getErrorSumSquares() ,1.0E-8);
Assert.assertEquals("SSR", regressionIntOnly.getRegressionSumSquares(), onlyInt.getRegressionSumSquares() ,1.0E-8);
Assert.assertEquals("MSE", regressionIntOnly.getMeanSquareError(), onlyInt.getMeanSquareError() ,1.0E-8);
}
@Test
public void testNoInterceot_noint2(){
SimpleRegression regression = new SimpleRegression(false);
regression.addData(noint2[0][1], noint2[0][0]);
regression.addData(noint2[1][1], noint2[1][0]);
regression.addData(noint2[2][1], noint2[2][0]);
Assert.assertEquals("slope", 0.727272727272727,
Assert.assertEquals("slope", 0.727272727272727,
regression.getSlope(), 10E-12);
Assert.assertEquals("slope std err", 0.420827318078432E-01,
regression.getSlopeStdErr(),10E-12);
@ -128,8 +176,8 @@ public final class SimpleRegressionTest {
Assert.assertEquals("SSE", 0.272727272727273,
regression.getSumSquaredErrors(),10E-9);
}
@Test
@Test
public void testNoIntercept_noint1(){
SimpleRegression regression = new SimpleRegression(false);
for (int i = 0; i < noint1.length; i++) {
@ -147,9 +195,9 @@ public final class SimpleRegressionTest {
regression.getMeanSquareError(), 10E-10);
Assert.assertEquals("SSE", 127.272727272727,
regression.getSumSquaredErrors(),10E-9);
}
}
@Test
public void testNorris() {
SimpleRegression regression = new SimpleRegression();