reverted some changes introduced yesterday, as they lead to unexpected test failures

git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@728500 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Luc Maisonobe 2008-12-21 21:04:47 +00:00
parent 7716a7ac0d
commit c060390793
4 changed files with 39 additions and 35 deletions

View File

@ -16,10 +16,8 @@
*/
package org.apache.commons.math.stat.regression;
import org.apache.commons.math.linear.MatrixUtils;
import org.apache.commons.math.linear.RealMatrix;
import org.apache.commons.math.linear.RealVector;
import org.apache.commons.math.linear.RealVectorImpl;
import org.apache.commons.math.linear.RealMatrixImpl;
/**
* Abstract base class for implementations of MultipleLinearRegression.
@ -33,7 +31,7 @@ public abstract class AbstractMultipleLinearRegression implements
protected RealMatrix X;
/** Y sample data. */
protected RealVector Y;
protected RealMatrix Y;
/**
* Loads model x and y sample data from a flat array of data, overriding any previous sample.
@ -54,8 +52,8 @@ public abstract class AbstractMultipleLinearRegression implements
x[i][j] = data[pointer++];
}
}
this.X = MatrixUtils.createRealMatrix(x);
this.Y = new RealVectorImpl(y);
this.X = new RealMatrixImpl(x);
this.Y = new RealMatrixImpl(y);
}
/**
@ -64,7 +62,7 @@ public abstract class AbstractMultipleLinearRegression implements
* @param y the [n,1] array representing the y sample
*/
protected void newYSampleData(double[] y) {
this.Y = new RealVectorImpl(y);
this.Y = new RealMatrixImpl(y);
}
/**
@ -73,7 +71,7 @@ public abstract class AbstractMultipleLinearRegression implements
* @param x the [n,k] array representing the x sample
*/
protected void newXSampleData(double[][] x) {
this.X = MatrixUtils.createRealMatrix(x);
this.X = new RealMatrixImpl(x);
}
/**
@ -122,14 +120,17 @@ public abstract class AbstractMultipleLinearRegression implements
* {@inheritDoc}
*/
public double[] estimateRegressionParameters() {
return calculateBeta().getData();
RealMatrix b = calculateBeta();
return b.getColumn(0);
}
/**
* {@inheritDoc}
*/
public double[] estimateResiduals() {
return Y.subtract(X.operate(calculateBeta())).getData();
RealMatrix b = calculateBeta();
RealMatrix e = Y.subtract(X.multiply(b));
return e.getColumn(0);
}
/**
@ -151,7 +152,7 @@ public abstract class AbstractMultipleLinearRegression implements
*
* @return beta
*/
protected abstract RealVector calculateBeta();
protected abstract RealMatrix calculateBeta();
/**
* Calculates the beta variance of multiple linear regression in matrix
@ -178,8 +179,9 @@ public abstract class AbstractMultipleLinearRegression implements
*
* @return The residuals [n,1] matrix
*/
protected RealVector calculateResiduals() {
return Y.subtract(X.operate(calculateBeta()));
protected RealMatrix calculateResiduals() {
RealMatrix b = calculateBeta();
return Y.subtract(X.multiply(b));
}
}

View File

@ -18,9 +18,8 @@ package org.apache.commons.math.stat.regression;
import org.apache.commons.math.linear.LUDecompositionImpl;
import org.apache.commons.math.linear.LUSolver;
import org.apache.commons.math.linear.MatrixUtils;
import org.apache.commons.math.linear.RealMatrix;
import org.apache.commons.math.linear.RealVector;
import org.apache.commons.math.linear.RealMatrixImpl;
/**
@ -69,7 +68,7 @@ public class GLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
* @param omega the [n,n] array representing the covariance
*/
protected void newCovarianceData(double[][] omega){
this.Omega = MatrixUtils.createRealMatrix(omega);
this.Omega = new RealMatrixImpl(omega);
this.OmegaInverse = null;
}
@ -92,12 +91,12 @@ public class GLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
* </pre>
* @return beta
*/
protected RealVector calculateBeta() {
protected RealMatrix calculateBeta() {
RealMatrix OI = getOmegaInverse();
RealMatrix XT = X.transpose();
RealMatrix XTOIX = XT.multiply(OI).multiply(X);
RealMatrix inverse = new LUSolver(new LUDecompositionImpl(XTOIX)).getInverse();
return inverse.multiply(XT).multiply(OI).operate(Y);
return inverse.multiply(XT).multiply(OI).multiply(Y);
}
/**
@ -121,9 +120,9 @@ public class GLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
* @return The Y variance
*/
protected double calculateYVariance() {
final RealVector u = calculateResiduals();
final double sse = u.dotProduct(getOmegaInverse().operate(u));
return sse / (X.getRowDimension() - X.getColumnDimension());
RealMatrix u = calculateResiduals();
RealMatrix sse = u.transpose().multiply(getOmegaInverse()).multiply(u);
return sse.getTrace()/(X.getRowDimension()-X.getColumnDimension());
}
}

