Added hat matrix computation.

git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@731166 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Phil Steitz 2009-01-04 04:11:25 +00:00
parent bede64cd3f
commit 4572684320
2 changed files with 109 additions and 1 deletions

View File

@ -17,7 +17,6 @@
package org.apache.commons.math.stat.regression;
import org.apache.commons.math.linear.LUDecompositionImpl;
import org.apache.commons.math.linear.LUSolver;
import org.apache.commons.math.linear.QRDecomposition;
import org.apache.commons.math.linear.QRDecompositionImpl;
import org.apache.commons.math.linear.RealMatrix;
@ -80,6 +79,41 @@ public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegressio
qr = new QRDecompositionImpl(X);
}
/**
* <p>Compute the "hat" matrix.
* </p>
* <p>The hat matrix is defined in terms of the design matrix X
* by X(X^TX)^-1X^T
* <p>
* <p>The implementation here uses the QR decomposition to compute the
* hat matrix as QIpQ^T where Ip is the p-dimensional identity matrix
* augmented by 0's. This computational formula is from "The Hat Matrix
* in Regression and ANOVA", David C. Hoaglin and Roy E. Welsch,
* The American Statistician, Vol. 32, No. 1 (Feb., 1978), pp. 17-22.
*
* @return the hat matrix
*/
public RealMatrix calculateHat() {
// Create augmented identity matrix
RealMatrix Q = qr.getQ();
final int p = qr.getR().getColumnDimension();
final int n = Q.getColumnDimension();
RealMatrixImpl augI = new RealMatrixImpl(n, n);
double[][] augIData = augI.getDataRef();
for (int i = 0; i < n; i++) {
for (int j =0; j < n; j++) {
if (i == j && i < p) {
augIData[i][j] = 1d;
} else {
augIData[i][j] = 0d;
}
}
}
// Compute and return Hat matrix
return Q.multiply(augI).multiply(Q.transpose());
}
/**
* Loads new x sample data, overriding any previous sample
*

View File

@ -17,8 +17,12 @@
package org.apache.commons.math.stat.regression;
import org.apache.commons.math.TestUtils;
import org.apache.commons.math.linear.MatrixUtils;
import org.apache.commons.math.linear.RealMatrix;
import org.apache.commons.math.linear.RealMatrixImpl;
import org.junit.Before;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbstractTest {
@ -243,4 +247,74 @@ public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs
-0.4515205619767598,-10.2916870903837587,-15.7812984571900063},
1E-12);
}
/**
* Test hat matrix computation
*
* @throws Exception
*/
@Test
public void testHat() throws Exception {
/*
* This example is from "The Hat Matrix in Regression and ANOVA",
* David C. Hoaglin and Roy E. Welsch,
* The American Statistician, Vol. 32, No. 1 (Feb., 1978), pp. 17-22.
*
*/
double[] design = new double[] {
11.14, .499, 11.1,
12.74, .558, 8.9,
13.13, .604, 8.8,
11.51, .441, 8.9,
12.38, .550, 8.8,
12.60, .528, 9.9,
11.13, .418, 10.7,
11.7, .480, 10.5,
11.02, .406, 10.5,
11.41, .467, 10.7
};
int nobs = 10;
int nvars = 2;
// Estimate the model
OLSMultipleLinearRegression model = new OLSMultipleLinearRegression();
model.newSampleData(design, nobs, nvars);
RealMatrix hat = model.calculateHat();
// Reference data is upper half of symmetric hat matrix
double[] referenceData = new double[] {
.418, -.002, .079, -.274, -.046, .181, .128, .222, .050, .242,
.242, .292, .136, .243, .128, -.041, .033, -.035, .004,
.417, -.019, .273, .187, -.126, .044, -.153, .004,
.604, .197, -.038, .168, -.022, .275, -.028,
.252, .111, -.030, .019, -.010, -.010,
.148, .042, .117, .012, .111,
.262, .145, .277, .174,
.154, .120, .168,
.315, .148,
.187
};
// Check against reference data and verify symmetry
int k = 0;
for (int i = 0; i < 10; i++) {
for (int j = i; j < 10; j++) {
assertEquals(referenceData[k], hat.getEntry(i, j), 10e-3);
assertEquals(hat.getEntry(i, j), hat.getEntry(j, i), 10e-12);
k++;
}
}
/*
* Verify that residuals computed using the hat matrix are close to
* what we get from direct computation, i.e. r = (I - H) y
*/
double[] residuals = model.estimateResiduals();
RealMatrix I = MatrixUtils.createRealIdentityMatrix(10);
double[] hatResiduals = I.subtract(hat).multiply(model.Y).getColumn(0);
TestUtils.assertEquals(residuals, hatResiduals, 10e-12);
}
}