diff --git a/src/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java b/src/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java index 3af000366..3fc97f1d8 100644 --- a/src/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java +++ b/src/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java @@ -17,7 +17,6 @@ package org.apache.commons.math.stat.regression; import org.apache.commons.math.linear.LUDecompositionImpl; -import org.apache.commons.math.linear.LUSolver; import org.apache.commons.math.linear.QRDecomposition; import org.apache.commons.math.linear.QRDecompositionImpl; import org.apache.commons.math.linear.RealMatrix; @@ -80,6 +79,41 @@ public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegressio qr = new QRDecompositionImpl(X); } + /** + *
Compute the "hat" matrix. + *
+ *The hat matrix is defined in terms of the design matrix X + * by X(X^TX)^-1X^T + *
+ *
The implementation here uses the QR decomposition to compute the + * hat matrix as QIpQ^T where Ip is the p-dimensional identity matrix + * augmented by 0's. This computational formula is from "The Hat Matrix + * in Regression and ANOVA", David C. Hoaglin and Roy E. Welsch, + * The American Statistician, Vol. 32, No. 1 (Feb., 1978), pp. 17-22. + * + * @return the hat matrix + */ + public RealMatrix calculateHat() { + // Create augmented identity matrix + RealMatrix Q = qr.getQ(); + final int p = qr.getR().getColumnDimension(); + final int n = Q.getColumnDimension(); + RealMatrixImpl augI = new RealMatrixImpl(n, n); + double[][] augIData = augI.getDataRef(); + for (int i = 0; i < n; i++) { + for (int j =0; j < n; j++) { + if (i == j && i < p) { + augIData[i][j] = 1d; + } else { + augIData[i][j] = 0d; + } + } + } + + // Compute and return Hat matrix + return Q.multiply(augI).multiply(Q.transpose()); + } + /** * Loads new x sample data, overriding any previous sample * diff --git a/src/test/org/apache/commons/math/stat/regression/OLSMultipleLinearRegressionTest.java b/src/test/org/apache/commons/math/stat/regression/OLSMultipleLinearRegressionTest.java index 0550d925b..64664c266 100644 --- a/src/test/org/apache/commons/math/stat/regression/OLSMultipleLinearRegressionTest.java +++ b/src/test/org/apache/commons/math/stat/regression/OLSMultipleLinearRegressionTest.java @@ -17,8 +17,12 @@ package org.apache.commons.math.stat.regression; import org.apache.commons.math.TestUtils; +import org.apache.commons.math.linear.MatrixUtils; +import org.apache.commons.math.linear.RealMatrix; +import org.apache.commons.math.linear.RealMatrixImpl; import org.junit.Before; import org.junit.Test; +import static org.junit.Assert.assertEquals; public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbstractTest { @@ -243,4 +247,74 @@ public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs -0.4515205619767598,-10.2916870903837587,-15.7812984571900063}, 1E-12); } + + /** + * Test hat matrix computation + * + * @throws Exception + */ + @Test + public void testHat() throws Exception { + + /* + * This example is from "The Hat Matrix in Regression and ANOVA", + * David C. Hoaglin and Roy E. Welsch, + * The American Statistician, Vol. 32, No. 1 (Feb., 1978), pp. 17-22. + * + */ + double[] design = new double[] { + 11.14, .499, 11.1, + 12.74, .558, 8.9, + 13.13, .604, 8.8, + 11.51, .441, 8.9, + 12.38, .550, 8.8, + 12.60, .528, 9.9, + 11.13, .418, 10.7, + 11.7, .480, 10.5, + 11.02, .406, 10.5, + 11.41, .467, 10.7 + }; + + int nobs = 10; + int nvars = 2; + + // Estimate the model + OLSMultipleLinearRegression model = new OLSMultipleLinearRegression(); + model.newSampleData(design, nobs, nvars); + + RealMatrix hat = model.calculateHat(); + + // Reference data is upper half of symmetric hat matrix + double[] referenceData = new double[] { + .418, -.002, .079, -.274, -.046, .181, .128, .222, .050, .242, + .242, .292, .136, .243, .128, -.041, .033, -.035, .004, + .417, -.019, .273, .187, -.126, .044, -.153, .004, + .604, .197, -.038, .168, -.022, .275, -.028, + .252, .111, -.030, .019, -.010, -.010, + .148, .042, .117, .012, .111, + .262, .145, .277, .174, + .154, .120, .168, + .315, .148, + .187 + }; + + // Check against reference data and verify symmetry + int k = 0; + for (int i = 0; i < 10; i++) { + for (int j = i; j < 10; j++) { + assertEquals(referenceData[k], hat.getEntry(i, j), 10e-3); + assertEquals(hat.getEntry(i, j), hat.getEntry(j, i), 10e-12); + k++; + } + } + + /* + * Verify that residuals computed using the hat matrix are close to + * what we get from direct computation, i.e. r = (I - H) y + */ + double[] residuals = model.estimateResiduals(); + RealMatrix I = MatrixUtils.createRealIdentityMatrix(10); + double[] hatResiduals = I.subtract(hat).multiply(model.Y).getColumn(0); + TestUtils.assertEquals(residuals, hatResiduals, 10e-12); + } }