View File

@ -16,14 +16,12 @@
*/
package org.apache.commons.math.stat.regression;
import org.apache.commons.math.linear.DenseRealMatrix;
import org.apache.commons.math.linear.LUDecompositionImpl;
import org.apache.commons.math.linear.LUSolver;
import org.apache.commons.math.linear.QRDecomposition;
import org.apache.commons.math.linear.QRDecompositionImpl;
import org.apache.commons.math.linear.RealMatrix;
import org.apache.commons.math.linear.RealVector;
import org.apache.commons.math.linear.RealVectorImpl;
import org.apache.commons.math.linear.RealMatrixImpl;
/**
* <p>Implements ordinary least squares (OLS) to estimate the parameters of a
@ -88,7 +86,7 @@ public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
* @param x the [n,k] array representing the x sample
*/
protected void newXSampleData(double[][] x) {
this.X = new DenseRealMatrix(x);
this.X = new RealMatrixImpl(x);
qr = new QRDecompositionImpl(X);
}
@ -97,8 +95,8 @@ public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
*
* @return beta
*/
protected RealVector calculateBeta() {
return solveUpperTriangular(qr.getR(), qr.getQ().transpose().operate(Y));
protected RealMatrix calculateBeta() {
return solveUpperTriangular(qr.getR(), qr.getQ().transpose().multiply(Y));
}
/**
@ -122,9 +120,9 @@ public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
* @return The Y variance
*/
protected double calculateYVariance() {
final RealVector u = calculateResiduals();
final double sse = u.dotProduct(u);
return sse / (X.getRowDimension() - X.getColumnDimension());
RealMatrix u = calculateResiduals();
RealMatrix sse = u.transpose().multiply(u);
return sse.getTrace()/(X.getRowDimension()-X.getColumnDimension());
}
/** TODO: Find a home for the following methods in the linear package */
@ -144,14 +142,19 @@ public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
*
* @param coefficients upper-triangular coefficients matrix
* @param constants column RHS constants matrix
* @return solution matrix as a vector
* @return solution matrix as a column matrix
*
*/
private static RealVector solveUpperTriangular(RealMatrix coefficients, RealVector constants) {
private static RealMatrix solveUpperTriangular(RealMatrix coefficients,
RealMatrix constants) {
if (!isUpperTriangular(coefficients, 1E-12)) {
throw new IllegalArgumentException(
"Coefficients is not upper-triangular");
}
if (constants.getColumnDimension() != 1) {
throw new IllegalArgumentException(
"Constants not a column matrix.");
}
int length = coefficients.getColumnDimension();
double x[] = new double[length];
for (int i = 0; i < length; i++) {
@ -160,9 +163,9 @@ public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
for (int j = index + 1; j < length; j++) {
sum += coefficients.getEntry(index, j) * x[j];
}
x[index] = (constants.getEntry(index) - sum) / coefficients.getEntry(index, index);
x[index] = (constants.getEntry(index, 0) - sum) / coefficients.getEntry(index, index);
}
return new RealVectorImpl(x);
return new RealMatrixImpl(x);
}
/**

View File

@ -63,7 +63,7 @@ public class GLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs
createRegression().newSampleData(y, x, null);
}
@Test(expected=ArrayIndexOutOfBoundsException.class)
@Test(expected=IllegalArgumentException.class)
public void cannotAddNullCovarianceData() {
createRegression().newSampleData(new double[]{}, new double[][]{}, null);
}