Refactored data specification in multiple regression api. JIRA: MATH-255. Patched by Mauro Televi.
git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/branches/MATH_2_0@676241 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
b51a782d1b
commit
22d13e1232
|
@ -23,7 +23,7 @@
|
|||
<name>Math</name>
|
||||
<groupId>commons-math</groupId>
|
||||
<artifactId>commons-math</artifactId>
|
||||
<currentVersion>1.2-RC2</currentVersion>
|
||||
<currentVersion>2.0-SNAPSHOT</currentVersion>
|
||||
<inceptionYear>2003</inceptionYear>
|
||||
<shortDescription>Commons Math</shortDescription>
|
||||
<description>The Math project is a library of lightweight, self-contained mathematics and statistics components addressing the most common practical problems not immediately available in the Java programming language or commons-lang.</description>
|
||||
|
|
|
@ -34,20 +34,43 @@ public abstract class AbstractMultipleLinearRegression implements
|
|||
protected RealMatrix Y;
|
||||
|
||||
/**
|
||||
* Adds y sample data.
|
||||
* Loads model x and y sample data from a flat array of data, overriding any previous sample.
|
||||
* Assumes that rows are concatenated with y values first in each row.
|
||||
*
|
||||
* @param y the [n,1] array representing the y sample
|
||||
* @param data input data array
|
||||
* @param nobs number of observations (rows)
|
||||
* @param nvars number of independent variables (columnns, not counting y)
|
||||
*/
|
||||
protected void addYSampleData(double[] y) {
|
||||
public void newSampleData(double[] data, int nobs, int nvars) {
|
||||
double[] y = new double[nobs];
|
||||
double[][] x = new double[nobs][nvars + 1];
|
||||
int pointer = 0;
|
||||
for (int i = 0; i < nobs; i++) {
|
||||
y[i] = data[pointer++];
|
||||
x[i][0] = 1.0d;
|
||||
for (int j = 1; j < nvars + 1; j++) {
|
||||
x[i][j] = data[pointer++];
|
||||
}
|
||||
}
|
||||
this.X = new RealMatrixImpl(x);
|
||||
this.Y = new RealMatrixImpl(y);
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds x sample data.
|
||||
* Loads new y sample data, overriding any previous sample
|
||||
*
|
||||
* @param y the [n,1] array representing the y sample
|
||||
*/
|
||||
protected void newYSampleData(double[] y) {
|
||||
this.Y = new RealMatrixImpl(y);
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads new x sample data, overriding any previous sample
|
||||
*
|
||||
* @param x the [n,k] array representing the x sample
|
||||
*/
|
||||
protected void addXSampleData(double[][] x) {
|
||||
protected void newXSampleData(double[][] x) {
|
||||
this.X = new RealMatrixImpl(x);
|
||||
}
|
||||
|
||||
|
|
|
@ -44,15 +44,12 @@ public class GLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
|
|||
/** Covariance matrix. */
|
||||
private RealMatrix Omega;
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
public void addData(double[] y, double[][] x, double[][] covariance) {
|
||||
public void newSampleData(double[] y, double[][] x, double[][] covariance) {
|
||||
validateSampleData(x, y);
|
||||
addYSampleData(y);
|
||||
addXSampleData(x);
|
||||
newYSampleData(y);
|
||||
newXSampleData(x);
|
||||
validateCovarianceData(x, covariance);
|
||||
addCovarianceData(covariance);
|
||||
newCovarianceData(covariance);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -60,7 +57,7 @@ public class GLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
|
|||
*
|
||||
* @param omega the [n,n] array representing the covariance
|
||||
*/
|
||||
protected void addCovarianceData(double[][] omega){
|
||||
protected void newCovarianceData(double[][] omega){
|
||||
this.Omega = new RealMatrixImpl(omega);
|
||||
}
|
||||
|
||||
|
|
|
@ -32,17 +32,6 @@ package org.apache.commons.math.stat.regression;
|
|||
*/
|
||||
public interface MultipleLinearRegression {
|
||||
|
||||
/**
|
||||
* Adds sample and covariance data.
|
||||
*
|
||||
* @param y the [n,1] array representing the y sample
|
||||
* @param x the [n,k] array representing x sample
|
||||
* @param covariance the [n,n] array representing the covariance matrix or <code>null</code> if not required for the
|
||||
* specific implementation
|
||||
* @throws IllegalArgumentException if required data arrays are <code>null</code> or their dimensions are not appropriate
|
||||
*/
|
||||
void addData(double[] y, double[][] x, double[][] covariance);
|
||||
|
||||
/**
|
||||
* Estimates the regression parameters b.
|
||||
*
|
||||
|
|
|
@ -40,13 +40,10 @@ import org.apache.commons.math.linear.RealMatrix;
|
|||
*/
|
||||
public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegression {
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
public void addData(double[] y, double[][] x, double[][] covariance) {
|
||||
public void newSampleData(double[] y, double[][] x) {
|
||||
validateSampleData(x, y);
|
||||
addYSampleData(y);
|
||||
addXSampleData(x);
|
||||
newYSampleData(y);
|
||||
newXSampleData(x);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -62,44 +62,4 @@ public abstract class AbstractMultipleLinearRegressionTest {
|
|||
assertTrue(variance > 0.0);
|
||||
}
|
||||
|
||||
@Test(expected=IllegalArgumentException.class)
|
||||
public void cannotAddXSampleData() {
|
||||
regression.addData(new double[]{}, null, null);
|
||||
}
|
||||
|
||||
@Test(expected=IllegalArgumentException.class)
|
||||
public void cannotAddNullYSampleData() {
|
||||
regression.addData(null, new double[][]{}, null);
|
||||
}
|
||||
|
||||
@Test(expected=IllegalArgumentException.class)
|
||||
public void cannotAddSampleDataWithSizeMismatch() {
|
||||
double[] y = new double[]{1.0, 2.0};
|
||||
double[][] x = new double[1][];
|
||||
x[0] = new double[]{1.0, 0};
|
||||
regression.addData(y, x, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads model Y[] and X[][] arrays from a flat array of data.
|
||||
* Assumes that rows are concatenated with y values first in each row.
|
||||
*
|
||||
* @param data input data array
|
||||
* @param y vector of y values to be filled
|
||||
* @param x matrix of x values to be filled
|
||||
* @param nobs number of observations (rows)
|
||||
* @param nvars number of independent variables (columnns, not counting y)
|
||||
*/
|
||||
protected void loadModelData(double[] data, double[] y, double[][] x, int nobs, int nvars) {
|
||||
int pointer = 0;
|
||||
for (int i = 0; i < nobs; i++) {
|
||||
y[i] = data[pointer++];
|
||||
x[i][0] = 1.0d;
|
||||
for (int j = 1; j < nvars + 1; j++) {
|
||||
x[i][j] = data[pointer++];
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -45,10 +45,27 @@ public class GLSMultipleLinearRegressionTest extends AbstractMultipleLinearRegre
|
|||
super.setUp();
|
||||
}
|
||||
|
||||
@Test(expected=IllegalArgumentException.class)
|
||||
public void cannotAddXSampleData() {
|
||||
createRegression().newSampleData(new double[]{}, null, null);
|
||||
}
|
||||
|
||||
@Test(expected=IllegalArgumentException.class)
|
||||
public void cannotAddNullYSampleData() {
|
||||
createRegression().newSampleData(null, new double[][]{}, null);
|
||||
}
|
||||
|
||||
@Test(expected=IllegalArgumentException.class)
|
||||
public void cannotAddSampleDataWithSizeMismatch() {
|
||||
double[] y = new double[]{1.0, 2.0};
|
||||
double[][] x = new double[1][];
|
||||
x[0] = new double[]{1.0, 0};
|
||||
createRegression().newSampleData(y, x, null);
|
||||
}
|
||||
|
||||
@Test(expected=IllegalArgumentException.class)
|
||||
public void cannotAddNullCovarianceData() {
|
||||
regression.addData(new double[]{}, new double[][]{}, null);
|
||||
createRegression().newSampleData(new double[]{}, new double[][]{}, null);
|
||||
}
|
||||
|
||||
@Test(expected=IllegalArgumentException.class)
|
||||
|
@ -59,7 +76,7 @@ public class GLSMultipleLinearRegressionTest extends AbstractMultipleLinearRegre
|
|||
x[1] = new double[]{0, 1.0};
|
||||
double[][] omega = new double[1][];
|
||||
omega[0] = new double[]{1.0, 0};
|
||||
regression.addData(y, x, omega);
|
||||
createRegression().newSampleData(y, x, omega);
|
||||
}
|
||||
|
||||
@Test(expected=IllegalArgumentException.class)
|
||||
|
@ -72,12 +89,12 @@ public class GLSMultipleLinearRegressionTest extends AbstractMultipleLinearRegre
|
|||
omega[0] = new double[]{1.0, 0};
|
||||
omega[1] = new double[]{0, 1.0};
|
||||
omega[2] = new double[]{0, 2.0};
|
||||
regression.addData(y, x, omega);
|
||||
createRegression().newSampleData(y, x, omega);
|
||||
}
|
||||
|
||||
protected MultipleLinearRegression createRegression() {
|
||||
MultipleLinearRegression regression = new GLSMultipleLinearRegression();
|
||||
regression.addData(y, x, omega);
|
||||
protected GLSMultipleLinearRegression createRegression() {
|
||||
GLSMultipleLinearRegression regression = new GLSMultipleLinearRegression();
|
||||
regression.newSampleData(y, x, omega);
|
||||
return regression;
|
||||
}
|
||||
|
||||
|
|
|
@ -38,9 +38,9 @@ public class OLSMultipleLinearRegressionTest extends AbstractMultipleLinearRegre
|
|||
super.setUp();
|
||||
}
|
||||
|
||||
protected MultipleLinearRegression createRegression() {
|
||||
MultipleLinearRegression regression = new OLSMultipleLinearRegression();
|
||||
regression.addData(y, x, null);
|
||||
protected OLSMultipleLinearRegression createRegression() {
|
||||
OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression();
|
||||
regression.newSampleData(y, x);
|
||||
return regression;
|
||||
}
|
||||
|
||||
|
@ -52,6 +52,24 @@ public class OLSMultipleLinearRegressionTest extends AbstractMultipleLinearRegre
|
|||
return y.length;
|
||||
}
|
||||
|
||||
@Test(expected=IllegalArgumentException.class)
|
||||
public void cannotAddXSampleData() {
|
||||
createRegression().newSampleData(new double[]{}, null);
|
||||
}
|
||||
|
||||
@Test(expected=IllegalArgumentException.class)
|
||||
public void cannotAddNullYSampleData() {
|
||||
createRegression().newSampleData(null, new double[][]{});
|
||||
}
|
||||
|
||||
@Test(expected=IllegalArgumentException.class)
|
||||
public void cannotAddSampleDataWithSizeMismatch() {
|
||||
double[] y = new double[]{1.0, 2.0};
|
||||
double[][] x = new double[1][];
|
||||
x[0] = new double[]{1.0, 0};
|
||||
createRegression().newSampleData(y, x);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPerfectFit() {
|
||||
double[] betaHat = regression.estimateRegressionParameters();
|
||||
|
@ -102,13 +120,10 @@ public class OLSMultipleLinearRegressionTest extends AbstractMultipleLinearRegre
|
|||
// Transform to Y and X required by interface
|
||||
int nobs = 16;
|
||||
int nvars = 6;
|
||||
double[] y = new double[nobs];
|
||||
double[][] x = new double[nobs][nvars + 1];
|
||||
loadModelData(design, y, x, nobs, nvars);
|
||||
|
||||
// Estimate the model
|
||||
MultipleLinearRegression model = new OLSMultipleLinearRegression();
|
||||
model.addData(y, x, null);
|
||||
OLSMultipleLinearRegression model = new OLSMultipleLinearRegression();
|
||||
model.newSampleData(design, nobs, nvars);
|
||||
|
||||
// Check expected beta values from NIST
|
||||
double[] betaHat = model.estimateRegressionParameters();
|
||||
|
@ -193,13 +208,10 @@ public class OLSMultipleLinearRegressionTest extends AbstractMultipleLinearRegre
|
|||
// Transform to Y and X required by interface
|
||||
int nobs = 47;
|
||||
int nvars = 4;
|
||||
double[] y = new double[nobs];
|
||||
double[][] x = new double[nobs][nvars + 1];
|
||||
loadModelData(design, y, x, nobs, nvars);
|
||||
|
||||
// Estimate the model
|
||||
MultipleLinearRegression model = new OLSMultipleLinearRegression();
|
||||
model.addData(y, x, null);
|
||||
OLSMultipleLinearRegression model = new OLSMultipleLinearRegression();
|
||||
model.newSampleData(design, nobs, nvars);
|
||||
|
||||
// Check expected beta values from R
|
||||
double[] betaHat = model.estimateRegressionParameters();
|
||||
|
|
Loading…
Reference in New Issue