Started work on JAMA-like interfaces.

This first step is an enhancement of the existing QR-decomposition interface and associated implementation in JAMA-style, i.e with added getH() method and most importantly various solve methods for least-squares solution of the A * X = B equation.
JIRA: MATH-220

git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/branches/MATH_2_0@687167 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Luc Maisonobe 2008-08-19 21:41:17 +00:00
parent 07e312f0c3
commit 3dd6fe1807
6 changed files with 522 additions and 64 deletions

View File

@ -0,0 +1,69 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math.linear;
import java.io.Serializable;
/**
* A base interface to decomposition algorithms that can solve A × X = B.
* <p>This interface is the common base of decomposition algorithms like
* {@link QRDecomposition} or {@link LUDecomposition}. All these algorithms
* decompose an A matrix has a product of several specific matrices from
* which they can solve A &times; X = B.</p>
* <p>Depending on the solver, the solution is either an exact linear solution
* or a least squares solution. When an exact linear solution exist, both the
* linear and the least squares solution are equal. When no exact linear solution
* exist, a least square solution gives an X which such that A &times; X is the
* closest possible to B.</p>
*
* @version $Revision$ $Date$
* @since 2.0
*/
public interface DecompositionSolver extends Serializable {
/** Solve the linear equation A &times; X = B.
* <p>The A matrix is implicit here. It is </p>
* @param b right-hand side of the equation A &times; X = B
* @return a vector X that minimizes the two norm of A &times; X - B
* @throws IllegalArgumentException if matrices dimensions don't match
* @throws InvalidMatrixException if decomposed matrix is singular
*/
double[] solve(double[] b)
throws IllegalArgumentException, InvalidMatrixException;
/** Solve the linear equation A &times; X = B.
* <p>The A matrix is implicit here. It is </p>
* @param b right-hand side of the equation A &times; X = B
* @return a vector X that minimizes the two norm of A &times; X - B
* @throws IllegalArgumentException if matrices dimensions don't match
* @throws InvalidMatrixException if decomposed matrix is singular
*/
RealVector solve(RealVector b)
throws IllegalArgumentException, InvalidMatrixException;
/** Solve the linear equation A &times; X = B.
* <p>The A matrix is implicit here. It is </p>
* @param b right-hand side of the equation A &times; X = B
* @return a matrix X that minimizes the two norm of A &times; X - B
* @throws IllegalArgumentException if matrices dimensions don't match
* @throws InvalidMatrixException if decomposed matrix is singular
*/
RealMatrix solve(RealMatrix b)
throws IllegalArgumentException, InvalidMatrixException;
}

View File

