Refactored data specification in multiple regression api. JIRA: MATH-255. Patched by Mauro Televi.
git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/branches/MATH_2_0@676241 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
b51a782d1b
commit
22d13e1232
|
@ -23,7 +23,7 @@
|
||||||
<name>Math</name>
|
<name>Math</name>
|
||||||
<groupId>commons-math</groupId>
|
<groupId>commons-math</groupId>
|
||||||
<artifactId>commons-math</artifactId>
|
<artifactId>commons-math</artifactId>
|
||||||
<currentVersion>1.2-RC2</currentVersion>
|
<currentVersion>2.0-SNAPSHOT</currentVersion>
|
||||||
<inceptionYear>2003</inceptionYear>
|
<inceptionYear>2003</inceptionYear>
|
||||||
<shortDescription>Commons Math</shortDescription>
|
<shortDescription>Commons Math</shortDescription>
|
||||||
<description>The Math project is a library of lightweight, self-contained mathematics and statistics components addressing the most common practical problems not immediately available in the Java programming language or commons-lang.</description>
|
<description>The Math project is a library of lightweight, self-contained mathematics and statistics components addressing the most common practical problems not immediately available in the Java programming language or commons-lang.</description>
|
||||||
|
|
|
@ -34,20 +34,43 @@ public abstract class AbstractMultipleLinearRegression implements
|
||||||
protected RealMatrix Y;
|
protected RealMatrix Y;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Adds y sample data.
|
* Loads model x and y sample data from a flat array of data, overriding any previous sample.
|
||||||
|
* Assumes that rows are concatenated with y values first in each row.
|
||||||
|
*
|
||||||
|
* @param data input data array
|
||||||
|
* @param nobs number of observations (rows)
|
||||||
|
* @param nvars number of independent variables (columnns, not counting y)
|
||||||
|
*/
|
||||||
|
public void newSampleData(double[] data, int nobs, int nvars) {
|
||||||
|
double[] y = new double[nobs];
|
||||||
|
double[][] x = new double[nobs][nvars + 1];
|
||||||
|
int pointer = 0;
|
||||||
|
for (int i = 0; i < nobs; i++) {
|
||||||
|
y[i] = data[pointer++];
|
||||||
|
x[i][0] = 1.0d;
|
||||||
|
for (int j = 1; j < nvars + 1; j++) {
|
||||||
|
x[i][j] = data[pointer++];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
this.X = new RealMatrixImpl(x);
|
||||||
|
this.Y = new RealMatrixImpl(y);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Loads new y sample data, overriding any previous sample
|
||||||
*
|
*
|
||||||
* @param y the [n,1] array representing the y sample
|
* @param y the [n,1] array representing the y sample
|
||||||
*/
|
*/
|
||||||
protected void addYSampleData(double[] y) {
|
protected void newYSampleData(double[] y) {
|
||||||
this.Y = new RealMatrixImpl(y);
|
this.Y = new RealMatrixImpl(y);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Adds x sample data.
|
* Loads new x sample data, overriding any previous sample
|
||||||
*
|
*
|
||||||
* @param x the [n,k] array representing the x sample
|
* @param x the [n,k] array representing the x sample
|
||||||
*/
|
*/
|
||||||
protected void addXSampleData(double[][] x) {
|
protected void newXSampleData(double[][] x) {
|
||||||
this.X = new RealMatrixImpl(x);
|
this.X = new RealMatrixImpl(x);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -44,15 +44,12 @@ public class GLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
|
||||||
/** Covariance matrix. */
|
/** Covariance matrix. */
|
||||||
private RealMatrix Omega;
|
private RealMatrix Omega;
|
||||||
|
|
||||||
/**
|
public void newSampleData(double[] y, double[][] x, double[][] covariance) {
|
||||||
* {@inheritDoc}
|
|
||||||
*/
|
|
||||||
public void addData(double[] y, double[][] x, double[][] covariance) {
|
|
||||||
validateSampleData(x, y);
|
validateSampleData(x, y);
|
||||||
addYSampleData(y);
|
newYSampleData(y);
|
||||||
addXSampleData(x);
|
newXSampleData(x);
|
||||||
validateCovarianceData(x, covariance);
|
validateCovarianceData(x, covariance);
|
||||||
addCovarianceData(covariance);
|
newCovarianceData(covariance);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -60,7 +57,7 @@ public class GLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
|
||||||
*
|
*
|
||||||
* @param omega the [n,n] array representing the covariance
|
* @param omega the [n,n] array representing the covariance
|
||||||
*/
|
*/
|
||||||
protected void addCovarianceData(double[][] omega){
|
protected void newCovarianceData(double[][] omega){
|
||||||
this.Omega = new RealMatrixImpl(omega);
|
this.Omega = new RealMatrixImpl(omega);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -32,17 +32,6 @@ package org.apache.commons.math.stat.regression;
|
||||||
*/
|
*/
|
||||||
public interface MultipleLinearRegression {
|
public interface MultipleLinearRegression {
|
||||||
|
|
||||||
/**
|
|
||||||
* Adds sample and covariance data.
|
|
||||||
*
|
|
||||||
* @param y the [n,1] array representing the y sample
|
|
||||||
* @param x the [n,k] array representing x sample
|
|
||||||
* @param covariance the [n,n] array representing the covariance matrix or <code>null</code> if not required for the
|
|
||||||
* specific implementation
|
|
||||||
* @throws IllegalArgumentException if required data arrays are <code>null</code> or their dimensions are not appropriate
|
|
||||||
*/
|
|
||||||
void addData(double[] y, double[][] x, double[][] covariance);
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Estimates the regression parameters b.
|
* Estimates the regression parameters b.
|
||||||
*
|
*
|
||||||
|
|
|
@ -40,13 +40,10 @@ import org.apache.commons.math.linear.RealMatrix;
|
||||||
*/
|
*/
|
||||||
public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegression {
|
public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegression {
|
||||||
|
|
||||||
/**
|
public void newSampleData(double[] y, double[][] x) {
|
||||||
* {@inheritDoc}
|
|
||||||
*/
|
|
||||||
public void addData(double[] y, double[][] x, double[][] covariance) {
|
|
||||||
validateSampleData(x, y);
|
validateSampleData(x, y);
|
||||||
addYSampleData(y);
|
newYSampleData(y);
|
||||||
addXSampleData(x);
|
newXSampleData(x);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -61,45 +61,5 @@ public abstract class AbstractMultipleLinearRegressionTest {
|
||||||
double variance = regression.estimateRegressandVariance();
|
double variance = regression.estimateRegressandVariance();
|
||||||
assertTrue(variance > 0.0);
|
assertTrue(variance > 0.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(expected=IllegalArgumentException.class)
|
|
||||||
public void cannotAddXSampleData() {
|
|
||||||
regression.addData(new double[]{}, null, null);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test(expected=IllegalArgumentException.class)
|
|
||||||
public void cannotAddNullYSampleData() {
|
|
||||||
regression.addData(null, new double[][]{}, null);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test(expected=IllegalArgumentException.class)
|
|
||||||
public void cannotAddSampleDataWithSizeMismatch() {
|
|
||||||
double[] y = new double[]{1.0, 2.0};
|
|
||||||
double[][] x = new double[1][];
|
|
||||||
x[0] = new double[]{1.0, 0};
|
|
||||||
regression.addData(y, x, null);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Loads model Y[] and X[][] arrays from a flat array of data.
|
|
||||||
* Assumes that rows are concatenated with y values first in each row.
|
|
||||||
*
|
|
||||||
* @param data input data array
|
|
||||||
* @param y vector of y values to be filled
|
|
||||||
* @param x matrix of x values to be filled
|
|
||||||
* @param nobs number of observations (rows)
|
|
||||||
* @param nvars number of independent variables (columnns, not counting y)
|
|
||||||
*/
|
|
||||||
protected void loadModelData(double[] data, double[] y, double[][] x, int nobs, int nvars) {
|
|
||||||
int pointer = 0;
|
|
||||||
for (int i = 0; i < nobs; i++) {
|
|
||||||
y[i] = data[pointer++];
|
|
||||||
x[i][0] = 1.0d;
|
|
||||||
for (int j = 1; j < nvars + 1; j++) {
|
|
||||||
x[i][j] = data[pointer++];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -45,10 +45,27 @@ public class GLSMultipleLinearRegressionTest extends AbstractMultipleLinearRegre
|
||||||
super.setUp();
|
super.setUp();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test(expected=IllegalArgumentException.class)
|
||||||
|
public void cannotAddXSampleData() {
|
||||||
|
createRegression().newSampleData(new double[]{}, null, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test(expected=IllegalArgumentException.class)
|
||||||
|
public void cannotAddNullYSampleData() {
|
||||||
|
createRegression().newSampleData(null, new double[][]{}, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test(expected=IllegalArgumentException.class)
|
||||||
|
public void cannotAddSampleDataWithSizeMismatch() {
|
||||||
|
double[] y = new double[]{1.0, 2.0};
|
||||||
|
double[][] x = new double[1][];
|
||||||
|
x[0] = new double[]{1.0, 0};
|
||||||
|
createRegression().newSampleData(y, x, null);
|
||||||
|
}
|
||||||
|
|
||||||
@Test(expected=IllegalArgumentException.class)
|
@Test(expected=IllegalArgumentException.class)
|
||||||
public void cannotAddNullCovarianceData() {
|
public void cannotAddNullCovarianceData() {
|
||||||
regression.addData(new double[]{}, new double[][]{}, null);
|
createRegression().newSampleData(new double[]{}, new double[][]{}, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(expected=IllegalArgumentException.class)
|
@Test(expected=IllegalArgumentException.class)
|
||||||
|
@ -59,7 +76,7 @@ public class GLSMultipleLinearRegressionTest extends AbstractMultipleLinearRegre
|
||||||
x[1] = new double[]{0, 1.0};
|
x[1] = new double[]{0, 1.0};
|
||||||
double[][] omega = new double[1][];
|
double[][] omega = new double[1][];
|
||||||
omega[0] = new double[]{1.0, 0};
|
omega[0] = new double[]{1.0, 0};
|
||||||
regression.addData(y, x, omega);
|
createRegression().newSampleData(y, x, omega);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(expected=IllegalArgumentException.class)
|
@Test(expected=IllegalArgumentException.class)
|
||||||
|
@ -72,12 +89,12 @@ public class GLSMultipleLinearRegressionTest extends AbstractMultipleLinearRegre
|
||||||
omega[0] = new double[]{1.0, 0};
|
omega[0] = new double[]{1.0, 0};
|
||||||
omega[1] = new double[]{0, 1.0};
|
omega[1] = new double[]{0, 1.0};
|
||||||
omega[2] = new double[]{0, 2.0};
|
omega[2] = new double[]{0, 2.0};
|
||||||
regression.addData(y, x, omega);
|
createRegression().newSampleData(y, x, omega);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected MultipleLinearRegression createRegression() {
|
protected GLSMultipleLinearRegression createRegression() {
|
||||||
MultipleLinearRegression regression = new GLSMultipleLinearRegression();
|
GLSMultipleLinearRegression regression = new GLSMultipleLinearRegression();
|
||||||
regression.addData(y, x, omega);
|
regression.newSampleData(y, x, omega);
|
||||||
return regression;
|
return regression;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -38,9 +38,9 @@ public class OLSMultipleLinearRegressionTest extends AbstractMultipleLinearRegre
|
||||||
super.setUp();
|
super.setUp();
|
||||||
}
|
}
|
||||||
|
|
||||||
protected MultipleLinearRegression createRegression() {
|
protected OLSMultipleLinearRegression createRegression() {
|
||||||
MultipleLinearRegression regression = new OLSMultipleLinearRegression();
|
OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression();
|
||||||
regression.addData(y, x, null);
|
regression.newSampleData(y, x);
|
||||||
return regression;
|
return regression;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -52,6 +52,24 @@ public class OLSMultipleLinearRegressionTest extends AbstractMultipleLinearRegre
|
||||||
return y.length;
|
return y.length;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test(expected=IllegalArgumentException.class)
|
||||||
|
public void cannotAddXSampleData() {
|
||||||
|
createRegression().newSampleData(new double[]{}, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test(expected=IllegalArgumentException.class)
|
||||||
|
public void cannotAddNullYSampleData() {
|
||||||
|
createRegression().newSampleData(null, new double[][]{});
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test(expected=IllegalArgumentException.class)
|
||||||
|
public void cannotAddSampleDataWithSizeMismatch() {
|
||||||
|
double[] y = new double[]{1.0, 2.0};
|
||||||
|
double[][] x = new double[1][];
|
||||||
|
x[0] = new double[]{1.0, 0};
|
||||||
|
createRegression().newSampleData(y, x);
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testPerfectFit() {
|
public void testPerfectFit() {
|
||||||
double[] betaHat = regression.estimateRegressionParameters();
|
double[] betaHat = regression.estimateRegressionParameters();
|
||||||
|
@ -102,13 +120,10 @@ public class OLSMultipleLinearRegressionTest extends AbstractMultipleLinearRegre
|
||||||
// Transform to Y and X required by interface
|
// Transform to Y and X required by interface
|
||||||
int nobs = 16;
|
int nobs = 16;
|
||||||
int nvars = 6;
|
int nvars = 6;
|
||||||
double[] y = new double[nobs];
|
|
||||||
double[][] x = new double[nobs][nvars + 1];
|
|
||||||
loadModelData(design, y, x, nobs, nvars);
|
|
||||||
|
|
||||||
// Estimate the model
|
// Estimate the model
|
||||||
MultipleLinearRegression model = new OLSMultipleLinearRegression();
|
OLSMultipleLinearRegression model = new OLSMultipleLinearRegression();
|
||||||
model.addData(y, x, null);
|
model.newSampleData(design, nobs, nvars);
|
||||||
|
|
||||||
// Check expected beta values from NIST
|
// Check expected beta values from NIST
|
||||||
double[] betaHat = model.estimateRegressionParameters();
|
double[] betaHat = model.estimateRegressionParameters();
|
||||||
|
@ -193,13 +208,10 @@ public class OLSMultipleLinearRegressionTest extends AbstractMultipleLinearRegre
|
||||||
// Transform to Y and X required by interface
|
// Transform to Y and X required by interface
|
||||||
int nobs = 47;
|
int nobs = 47;
|
||||||
int nvars = 4;
|
int nvars = 4;
|
||||||
double[] y = new double[nobs];
|
|
||||||
double[][] x = new double[nobs][nvars + 1];
|
|
||||||
loadModelData(design, y, x, nobs, nvars);
|
|
||||||
|
|
||||||
// Estimate the model
|
// Estimate the model
|
||||||
MultipleLinearRegression model = new OLSMultipleLinearRegression();
|
OLSMultipleLinearRegression model = new OLSMultipleLinearRegression();
|
||||||
model.addData(y, x, null);
|
model.newSampleData(design, nobs, nvars);
|
||||||
|
|
||||||
// Check expected beta values from R
|
// Check expected beta values from R
|
||||||
double[] betaHat = model.estimateRegressionParameters();
|
double[] betaHat = model.estimateRegressionParameters();
|
||||||
|
|
Loading…
Reference in New Issue