Refactored data specification in multiple regression api. JIRA: MATH-255. Patched by Mauro Televi.

git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/branches/MATH_2_0@676241 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Phil Steitz 2008-07-12 21:41:17 +00:00
parent b51a782d1b
commit 22d13e1232
8 changed files with 84 additions and 89 deletions

View File

@ -23,7 +23,7 @@
<name>Math</name>
<groupId>commons-math</groupId>
<artifactId>commons-math</artifactId>
<currentVersion>1.2-RC2</currentVersion>
<currentVersion>2.0-SNAPSHOT</currentVersion>
<inceptionYear>2003</inceptionYear>
<shortDescription>Commons Math</shortDescription>
<description>The Math project is a library of lightweight, self-contained mathematics and statistics components addressing the most common practical problems not immediately available in the Java programming language or commons-lang.</description>

View File

@ -34,20 +34,43 @@ public abstract class AbstractMultipleLinearRegression implements
protected RealMatrix Y;
/**
* Adds y sample data.
* Loads model x and y sample data from a flat array of data, overriding any previous sample.
* Assumes that rows are concatenated with y values first in each row.
*
* @param y the [n,1] array representing the y sample
* @param data input data array
* @param nobs number of observations (rows)
* @param nvars number of independent variables (columnns, not counting y)
*/
protected void addYSampleData(double[] y) {
public void newSampleData(double[] data, int nobs, int nvars) {
double[] y = new double[nobs];
double[][] x = new double[nobs][nvars + 1];
int pointer = 0;
for (int i = 0; i < nobs; i++) {
y[i] = data[pointer++];
x[i][0] = 1.0d;
for (int j = 1; j < nvars + 1; j++) {
x[i][j] = data[pointer++];
}
}
this.X = new RealMatrixImpl(x);
this.Y = new RealMatrixImpl(y);
}
/**
* Adds x sample data.
* Loads new y sample data, overriding any previous sample
*
* @param y the [n,1] array representing the y sample
*/
protected void newYSampleData(double[] y) {
this.Y = new RealMatrixImpl(y);
}
/**
* Loads new x sample data, overriding any previous sample
*
* @param x the [n,k] array representing the x sample
*/
protected void addXSampleData(double[][] x) {
protected void newXSampleData(double[][] x) {
this.X = new RealMatrixImpl(x);
}

View File

@ -44,15 +44,12 @@ public class GLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
/** Covariance matrix. */
private RealMatrix Omega;
/**
* {@inheritDoc}
*/
public void addData(double[] y, double[][] x, double[][] covariance) {
public void newSampleData(double[] y, double[][] x, double[][] covariance) {
validateSampleData(x, y);
addYSampleData(y);
addXSampleData(x);
newYSampleData(y);
newXSampleData(x);
validateCovarianceData(x, covariance);
addCovarianceData(covariance);
newCovarianceData(covariance);
}
/**
@ -60,7 +57,7 @@ public class GLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
*
* @param omega the [n,n] array representing the covariance
*/
protected void addCovarianceData(double[][] omega){
protected void newCovarianceData(double[][] omega){
this.Omega = new RealMatrixImpl(omega);
}

View File

@ -32,17 +32,6 @@ package org.apache.commons.math.stat.regression;
*/
public interface MultipleLinearRegression {
/**
* Adds sample and covariance data.
*
* @param y the [n,1] array representing the y sample
* @param x the [n,k] array representing x sample
* @param covariance the [n,n] array representing the covariance matrix or <code>null</code> if not required for the
* specific implementation
* @throws IllegalArgumentException if required data arrays are <code>null</code> or their dimensions are not appropriate
*/
void addData(double[] y, double[][] x, double[][] covariance);
/**
* Estimates the regression parameters b.
*

View File

@ -40,13 +40,10 @@ import org.apache.commons.math.linear.RealMatrix;
*/
public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegression {
/**
* {@inheritDoc}
*/
public void addData(double[] y, double[][] x, double[][] covariance) {
public void newSampleData(double[] y, double[][] x) {
validateSampleData(x, y);
addYSampleData(y);
addXSampleData(x);
newYSampleData(y);
newXSampleData(x);
}
/**

View File

@ -62,44 +62,4 @@ public abstract class AbstractMultipleLinearRegressionTest {
assertTrue(variance > 0.0);
}
@Test(expected=IllegalArgumentException.class)
public void cannotAddXSampleData() {
regression.addData(new double[]{}, null, null);
}
@Test(expected=IllegalArgumentException.class)
public void cannotAddNullYSampleData() {
regression.addData(null, new double[][]{}, null);
}
@Test(expected=IllegalArgumentException.class)
public void cannotAddSampleDataWithSizeMismatch() {
double[] y = new double[]{1.0, 2.0};
double[][] x = new double[1][];
x[0] = new double[]{1.0, 0};
regression.addData(y, x, null);
}
/**
* Loads model Y[] and X[][] arrays from a flat array of data.
* Assumes that rows are concatenated with y values first in each row.
*
* @param data input data array
* @param y vector of y values to be filled
* @param x matrix of x values to be filled
* @param nobs number of observations (rows)
* @param nvars number of independent variables (columnns, not counting y)
*/
protected void loadModelData(double[] data, double[] y, double[][] x, int nobs, int nvars) {
int pointer = 0;
for (int i = 0; i < nobs; i++) {
y[i] = data[pointer++];
x[i][0] = 1.0d;
for (int j = 1; j < nvars + 1; j++) {
x[i][j] = data[pointer++];
}
}
}
}

View File

@ -45,10 +45,27 @@ public class GLSMultipleLinearRegressionTest extends AbstractMultipleLinearRegre
super.setUp();
}
@Test(expected=IllegalArgumentException.class)
public void cannotAddXSampleData() {
createRegression().newSampleData(new double[]{}, null, null);
}
@Test(expected=IllegalArgumentException.class)
public void cannotAddNullYSampleData() {
createRegression().newSampleData(null, new double[][]{}, null);
}
@Test(expected=IllegalArgumentException.class)
public void cannotAddSampleDataWithSizeMismatch() {
double[] y = new double[]{1.0, 2.0};
double[][] x = new double[1][];
x[0] = new double[]{1.0, 0};
createRegression().newSampleData(y, x, null);
}
@Test(expected=IllegalArgumentException.class)
public void cannotAddNullCovarianceData() {
regression.addData(new double[]{}, new double[][]{}, null);
createRegression().newSampleData(new double[]{}, new double[][]{}, null);
}
@Test(expected=IllegalArgumentException.class)
@ -59,7 +76,7 @@ public class GLSMultipleLinearRegressionTest extends AbstractMultipleLinearRegre
x[1] = new double[]{0, 1.0};
double[][] omega = new double[1][];
omega[0] = new double[]{1.0, 0};
regression.addData(y, x, omega);
createRegression().newSampleData(y, x, omega);
}
@Test(expected=IllegalArgumentException.class)
@ -72,12 +89,12 @@ public class GLSMultipleLinearRegressionTest extends AbstractMultipleLinearRegre
omega[0] = new double[]{1.0, 0};
omega[1] = new double[]{0, 1.0};
omega[2] = new double[]{0, 2.0};
regression.addData(y, x, omega);
createRegression().newSampleData(y, x, omega);
}
protected MultipleLinearRegression createRegression() {
MultipleLinearRegression regression = new GLSMultipleLinearRegression();
regression.addData(y, x, omega);
protected GLSMultipleLinearRegression createRegression() {
GLSMultipleLinearRegression regression = new GLSMultipleLinearRegression();
regression.newSampleData(y, x, omega);
return regression;
}

View File

@ -38,9 +38,9 @@ public class OLSMultipleLinearRegressionTest extends AbstractMultipleLinearRegre
super.setUp();
}
protected MultipleLinearRegression createRegression() {
MultipleLinearRegression regression = new OLSMultipleLinearRegression();
regression.addData(y, x, null);
protected OLSMultipleLinearRegression createRegression() {
OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression();
regression.newSampleData(y, x);
return regression;
}
@ -52,6 +52,24 @@ public class OLSMultipleLinearRegressionTest extends AbstractMultipleLinearRegre
return y.length;
}
@Test(expected=IllegalArgumentException.class)
public void cannotAddXSampleData() {
createRegression().newSampleData(new double[]{}, null);
}
@Test(expected=IllegalArgumentException.class)
public void cannotAddNullYSampleData() {
createRegression().newSampleData(null, new double[][]{});
}
@Test(expected=IllegalArgumentException.class)
public void cannotAddSampleDataWithSizeMismatch() {
double[] y = new double[]{1.0, 2.0};
double[][] x = new double[1][];
x[0] = new double[]{1.0, 0};
createRegression().newSampleData(y, x);
}
@Test
public void testPerfectFit() {
double[] betaHat = regression.estimateRegressionParameters();
@ -102,13 +120,10 @@ public class OLSMultipleLinearRegressionTest extends AbstractMultipleLinearRegre
// Transform to Y and X required by interface
int nobs = 16;
int nvars = 6;
double[] y = new double[nobs];
double[][] x = new double[nobs][nvars + 1];
loadModelData(design, y, x, nobs, nvars);
// Estimate the model
MultipleLinearRegression model = new OLSMultipleLinearRegression();
model.addData(y, x, null);
OLSMultipleLinearRegression model = new OLSMultipleLinearRegression();
model.newSampleData(design, nobs, nvars);
// Check expected beta values from NIST
double[] betaHat = model.estimateRegressionParameters();
@ -193,13 +208,10 @@ public class OLSMultipleLinearRegressionTest extends AbstractMultipleLinearRegre
// Transform to Y and X required by interface
int nobs = 47;
int nvars = 4;
double[] y = new double[nobs];
double[][] x = new double[nobs][nvars + 1];
loadModelData(design, y, x, nobs, nvars);
// Estimate the model
MultipleLinearRegression model = new OLSMultipleLinearRegression();
model.addData(y, x, null);
OLSMultipleLinearRegression model = new OLSMultipleLinearRegression();
model.newSampleData(design, nobs, nvars);
// Check expected beta values from R
double[] betaHat = model.estimateRegressionParameters();