From 6c64326a1757e5ab237089116ff4bebc858e7f38 Mon Sep 17 00:00:00 2001 From: Luc Maisonobe Date: Wed, 6 May 2009 09:40:13 +0000 Subject: [PATCH] replaced matrix by vector where possible git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@772114 13f79535-47bb-0310-9956-ffa450edef68 --- .../GLSMultipleLinearRegression.java | 11 ++++---- .../OLSMultipleLinearRegression.java | 28 +++++++++---------- .../OLSMultipleLinearRegressionTest.java | 4 +-- 3 files changed, 21 insertions(+), 22 deletions(-) diff --git a/src/java/org/apache/commons/math/stat/regression/GLSMultipleLinearRegression.java b/src/java/org/apache/commons/math/stat/regression/GLSMultipleLinearRegression.java index 15e5e08f9..c76fd431b 100644 --- a/src/java/org/apache/commons/math/stat/regression/GLSMultipleLinearRegression.java +++ b/src/java/org/apache/commons/math/stat/regression/GLSMultipleLinearRegression.java @@ -18,6 +18,7 @@ package org.apache.commons.math.stat.regression; import org.apache.commons.math.linear.RealMatrix; import org.apache.commons.math.linear.RealMatrixImpl; +import org.apache.commons.math.linear.RealVector; import org.apache.commons.math.linear.decomposition.LUDecompositionImpl; @@ -91,12 +92,12 @@ public class GLSMultipleLinearRegression extends AbstractMultipleLinearRegressio * @return beta */ @Override - protected RealMatrix calculateBeta() { + protected RealVector calculateBeta() { RealMatrix OI = getOmegaInverse(); RealMatrix XT = X.transpose(); RealMatrix XTOIX = XT.multiply(OI).multiply(X); RealMatrix inverse = new LUDecompositionImpl(XTOIX).getSolver().getInverse(); - return inverse.multiply(XT).multiply(OI).multiply(Y); + return inverse.multiply(XT).multiply(OI).operate(Y); } /** @@ -122,9 +123,9 @@ public class GLSMultipleLinearRegression extends AbstractMultipleLinearRegressio */ @Override protected double calculateYVariance() { - RealMatrix u = calculateResiduals(); - RealMatrix sse = u.transpose().multiply(getOmegaInverse()).multiply(u); - return sse.getTrace()/(X.getRowDimension()-X.getColumnDimension()); + RealVector residuals = calculateResiduals(); + double t = residuals.dotProduct(getOmegaInverse().operate(residuals)); + return t / (X.getRowDimension() - X.getColumnDimension()); } } diff --git a/src/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java b/src/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java index 653520fe0..c4da4a3c5 100644 --- a/src/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java +++ b/src/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java @@ -18,6 +18,8 @@ package org.apache.commons.math.stat.regression; import org.apache.commons.math.linear.RealMatrix; import org.apache.commons.math.linear.RealMatrixImpl; +import org.apache.commons.math.linear.RealVector; +import org.apache.commons.math.linear.RealVectorImpl; import org.apache.commons.math.linear.decomposition.LUDecompositionImpl; import org.apache.commons.math.linear.decomposition.QRDecomposition; import org.apache.commons.math.linear.decomposition.QRDecompositionImpl; @@ -137,8 +139,8 @@ public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegressio * @return beta */ @Override - protected RealMatrix calculateBeta() { - return solveUpperTriangular(qr.getR(), qr.getQ().transpose().multiply(Y)); + protected RealVector calculateBeta() { + return solveUpperTriangular(qr.getR(), qr.getQ().transpose().operate(Y)); } /** @@ -170,9 +172,9 @@ public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegressio */ @Override protected double calculateYVariance() { - RealMatrix u = calculateResiduals(); - RealMatrix sse = u.transpose().multiply(u); - return sse.getTrace()/(X.getRowDimension()-X.getColumnDimension()); + RealVector residuals = calculateResiduals(); + return residuals.dotProduct(residuals) / + (X.getRowDimension() - X.getColumnDimension()); } /** TODO: Find a home for the following methods in the linear package */ @@ -191,20 +193,16 @@ public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegressio * Similarly, extra (zero) rows in coefficients are ignored

* * @param coefficients upper-triangular coefficients matrix - * @param constants column RHS constants matrix - * @return solution matrix as a column matrix + * @param constants column RHS constants vector + * @return solution matrix as a column vector * */ - private static RealMatrix solveUpperTriangular(RealMatrix coefficients, - RealMatrix constants) { + private static RealVector solveUpperTriangular(RealMatrix coefficients, + RealVector constants) { if (!isUpperTriangular(coefficients, 1E-12)) { throw new IllegalArgumentException( "Coefficients is not upper-triangular"); } - if (constants.getColumnDimension() != 1) { - throw new IllegalArgumentException( - "Constants not a column matrix."); - } int length = coefficients.getColumnDimension(); double x[] = new double[length]; for (int i = 0; i < length; i++) { @@ -213,9 +211,9 @@ public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegressio for (int j = index + 1; j < length; j++) { sum += coefficients.getEntry(index, j) * x[j]; } - x[index] = (constants.getEntry(index, 0) - sum) / coefficients.getEntry(index, index); + x[index] = (constants.getEntry(index) - sum) / coefficients.getEntry(index, index); } - return new RealMatrixImpl(x); + return new RealVectorImpl(x); } /** diff --git a/src/test/org/apache/commons/math/stat/regression/OLSMultipleLinearRegressionTest.java b/src/test/org/apache/commons/math/stat/regression/OLSMultipleLinearRegressionTest.java index f7697fdb0..6d68e2cd6 100644 --- a/src/test/org/apache/commons/math/stat/regression/OLSMultipleLinearRegressionTest.java +++ b/src/test/org/apache/commons/math/stat/regression/OLSMultipleLinearRegressionTest.java @@ -139,7 +139,7 @@ public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs new double[]{-3482258.63459582, 15.0618722713733, -0.358191792925910E-01,-2.02022980381683, -1.03322686717359,-0.511041056535807E-01, - 1829.15146461355}, 1E-8); // + 1829.15146461355}, 2E-8); // // Check expected residuals from R double[] residuals = model.estimateResiduals(); @@ -332,7 +332,7 @@ public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs */ double[] residuals = model.estimateResiduals(); RealMatrix I = MatrixUtils.createRealIdentityMatrix(10); - double[] hatResiduals = I.subtract(hat).multiply(model.Y).getColumn(0); + double[] hatResiduals = I.subtract(hat).operate(model.Y).getData(); TestUtils.assertEquals(residuals, hatResiduals, 10e-12); } }