@ -20,25 +20,43 @@ package org.apache.commons.math.linear;
/**
* An interface to classes that implement a algorithm to calculate the
* QR-decomposition of a real matrix.
* <p>This interface is similar to the class with similar name from the now defunct
* <a href="http://math.nist.gov/javanumerics/jama/">JAMA</a> library.</p>
*
* @see <a href="http://mathworld.wolfram.com/QRDecomposition.html">MathWorld</a>
* @see <a href="http://en.wikipedia.org/wiki/QR_decomposition">Wikipedia</a>
* @version $Revision$ $Date$
* @since 1.2
*/
public interface QRDecomposition {
public interface QRDecomposition extends DecompositionSolver {
/**
* Returns the matrix R of the decomposition.
*
* <p>R is an upper-triangular matrix</p>
* @return the R matrix
*/
public abstract RealMatrix getR();
RealMatrix getR();
/**
* Returns the matrix Q of the decomposition.
*
* <p>Q is an orthogonal matrix</p>
* @return the Q matrix
*/
public abstract RealMatrix getQ();
RealMatrix getQ();
/**
* Returns the Householder reflector vectors.
* <p>H is a lower trapezoidal matrix whose columns represent
* each successive Householder reflector vector. This matrix is used
* to compute Q.</p>
* @return a matrix containing the Householder reflector vectors
*/
RealMatrix getH();
/**
* Check if the decomposed matrix is full rank.
* @return true if the decomposed matrix is full rank
*/
boolean isFullRank();
}

View File

@ -33,29 +33,41 @@ package org.apache.commons.math.linear;
*/
public class QRDecompositionImpl implements QRDecomposition {
/** Serializable version identifier. */
private static final long serialVersionUID = 3965943878043764074L;
/**
* A packed representation of the QR decomposition. The elements above the
* diagonal are the elements of R, and the columns of the lower triangle
* are the Householder reflector vectors of which an explicit form of Q can
* be calculated.
*/
private double[][] qr;
private final double[][] qr;
/**
* The diagonal elements of R.
*/
private double[] rDiag;
private final double[] rDiag;
/** Cached value of Q. */
private RealMatrix cachedQ;
/** Cached value of R. */
private RealMatrix cachedR;
/** Cached value of H. */
private RealMatrix cachedH;
/**
* The row dimension of the given matrix. The size of Q will be m x m, the
* size of R will be m x n.
*/
private int m;
private final int m;
/**
* The column dimension of the given matrix. The size of R will be m x n.
*/
private int n;
private final int n;
/**
* Calculates the QR decomposition of the given matrix.
@ -67,6 +79,9 @@ public class QRDecompositionImpl implements QRDecomposition {
n = matrix.getColumnDimension();
qr = matrix.getData();
rDiag = new double[n];
cachedQ = null;
cachedR = null;
cachedH = null;
/*
* The QR decomposition of a matrix A is calculated using Householder
@ -83,10 +98,10 @@ public class QRDecompositionImpl implements QRDecomposition {
*/
double xNormSqr = 0;
for (int row = minor; row < m; row++) {
xNormSqr += qr[row][minor]*qr[row][minor];
final double c = qr[row][minor];
xNormSqr += c * c;
}
double a = Math.sqrt(xNormSqr);
if (qr[minor][minor] > 0) a = -a;
final double a = (qr[minor][minor] > 0) ? -Math.sqrt(xNormSqr) : Math.sqrt(xNormSqr);
rDiag[minor] = a;
if (a != 0.0) {
@ -113,80 +128,240 @@ public class QRDecompositionImpl implements QRDecomposition {
* |v|^2 = -2a*(qr[minor][minor]), so
* alpha = -<x,v>/(a*qr[minor][minor])
*/
for (int col = minor+1; col < n; col++) {
for (int col = minor + 1; col < n; col++) {
double alpha = 0;
for (int row = minor; row < m; row++) {
alpha -= qr[row][col]*qr[row][minor];
final double[] qrRow = qr[row];
alpha -= qrRow[col] * qrRow[minor];
}
alpha /= a*qr[minor][minor];
alpha /= a * qr[minor][minor];
// Subtract the column vector alpha*v from x.
for (int row = minor; row < m; row++) {
qr[row][col] -= alpha*qr[row][minor];
final double[] qrRow = qr[row];
qrRow[col] -= alpha * qrRow[minor];
}
}
}
}
}
/**
* Returns the matrix R of the QR-decomposition.
*
* @return the R matrix
*/
public RealMatrix getR()
{
// R is supposed to be m x n
RealMatrixImpl ret = new RealMatrixImpl(m,n);
double[][] r = ret.getDataRef();
/** {@inheritDoc} */
public RealMatrix getR() {
// copy the diagonal from rDiag and the upper triangle of qr
for (int row = Math.min(m,n)-1; row >= 0; row--) {
r[row][row] = rDiag[row];
for (int col = row+1; col < n; col++) {
r[row][col] = qr[row][col];
if (cachedR == null) {
// R is supposed to be m x n
double[][] r = new double[m][n];
// copy the diagonal from rDiag and the upper triangle of qr
for (int row = Math.min(m,n)-1; row >= 0; row--) {
final double[] rRow = r[row];
rRow[row] = rDiag[row];
System.arraycopy(qr[row], row + 1, rRow, row + 1, n - row - 1);
}
// cache the matrix for subsequent calls
cachedR = new RealMatrixImpl(r, false);
}
return ret;
// return the cached matrix
return cachedR;
}
/**
* Returns the matrix Q of the QR-decomposition.
*
* @return the Q matrix
*/
public RealMatrix getQ()
{
// Q is supposed to be m x m
RealMatrixImpl ret = new RealMatrixImpl(m,m);
double[][] Q = ret.getDataRef();
/** {@inheritDoc} */
public RealMatrix getQ() {
/*
* Q = Q1 Q2 ... Q_m, so Q is formed by first constructing Q_m and then
* applying the Householder transformations Q_(m-1),Q_(m-2),...,Q1 in
* succession to the result
*/
for (int minor = m-1; minor >= Math.min(m,n); minor--) {
Q[minor][minor]=1;
}
if (cachedQ == null) {
for (int minor = Math.min(m,n)-1; minor >= 0; minor--){
Q[minor][minor] = 1;
if (qr[minor][minor] != 0.0) {
for (int col = minor; col < m; col++) {
double alpha = 0;
for (int row = minor; row < m; row++) {
alpha -= Q[row][col] * qr[row][minor];
}
alpha /= rDiag[minor]*qr[minor][minor];
// Q is supposed to be m x m
double[][] Q = new double[m][m];
for (int row = minor; row < m; row++) {
Q[row][col] -= alpha*qr[row][minor];
/*
* Q = Q1 Q2 ... Q_m, so Q is formed by first constructing Q_m and then
* applying the Householder transformations Q_(m-1),Q_(m-2),...,Q1 in
* succession to the result
*/
for (int minor = m-1; minor >= Math.min(m,n); minor--) {
Q[minor][minor]=1;
}
for (int minor = Math.min(m,n)-1; minor >= 0; minor--){
Q[minor][minor] = 1;
if (qr[minor][minor] != 0.0) {
for (int col = minor; col < m; col++) {
double alpha = 0;
for (int row = minor; row < m; row++) {
alpha -= Q[row][col] * qr[row][minor];
}
alpha /= rDiag[minor]*qr[minor][minor];
for (int row = minor; row < m; row++) {
Q[row][col] -= alpha*qr[row][minor];
}
}
}
}
// cache the matrix for subsequent calls
cachedQ = new RealMatrixImpl(Q, false);
}
return ret;
// return the cached matrix
return cachedQ;
}
/** {@inheritDoc} */
public RealMatrix getH() {
if (cachedH == null) {
double[][] hData = new double[m][n];
for (int i = 0; i < m; ++i) {
for (int j = 0; j < Math.min(i + 1, n); ++j) {
hData[i][j] = qr[i][j] / -rDiag[j];
}
}
// cache the matrix for subsequent calls
cachedH = new RealMatrixImpl(hData, false);
}
// return the cached matrix
return cachedH;
}
/** {@inheritDoc} */
public boolean isFullRank() {
for (double diag : rDiag) {
if (diag == 0) {
return false;
}
}
return true;
}
/** {@inheritDoc} */
public double[] solve(double[] b)
throws IllegalArgumentException, InvalidMatrixException {
if (b.length != m) {
throw new IllegalArgumentException("Incorrect row dimension");
}
if (!isFullRank()) {
throw new InvalidMatrixException("Matrix is rank-deficient");
}
final double[] x = new double[n];
final double[] y = b.clone();
// apply Householder transforms to solve Q.y = b
for (int minor = 0; minor < Math.min(m, n); minor++) {
double dotProduct = 0;
for (int row = minor; row < m; row++) {
dotProduct += y[row] * qr[row][minor];
}
dotProduct /= rDiag[minor] * qr[minor][minor];
for (int row = minor; row < m; row++) {
y[row] += dotProduct * qr[row][minor];
}
}
// solve triangular system R.x = y
for (int row = n - 1; row >= 0; --row) {
y[row] /= rDiag[row];
final double yRow = y[row];
x[row] = yRow;
for (int i = 0; i < row; i++) {
y[i] -= yRow * qr[i][row];
}
}
return x;
}
/** {@inheritDoc} */
public RealVector solve(RealVector b)
throws IllegalArgumentException, InvalidMatrixException {
try {
return solve((RealVectorImpl) b);
} catch (ClassCastException cce) {
return new RealVectorImpl(solve(b.getData()), false);
}
}
/** Solve the linear equation A &times; X = B.
* <p>The A matrix is implicit here. It is </p>
* @param b right-hand side of the equation A &times; X = B
* @return a vector X that minimizes the two norm of A &times; X - B
* @throws IllegalArgumentException if matrices dimensions don't match
* @throws InvalidMatrixException if decomposed matrix is singular
*/
public RealVectorImpl solve(RealVectorImpl b)
throws IllegalArgumentException, InvalidMatrixException {
return new RealVectorImpl(solve(b.getDataRef()), false);
}
/** {@inheritDoc} */
public RealMatrix solve(RealMatrix b)
throws IllegalArgumentException, InvalidMatrixException {
if (b.getRowDimension() != m) {
throw new IllegalArgumentException("Incorrect row dimension");
}
if (!isFullRank()) {
throw new InvalidMatrixException("Matrix is rank-deficient");
}
final int cols = b.getColumnDimension();
final double[][] xData = new double[n][cols];
final double[] y = new double[b.getRowDimension()];
for (int k = 0; k < cols; ++k) {
// get the right hand side vector
for (int j = 0; j < y.length; ++j) {
y[j] = b.getEntry(j, k);
}
// apply Householder transforms to solve Q.y = b
for (int minor = 0; minor < Math.min(m, n); minor++) {
double dotProduct = 0;
for (int row = minor; row < m; row++) {
dotProduct += y[row] * qr[row][minor];
}
dotProduct /= rDiag[minor] * qr[minor][minor];
for (int row = minor; row < m; row++) {
y[row] += dotProduct * qr[row][minor];
}
}
// solve triangular system R.x = y
for (int row = n - 1; row >= 0; --row) {
y[row] /= rDiag[row];
final double yRow = y[row];
xData[row][k] = yRow;
for (int i = 0; i < row; i++) {
y[i] -= yRow * qr[i][row];
}
}
}
return new RealMatrixImpl(xData, false);
}
}

View File

@ -39,6 +39,13 @@ The <action> type attribute can be add,update,fix,remove.
</properties>
<body>
<release version="2.0" date="TBD" description="TBD">
<action dev="luc" type="add" issue="MATH-220" >
Added JAMA-like interfaces for decomposition algorithms. These interfaces
decompose a matrix as a product of several other matrices with predefined
properties and shapes depending on the algorithm. These algorithms allow to
solve the equation A * X = B, either for an exact linear solution
(LU-decomposition) or an exact or least-squares solution (QR-decomposition).
</action>
<action dev="luc" type="add" issue="MATH-219" due-to="Andrew Berry">
Added removeData methods for the SimpleRegression class. This allows
to support regression calculations across a sliding window of (time-based)

View File

@ -146,6 +146,13 @@ public class QRDecompositionImplTest extends TestCase {
assertEquals("R lower triangle", R.getEntry(i, j), 0,
entryTolerance);
matrix = new RealMatrixImpl(testData3x3Singular, false);
R = new QRDecompositionImpl(matrix).getR();
for (int i = 0; i < R.getRowDimension(); i++)
for (int j = 0; j < i; j++)
assertEquals("R lower triangle", R.getEntry(i, j), 0,
entryTolerance);
matrix = new RealMatrixImpl(testData3x4, false);
R = new QRDecompositionImpl(matrix).getR();
for (int i = 0; i < R.getRowDimension(); i++)
@ -160,4 +167,186 @@ public class QRDecompositionImplTest extends TestCase {
assertEquals("R lower triangle", R.getEntry(i, j), 0,
entryTolerance);
}
/** test that H is trapezoidal */
public void testHTrapezoidal() {
RealMatrixImpl matrix = new RealMatrixImpl(testData3x3NonSingular, false);
RealMatrix H = new QRDecompositionImpl(matrix).getH();
for (int i = 0; i < H.getRowDimension(); i++)
for (int j = i + 1; j < H.getColumnDimension(); j++)
assertEquals(H.getEntry(i, j), 0, entryTolerance);
matrix = new RealMatrixImpl(testData3x3Singular, false);
H = new QRDecompositionImpl(matrix).getH();
for (int i = 0; i < H.getRowDimension(); i++)
for (int j = i + 1; j < H.getColumnDimension(); j++)
assertEquals(H.getEntry(i, j), 0, entryTolerance);
matrix = new RealMatrixImpl(testData3x4, false);
H = new QRDecompositionImpl(matrix).getH();
for (int i = 0; i < H.getRowDimension(); i++)
for (int j = i + 1; j < H.getColumnDimension(); j++)
assertEquals(H.getEntry(i, j), 0, entryTolerance);
matrix = new RealMatrixImpl(testData4x3, false);
H = new QRDecompositionImpl(matrix).getH();
for (int i = 0; i < H.getRowDimension(); i++)
for (int j = i + 1; j < H.getColumnDimension(); j++)
assertEquals(H.getEntry(i, j), 0, entryTolerance);
}
/** test rank */
public void testRank() {
QRDecomposition qr =
new QRDecompositionImpl(new RealMatrixImpl(testData3x3NonSingular, false));
assertTrue(qr.isFullRank());
qr = new QRDecompositionImpl(new RealMatrixImpl(testData3x3Singular, false));
assertFalse(qr.isFullRank());
qr = new QRDecompositionImpl(new RealMatrixImpl(testData3x4, false));
assertFalse(qr.isFullRank());
qr = new QRDecompositionImpl(new RealMatrixImpl(testData4x3, false));
assertTrue(qr.isFullRank());
}
/** test solve dimension errors */
public void testSolveDimensionErrors() {
QRDecomposition qr =
new QRDecompositionImpl(new RealMatrixImpl(testData3x3NonSingular, false));
RealMatrix b = new RealMatrixImpl(new double[2][2]);
try {
qr.solve(b);
fail("an exception should have been thrown");
} catch (IllegalArgumentException iae) {
// expected behavior
} catch (Exception e) {
fail("wrong exception caught");
}
try {
qr.solve(b.getColumn(0));
fail("an exception should have been thrown");
} catch (IllegalArgumentException iae) {
// expected behavior
} catch (Exception e) {
fail("wrong exception caught");
}
try {
qr.solve(b.getColumnVector(0));
fail("an exception should have been thrown");
} catch (IllegalArgumentException iae) {
// expected behavior
} catch (Exception e) {
fail("wrong exception caught");
}
}
/** test solve rank errors */
public void testSolveRankErrors() {
QRDecomposition qr =
new QRDecompositionImpl(new RealMatrixImpl(testData3x3Singular, false));
RealMatrix b = new RealMatrixImpl(new double[3][2]);
try {
qr.solve(b);
fail("an exception should have been thrown");
} catch (InvalidMatrixException iae) {
// expected behavior
} catch (Exception e) {
fail("wrong exception caught");
}
try {
qr.solve(b.getColumn(0));
fail("an exception should have been thrown");
} catch (InvalidMatrixException iae) {
// expected behavior
} catch (Exception e) {
fail("wrong exception caught");
}
try {
qr.solve(b.getColumnVector(0));
fail("an exception should have been thrown");
} catch (InvalidMatrixException iae) {
// expected behavior
} catch (Exception e) {
fail("wrong exception caught");
}
}
/** test solve */
public void testSolve() {
QRDecomposition qr =
new QRDecompositionImpl(new RealMatrixImpl(testData3x3NonSingular, false));
RealMatrix b = new RealMatrixImpl(new double[][] {
{ -102, 12250 }, { 544, 24500 }, { 167, -36750 }
});
RealMatrix xRef = new RealMatrixImpl(new double[][] {
{ 1, 2515 }, { 2, 422 }, { -3, 898 }
});
// using RealMatrix
assertEquals(0, qr.solve(b).subtract(xRef).getNorm(), 1.0e-13);
// using double[]
for (int i = 0; i < b.getColumnDimension(); ++i) {
assertEquals(0,
new RealVectorImpl(qr.solve(b.getColumn(i))).subtract(xRef.getColumnVector(i)).getNorm(),
1.0e-13);
}
// using RealVectorImpl
for (int i = 0; i < b.getColumnDimension(); ++i) {
assertEquals(0,
qr.solve(b.getColumnVector(i)).subtract(xRef.getColumnVector(i)).getNorm(),
1.0e-13);
}
// using RealVector with an alternate implementation
for (int i = 0; i < b.getColumnDimension(); ++i) {
RealVectorImplTest.RealVectorTestImpl v =
new RealVectorImplTest.RealVectorTestImpl(b.getColumn(i));
assertEquals(0,
qr.solve(v).subtract(xRef.getColumnVector(i)).getNorm(),
1.0e-13);
}
}
/** test matrices values */
public void testMatricesValues() {
QRDecomposition qr =
new QRDecompositionImpl(new RealMatrixImpl(testData3x3NonSingular, false));
RealMatrix qRef = new RealMatrixImpl(new double[][] {
{ -12.0 / 14.0, 69.0 / 175.0, -58.0 / 175.0 },
{ -6.0 / 14.0, -158.0 / 175.0, 6.0 / 175.0 },
{ 4.0 / 14.0, -30.0 / 175.0, -165.0 / 175.0 }
});
RealMatrix rRef = new RealMatrixImpl(new double[][] {
{ -14.0, -21.0, 14.0 },
{ 0.0, -175.0, 70.0 },
{ 0.0, 0.0, 35.0 }
});
RealMatrix hRef = new RealMatrixImpl(new double[][] {
{ 26.0 / 14.0, 0.0, 0.0 },
{ 6.0 / 14.0, 648.0 / 325.0, 0.0 },
{ -4.0 / 14.0, 36.0 / 325.0, 2.0 }
});
// check values against known references
RealMatrix q = qr.getQ();
assertEquals(0, q.subtract(qRef).getNorm(), 1.0e-13);
RealMatrix r = qr.getR();
assertEquals(0, r.subtract(rRef).getNorm(), 1.0e-13);
RealMatrix h = qr.getH();
assertEquals(0, h.subtract(hRef).getNorm(), 1.0e-13);
// check the same cached instance is returned the second time
assertTrue(q == qr.getQ());
assertTrue(r == qr.getR());
assertTrue(h == qr.getH());
}
}

View File

@ -44,7 +44,7 @@ public class RealVectorImplTest extends TestCase {
// Testclass to test the RealVector interface
// only with enough content to support the test
public class RealVectorTestImpl implements RealVector, Serializable {
public static class RealVectorTestImpl implements RealVector, Serializable {
/** Serializable version identifier. */
private static final long serialVersionUID = 8731816072271374422L;
@ -309,7 +309,7 @@ public class RealVectorImplTest extends TestCase {
}
public double[] getData() {
throw unsupported();
return data.clone();
}
public double dotProduct(RealVector v) throws IllegalArgumentException {