Mark R. Diggory 2003-06-15 17:01:39 +00:00
parent 09c8b57924
commit 765db662a5
3 changed files with 868 additions and 218 deletions

View File

@ -57,180 +57,227 @@ package org.apache.commons.math;
/**
* Interface defining a real-valued matrix with basic algebraic operations
* @author Phil Steitz
* @version $Revision: 1.3 $ $Date: 2003/06/11 14:50:29 $
* @version $Revision: 1.4 $ $Date: 2003/06/15 17:01:39 $
*/
public interface RealMatrix {
public RealMatrix copy();
/**
* Returns a (deep) copy of this.
*
* @return matrix copy
*/
RealMatrix copy();
/**
* Compute the sum of *this and m
* Compute the sum of this and m.
*
* @param m matrix to be added
* @return this + m
* @exception IllegalArgumentException if m is not the same size as *this
* @exception IllegalArgumentException if m is not the same size as this
*/
public RealMatrix add(RealMatrix m);
RealMatrix add(RealMatrix m) throws IllegalArgumentException;
/**
* Compute *this minus m
* Compute this minus m.
*
* @param m matrix to be subtracted
* @return this + m
* @exception IllegalArgumentException if m is not the same size as *this
* @exception IllegalArgumentException if m is not the same size as this
*/
public RealMatrix subtract(RealMatrix m);
RealMatrix subtract(RealMatrix m) throws IllegalArgumentException;
/**
* Returns the rank of the matrix
* @return the rank of this matrix
* Returns the rank of the matrix.
*
* @return the rank of this matrix
*/
public int getRank();
int getRank();
/**
* Returns the result of adding d to each entry of *this
* Returns the result of adding d to each entry of this.
*
* @param d value to be added to each entry
* @return d + this
*/
public RealMatrix scalarAdd(double d);
RealMatrix scalarAdd(double d);
/**
* Returns the result multiplying each entry of *this by d
* Returns the result multiplying each entry of this by d.
*
* @param d value to multiply all entries by
* @return d*this
* @return d * this
*/
public RealMatrix scalarMultiply(double d);
RealMatrix scalarMultiply(double d);
/**
* Returns the result postmultiplyin *this by m
* Returns the result postmultiplying this by m.
*
* @param m matrix to postmultiply by
* @return this*m
* @return this * m
* @throws IllegalArgumentException
* if columnDimension(this) != rowDimension(m)
*/
public RealMatrix multiply(RealMatrix m);
RealMatrix multiply(RealMatrix m) throws IllegalArgumentException;
/**
* Returns matrix entries as a two-dimensional array
* Returns matrix entries as a two-dimensional array.
*
* @return 2-dimensional array of entries
*/
public double[][] getData();
double[][] getData();
/**
* Sets/overwrites the underlying data for the matrix
* Overwrites the underlying data for the matrix with
* a fresh copy of <code>data</code>.
*
* @param data 2-dimensional array of entries
*/
public void setData(double[][] data);
void setData(double[][] data);
/**
* Returns the norm of the matrix
* Returns the <a href="http://mathworld.wolfram.com/
* MaximumAbsoluteRowSumNorm.html">maximum absolute row sum norm</a>
* of the matrix.
*
* @return norm
*/
public double getNorm();
double getNorm();
/**
* Returns entries in row as an array
* @param row the row to be fetched
* @return array of entries in the row
* @throws IllegalArgumentException if row > rowDimension
* Returns the entries in row number <code>row</code> as an array.
*
* @param row the row to be fetched
* @return array of entries in the row
* @throws IllegalArgumentException if row > rowDimension
*/
public double[] getRow(int row);
double[] getRow(int row) throws IllegalArgumentException;
/**
* Returns entries in column as an array
* Returns the entries in column number <code>col</code> as an array.
*
* @param col column to fetch
* @return array of entries in the column
* @throws IllegalArgumentException if column > columnDimension
* @return array of entries in the column
* @throws IllegalArgumentException if column > columnDimension
*/
public double[] getColumn(int col);
double[] getColumn(int col) throws IllegalArgumentException;
/**
* Returns the entry in the specified row and column
* Returns the entry in the specified row and column.
*
* @param row row location of entry to be fetched
* @param column column location of entry to be fetched
* @return matrix entry in row,column
* @throws IllegalArgumentException if entry does not exist
*/
public double getEntry(int row, int column);
double getEntry(int row, int column) throws IllegalArgumentException;
/**
* Sets the entry in the specified row and column to the specified value
* Sets the entry in the specified row and column to the specified value.
*
* @param row row location of entry to be set
* @param column column location of entry to be set
* @param value value to set
* @throws IllegalArgumentException if entry does not exist
*/
public void setEntry(int row, int column, double value);
void setEntry(int row, int column, double value)
throws IllegalArgumentException;
/**
* Returns the transpose of this matrix
* Returns the transpose of this matrix.
*
* @return transpose matrix
*/
public RealMatrix transpose();
RealMatrix transpose();
/**
* Returns the inverse of this matrix
* Returns the inverse of this matrix.
*
* @return inverse matrix
* @throws IllegalArgumentException if *this is not invertible
*/
public RealMatrix inverse();
RealMatrix inverse() throws IllegalArgumentException;
/**
* Returns the determinant of this matrix
* Returns the determinant of this matrix.
*
* @return determinant
*/
public double getDeterminant();
double getDeterminant();
/**
* Is this a square matrix?
* @return true if the matrix is square (rowDimension = columnDimension)
*/
public boolean isSquare();
boolean isSquare();
/**
* Is this a singular matrix?
* @return true if the matrix is singular
*/
public boolean isSingular();
boolean isSingular();
/**
* Returns the number of rows in the matrix
* Returns the number of rows in the matrix.
*
* @return rowDimension
*/
public int getRowDimension();
int getRowDimension();
/**
* Returns the number of columns in the matrix
* Returns the number of columns in the matrix.
*
* @return columnDimension
*/
public int getColumnDimension();
int getColumnDimension();
/**
* Returns the trace of the matrix
* Returns the <a href="http://mathworld.wolfram.com/MatrixTrace.html">
* trace</a> of the matrix (the sum of the elements on the main diagonal).
*
* @return trace
*/
public double getTrace();
double getTrace();
/**
* Returns the result of multiplying this by vector v
* Returns the result of multiplying this by the vector <code>v</code>.
*
* @param v the vector to operate on
* @return this*v
* @throws IllegalArgumentException if columnDimension != v.size()
*/
public double[] operate(double[] v);
double[] operate(double[] v) throws IllegalArgumentException;
/**
* Returns the result of premultiplying this by vector v
* Returns the result of premultiplying this by the vector <code>v</code>.
*
* @param v the row vector to premultiply by
* @return v*this
* @throws IllegalArgumentException if rowDimension != v.size()
*/
public RealMatrix preMultiply(double[] v);
RealMatrix preMultiply(double[] v) throws IllegalArgumentException;
/**
* Returns the solution vector for a linear system with coefficient
* matrix = *this and constant vector = b
* matrix = this and constant vector = <code>b</code>.
*
* @param b constant vector
* @return vector of solution values to AX = b, where A is *this
* @throws IllegalArgumentException if rowDimension != b.length or matrix
* is singular
*/
public double[] solve(double[] b);
double[] solve(double[] b) throws IllegalArgumentException;
/**
* Returns a matrix of (column) solution vectors for linear systems with
* coefficient matrix = this and constant vectors = columns of
* <code>b</code>.
*
* @param b matrix of constant vectors forming RHS of linear systems to
* to solve
* @return matrix of solution vectors
* @throws IllegalArgumentException if rowDimension != row dimension of b
* or this is singular
*/
RealMatrix solve(RealMatrix b) throws IllegalArgumentException;
}

View File

@ -56,50 +56,118 @@ package org.apache.commons.math;
import java.io.Serializable;
/**
* Implementation for RealMatrix using double[][] array
* @author Phil Stetiz
* @version $Revision: 1.3 $ $Date: 2003/06/11 14:50:29 $
* Implementation for RealMatrix using a double[][] array to store entries
* and <a href="http://www.math.gatech.edu/~bourbaki/
* math2601/Web-notes/2num.pdf">LU decompostion</a> to support linear system
* solution and inverse.
* <p>
* The <a href="http://www.math.gatech.edu/~bourbaki/math2601/Web-notes
* /2num.pdf">LU decompostion</a> is performed as needed, to support the
* following operations: <ul>
* <li>solve</li>
* <li>isSingular</li>
* <li>getDeterminant</li>
* <li>inverse</li> </ul>
* <p>
* <strong>Usage note</strong>:<br>
* The LU decomposition is stored and reused on subsequent calls. If matrix
* data are modified using any of the public setXxx methods, the saved
* decomposition is discarded. If data are modified via references to the
* underlying array obtained using <code>getDataRef()</code>, then the stored
* LU decomposition will not be discarded. In this case, you need to
* explicitly invoke <code>LUDecompose()</code> to recompute the decomposition
* before using any of the methods above.
*
* @author Phil Steitz
* @version $Revision: 1.4 $ $Date: 2003/06/15 17:01:39 $
*/
public class RealMatrixImpl implements RealMatrix, Serializable {
/** Entries of the matrix */
private double data[][];
private double data[][] = null;
/** Entries of LU decomposition.
* All updates to data (other than luDecompostion) *must* set this to null
*/
private double lu[][] = null;
/** Pivot array associated with LU decompostion */
private int[] pivot = null;
/** Parity of the permutation associated with the LU decomposition */
private int parity = 1;
/** Bound to determine effective singularity in LU decomposition */
private static double TOO_SMALL = 10E-12;
/**
* Creates a matrix with no data
*/
public RealMatrixImpl() {
}
/**
* Create a new RealMatrix with the supplied row and column dimensions
* @param rowDimension the number of rows in the new matrix
* @param columnDimension the number of columns in the new matrix
*/
public RealMatrixImpl(int rowDimension,
int columnDimension) {
data = new double[rowDimension][columnDimension];
}
public RealMatrixImpl(double[][] data) {
this.data = data;
}
/**
* Create a new RealMatrix which is a copy of *this
* Create a new RealMatrix with the supplied row and column dimensions.
*
* @param rowDimension the number of rows in the new matrix
* @param columnDimension the number of columns in the new matrix
*/
public RealMatrixImpl(int rowDimension,
int columnDimension) {
data = new double[rowDimension][columnDimension];
lu = null;
}
/**
* Create a new RealMatrix using the <code>data</code> as the underlying
* data array.
* <p>
* The input array is copied, not referenced.
*
* @param d data for new matrix
*/
public RealMatrixImpl(double[][] d) {
this.copyIn(d);
lu = null;
}
/**
* Create a new (column) RealMatrix using <code>v</code> as the
* data for the unique column of the <code>v.length x 1</code> matrix
* created.
* <p>
* The input array is copied, not referenced.
*
* @param v column vector holding data for new matrix
*/
public RealMatrixImpl(double[] v) {
int nRows = v.length;
data = new double[nRows][1];
for (int row = 0; row < nRows; row++) {
data[row][0] = v[row];
}
}
/**
* Create a new RealMatrix which is a copy of this.
*
* @return the cloned matrix
*/
public RealMatrix copy() {
throw new UnsupportedOperationException("not implemented yet");
return new RealMatrixImpl(this.copyOut());
}
/**
* Compute the sum of *this and m
* Compute the sum of this and <code>m</code>.
*
* @param m matrix to be added
* @return this + m
* @exception IllegalArgumentException if m is not the same size as *this
* @exception IllegalArgumentException if m is not the same size as this
*/
public RealMatrix add(RealMatrix m) {
public RealMatrix add(RealMatrix m) throws IllegalArgumentException {
if (this.getColumnDimension() != m.getColumnDimension() ||
this.getRowDimension() != m.getRowDimension()) {
throw new IllegalArgumentException("matrix dimension mismatch");
this.getRowDimension() != m.getRowDimension()) {
throw new IllegalArgumentException("matrix dimension mismatch");
}
int rowCount = this.getRowDimension();
int columnCount = this.getColumnDimension();
@ -114,15 +182,16 @@ public class RealMatrixImpl implements RealMatrix, Serializable {
}
/**
* Compute *this minus m
* Compute this minus <code>m</code>.
*
* @param m matrix to be subtracted
* @return this + m
* @exception IllegalArgumentException if m is not the same size as *this
*/
public RealMatrix subtract(RealMatrix m) {
public RealMatrix subtract(RealMatrix m) throws IllegalArgumentException {
if (this.getColumnDimension() != m.getColumnDimension() ||
this.getRowDimension() != m.getRowDimension()) {
throw new IllegalArgumentException("matrix dimension mismatch");
this.getRowDimension() != m.getRowDimension()) {
throw new IllegalArgumentException("matrix dimension mismatch");
}
int rowCount = this.getRowDimension();
int columnCount = this.getColumnDimension();
@ -137,16 +206,19 @@ public class RealMatrixImpl implements RealMatrix, Serializable {
}
/**
* Returns the rank of the matrix
* @return the rank of this matrix
* Returns the rank of the matrix.
*
* @return the rank of this matrix
*/
public int getRank() {
// FIXME: need to add singular value decomposition or drop this
throw new UnsupportedOperationException("not implemented yet");
}
/**
* Returns the result of adding d to each entry of *this
/**
* Returns the result of adding d to each entry of this.
*
* @param d value to be added to each entry
* @return d + this
*/
@ -161,11 +233,11 @@ public class RealMatrixImpl implements RealMatrix, Serializable {
}
return new RealMatrixImpl(outData);
}
/**
* Returns the result multiplying each entry of *this by d
* @param d value to multiply all entries by
* @return d*this
* Returns the result multiplying each entry of this by <code>d</code>
* @param d value to multiply all entries by
* @return d * this
*/
public RealMatrix scalarMultiply(double d) {
int rowCount = this.getRowDimension();
@ -173,152 +245,216 @@ public class RealMatrixImpl implements RealMatrix, Serializable {
double[][] outData = new double[rowCount][columnCount];
for (int row = 0; row < rowCount; row++) {
for (int col = 0; col < columnCount; col++) {
outData[row][col] = data[row][col]*d;
outData[row][col] = data[row][col] * d;
}
}
return new RealMatrixImpl(outData);
}
/**
* Returns the result postmultiplying *this by m
* Returns the result postmultiplying this by <code>m</code>.
* @param m matrix to postmultiply by
* @return this*m
* @throws IllegalArgumentException
* @throws IllegalArgumentException
* if columnDimension(this) != rowDimension(m)
*/
public RealMatrix multiply(RealMatrix m) {
if (this.getColumnDimension() != m.getRowDimension()) {
throw new IllegalArgumentException
public RealMatrix multiply(RealMatrix m) throws IllegalArgumentException {
if (this.getColumnDimension() != m.getRowDimension()) {
throw new IllegalArgumentException
("Matrices are not multiplication compatible.");
}
double[][] mData = m.getData();
double[][] outData =
new double[this.getRowDimension()][m.getColumnDimension()];
double sum = 0;
for (int row = 0; row < this.getRowDimension(); row++) {
for (int col = 0; col < m.getColumnDimension(); col++) {
sum = 0;
for (int i = 0; i < this.getColumnDimension(); i++) {
sum += data[row][i] * mData[i][col];
}
int nRows = this.getRowDimension();
int nCols = this.getColumnDimension();
double[][] mData = m.getData();
double[][] outData =
new double[nRows][nCols];
double sum = 0;
for (int row = 0; row < nRows; row++) {
for (int col = 0; col < nCols; col++) {
sum = 0;
for (int i = 0; i < nCols; i++) {
sum += data[row][i] * mData[i][col];
}
outData[row][col] = sum;
}
outData[row][col] = sum;
}
}
return new RealMatrixImpl(outData);
}
return new RealMatrixImpl(outData);
}
/**
* Returns matrix entries as a two-dimensional array
* Returns matrix entries as a two-dimensional array.
* <p>
* Makes a fresh copy of the underlying data.
*
* @return 2-dimensional array of entries
*/
public double[][] getData() {
return copyOut();
}
/**
* Overwrites the underlying data for the matrix
* with a fresh copy of <code>inData</code>.
*
* @param inData 2-dimensional array of entries
*/
public void setData(double[][] inData) {
copyIn(inData);
lu = null;
}
/**
* Returns a reference to the underlying data array.
* <p>
* Does not make a fresh copy of the underlying data.
*
* @return 2-dimensional array of entries
*/
public double[][] getDataRef() {
return data;
}
/**
* Sets/overwrites the underlying data for the matrix
* @param data 2-dimensional array of entries
* Overwrites the underlying data for the matrix
* with a reference to <code>inData</code>.
* <p>
* Does not make a fresh copy of <code>data</code>.
*
* @param inData 2-dimensional array of entries
*/
public void setData(double[][] data) {
this.data = data;
public void setDataRef(double[][] inData) {
this.data = inData;
lu = null;
}
/**
* Returns the 1-norm of the matrix (max column sum)
*
* @return norm
*/
public double getNorm() {
double maxColSum = 0;
for (int col = 0; col < this.getColumnDimension(); col++) {
double sum = 0;
for (int row = 0; row < this.getRowDimension(); row++) {
sum += Math.abs(data[row][col]);
}
maxColSum = Math.max(maxColSum,sum);
}
return maxColSum;
}
/**
* Returns entries in row as an array
* @param row the row to be fetched
* @return array of entries in the row
* @throws IllegalArgumentException if row > rowDimension
*/
public double[] getRow(int row) {
return data[row];
}
/**
* Returns entries in column as an array
* @param col column to fetch
* @return array of entries in the column
* @throws IllegalArgumentException if column > columnDimension
*/
public double[] getColumn(int col) {
throw new UnsupportedOperationException("not implemented yet");
}
/**
* Returns the entry in the specified row and column
* @param row row location of entry to be fetched
* @param column column location of entry to be fetched
* @return matrix entry in row,column
* @throws IllegalArgumentException if entry does not exist
*/
public double getEntry(int row, int column) {
if (row < 1 || column < 1 || row > this.getRowDimension()
|| column > this.getColumnDimension()) {
throw new IllegalArgumentException
("matrix entry does not exist");
double maxColSum = 0;
for (int col = 0; col < this.getColumnDimension(); col++) {
double sum = 0;
for (int row = 0; row < this.getRowDimension(); row++) {
sum += Math.abs(data[row][col]);
}
maxColSum = Math.max(maxColSum, sum);
}
return data[row-1][column-1];
return maxColSum;
}
/**
* Sets the entry in the specified row and column to the specified value
* @param row row location of entry to be set
* @param column column location of entry to be set
* @param value value to set
*
* @param row the row to be fetched
* @return array of entries in the row
* @throws IllegalArgumentException if row > rowDimension or row < 1
*/
public double[] getRow(int row) throws IllegalArgumentException {
if (row > this.getRowDimension() || row < 1) {
throw new IllegalArgumentException("illegal row argument");
}
int ncols = this.getColumnDimension();
double[] out = new double[ncols];
System.arraycopy(data[row - 1], 0, out, 0, ncols);
return out;
}
/**
* @param col column to fetch
* @return array of entries in the column
* @throws IllegalArgumentException if column > columnDimension or
* column < 1
*/
public double[] getColumn(int col) throws IllegalArgumentException {
if (col > this.getColumnDimension() || col < 1) {
throw new IllegalArgumentException("illegal column argument");
}
int nRows = this.getRowDimension();
double[] out = new double[nRows];
for (int row = 0; row < nRows; row++) {
out[row] = data[row][col - 1];
}
return out;
}
/**
* @param row row location of entry to be fetched
* @param column column location of entry to be fetched
* @return matrix entry in row,column
* @throws IllegalArgumentException if entry does not exist
*/
public void setEntry(int row, int column, double value) {
public double getEntry(int row, int column)
throws IllegalArgumentException {
if (row < 1 || column < 1 || row > this.getRowDimension()
|| column > this.getColumnDimension()) {
throw new IllegalArgumentException
("matrix entry does not exist");
|| column > this.getColumnDimension()) {
throw new IllegalArgumentException
("matrix entry does not exist");
}
data[row-1][column-1] = value;
return data[row - 1][column - 1];
}
/**
* Returns the transpose of this matrix
* @param row row location of entry to be set
* @param column column location of entry to be set
* @param value value to set
* @throws IllegalArgumentException if entry does not exist
*/
public void setEntry(int row, int column, double value)
throws IllegalArgumentException {
if (row < 1 || column < 1 || row > this.getRowDimension()
|| column > this.getColumnDimension()) {
throw new IllegalArgumentException
("matrix entry does not exist");
}
data[row - 1][column - 1] = value;
lu = null;
}
/**
*
* @return transpose matrix
*/
public RealMatrix transpose() {
throw new UnsupportedOperationException("not implemented yet");
}
int nRows = this.getRowDimension();
int nCols = this.getColumnDimension();
RealMatrixImpl out = new RealMatrixImpl(nCols, nRows);
double[][] outData = out.getDataRef();
for (int row = 0; row < nRows; row++) {
for (int col = 0; col < nCols; col++) {
outData[col][row] = data[row][col];
}
}
return out;
}
/**
* Returns the inverse of this matrix
* @return inverse matrix
* @throws IllegalArgumentException if *this is not invertible
* @throws IllegalArgumentException if this is not invertible
*/
public RealMatrix inverse() {
throw new UnsupportedOperationException("not implemented yet");
public RealMatrix inverse() throws IllegalArgumentException {
return solve(getIdentity(this.getRowDimension()));
}
/**
* Returns the determinant of this matrix
* @return determinant
* @throws IllegalArgumentException if matrix is not square
*/
public double getDeterminant() {
throw new UnsupportedOperationException("not implemented yet");
public double getDeterminant() throws IllegalArgumentException {
if (!isSquare()) {
throw new IllegalArgumentException("matrix is not square");
}
if (isSingular()) { // note: this has side effect of attempting LU
return 0d; // decomp if lu == null
} else {
double det = (double) parity;
for (int i = 0; i < this.getRowDimension(); i++) {
det *= lu[i][i];
}
return det;
}
}
/**
* Is this a square matrix?
* @return true if the matrix is square (rowDimension = columnDimension)
*/
public boolean isSquare() {
@ -326,23 +462,29 @@ public class RealMatrixImpl implements RealMatrix, Serializable {
}
/**
* Is this a singular matrix?
* @return true if the matrix is singular
*/
public boolean isSingular() {
throw new UnsupportedOperationException("not implemented yet");
if (lu == null) {
try {
LUDecompose();
return false;
} catch (IllegalArgumentException ex) {
return true;
}
} else { // LU decomp must have been successfully performed
return false; // so the matrix is not singular
}
}
/**
* Returns the number of rows in the matrix
* @return rowDimension
*/
public int getRowDimension() {
return data.length;
return data.length;
}
/**
* Returns the number of columns in the matrix
* @return columnDimension
*/
public int getColumnDimension() {
@ -350,41 +492,276 @@ public class RealMatrixImpl implements RealMatrix, Serializable {
}
/**
* Returns the trace of the matrix
* @return trace
* @throws IllegalArgumentException if the matrix is not square
*/
public double getTrace() {
throw new UnsupportedOperationException("not implemented yet");
public double getTrace() throws IllegalArgumentException {
if (!isSquare()) {
throw new IllegalArgumentException("matrix is not square");
}
double trace = data[0][0];
for (int i = 1; i < this.getRowDimension(); i++) {
trace += data[i][i];
}
return trace;
}
/**
* Returns the result of multiplying this by the vector b
* @return this*v
* @throws IllegalArgumentException if columnDimension != v.size()
* @param v vector to operate on
* @throws IllegalArgumentException if columnDimension != v.length
* @return resulting vector
*/
public double[] operate(double[] v) {
throw new UnsupportedOperationException("not implemented yet");
public double[] operate(double[] v) throws IllegalArgumentException {
if (v.length != this.getColumnDimension()) {
throw new IllegalArgumentException("vector has wrong length");
}
int nRows = this.getRowDimension();
int nCols = this.getColumnDimension();
double[] out = new double[v.length];
for (int row = 0; row < nRows; row++) {
double sum = 0;
for (int i = 0; i < nCols; i++) {
sum += data[row][i] * v[i];
}
out[row] = sum;
}
return out;
}
/**
* Returns the result of premultiplying this by the vector v
* @return v*this
* @throws IllegalArgumentException if rowDimension != v.size()
* @param v vector to premultiply by
* @throws IllegalArgumentException if rowDimension != v.length
* @return resulting matrix
*/
public RealMatrix preMultiply(double[] v) {
throw new UnsupportedOperationException("not implemented yet");
public RealMatrix preMultiply(double[] v) throws IllegalArgumentException {
int nCols = this.getColumnDimension();
if (v.length != nCols) {
throw new IllegalArgumentException("vector has wrong length");
}
// being a bit lazy here -- probably should implement directly, like
// operate
RealMatrix pm = new RealMatrixImpl(v).transpose();
return pm.multiply(this);
}
/**
* Returns the solution vector for a linear system with coefficient
* matrix = *this and constant vector = b
* @param b constant vector
* @return vector of solution values to AX = b, where A is *this
* @throws IllegalArgumentException if rowDimension != b.length or matrix
* @return vector of solution values to AX = b, where A is this
* @throws IllegalArgumentException if rowDimension != b.length or matrix
* is singular
*/
public double[] solve(double[] b) {
throw new UnsupportedOperationException("not implemented yet");
}
public double[] solve(double[] b) throws IllegalArgumentException {
int nRows = this.getRowDimension();
if (b.length != nRows) {
throw new IllegalArgumentException
("constant vector has wrong length");
}
RealMatrix bMatrix = new RealMatrixImpl(b);
double[][] solution = ((RealMatrixImpl) (solve(bMatrix))).getDataRef();
double[] out = new double[nRows];
for (int row = 0; row < nRows; row++) {
out[row] = solution[row][0];
}
return out;
}
/**
* Uses LU decomposition, performing the composition if the matrix has
* not been decomposed, or if there have been changes to the matrix since
* the last decomposition.
*
* @param b the constant vector
* @return solution matrix
* @throws IllegalArgumentException if this is singular or dimensions
* do not match.
*/
public RealMatrix solve(RealMatrix b) throws IllegalArgumentException {
if (b.getRowDimension() != this.getRowDimension()) {
throw new IllegalArgumentException("Incorrect row dimension");
}
if (this.isSingular()) { // side effect: compute LU decomp
throw new IllegalArgumentException("Matrix is singular.");
}
int nCol = this.getColumnDimension();
int nRow = this.getRowDimension();
int nColB = b.getColumnDimension();
int nRowB = b.getRowDimension();
// Apply permutations to b
double[][] bv = b.getData();
double[][] bp = new double[nRowB][nColB];
for (int row = 0; row < nRowB; row++) {
for (int col = 0; col < nColB; col++) {
bp[row][col] = bv[pivot[row]][col];
}
}
bv = null;
// Solve LY = b
for (int col = 0; col < nCol; col++) {
for (int i = col + 1; i < nCol; i++) {
for (int j = 0; j < nColB; j++) {
bp[i][j] -= bp[col][j] * lu[i][col];
}
}
}
// Solve UX = Y
for (int col = nCol - 1; col >= 0; col--) {
for (int j = 0; j < nColB; j++) {
bp[col][j] /= lu[col][col];
}
for (int i = 0; i < col; i++) {
for (int j = 0; j < nColB; j++) {
bp[i][j] -= bp[col][j] * lu[i][col];
}
}
}
RealMatrixImpl outMat = new RealMatrixImpl(bp);
return outMat;
}
/**
* Computes a new <a href="http://www.math.gatech.edu/~bourbaki/
* math2601/Web-notes/2num.pdf">LU decompostion</a> for this matrix,
* storing the result for use by other methods.
* <p>
* <strong>Implementation Note</strong>:<br>
* Uses <a href="http://www.damtp.cam.ac.uk/user/fdl/
* people/sd/lectures/nummeth98/linear.htm">Crout's algortithm</a>,
* with partial pivoting.
* <p>
* <strong>Usage Note</strong>:<br>
* This method should rarely be invoked directly. Its only use is
* to force recomputation of the LU decomposition when changes have been
* made to the underlying data using direct array references. Changes
* made using setXxx methods will trigger recomputation when needed
* automatically.
*
* @throws IllegalArgumentException if the matrix is singular
*/
public void LUDecompose() throws IllegalArgumentException {
int nRows = this.getRowDimension();
int nCols = this.getColumnDimension();
lu = this.getData();
// Initialize pivot array and parity
pivot = new int[nRows];
for (int row = 0; row < nRows; row++) {
pivot[row] = row;
}
parity = 1;
// Loop over columns
for (int col = 0; col < nCols; col++) {
double sum = 0;
// upper
for (int row = 0; row < col; row++) {
sum = lu[row][col];
for (int i = 0; i < row; i++) {
sum -= lu[row][i] * lu[i][col];
}
lu[row][col] = sum;
}
// lower
int max = col; // pivot row
double largest = 0d;
for (int row = col; row < nRows; row++) {
sum = lu[row][col];
for (int i = 0; i < col; i++) {
sum -= lu[row][i] * lu[i][col];
}
lu[row][col] = sum;
// maintain best pivot choice
if (Math.abs(sum) > largest) {
largest = Math.abs(sum);
max = row;
}
}
// Singularity check
if (Math.abs(lu[max][col]) < TOO_SMALL) {
lu = null;
throw new IllegalArgumentException("matrix is singular");
}
// Pivot if necessary
if (max != col) {
double tmp = 0;
for (int i = 0; i < nCols; i++) {
tmp = lu[max][i];
lu[max][i] = lu[col][i];
lu[col][i] = tmp;
}
int temp = pivot[max];
pivot[max] = pivot[col];
pivot[col] = temp;
parity = -parity;
}
//Divide the lower elements by the "winning" diagonal elt.
for (int row = col + 1; row < nRows; row++) {
lu[row][col] /= lu[col][col];
}
}
}
//------------------------ Protected methods
/**
* Returns <code>dimension x dimension</code> identity matrix.
*
* @param dimension dimension of identity matrix to generate
* @return identity matrix
*/
protected RealMatrix getIdentity(int dimension) {
RealMatrixImpl out = new RealMatrixImpl(dimension, dimension);
double[][] d = out.getDataRef();
for (int row = 0; row < dimension; row++) {
for (int col = 0; col < dimension; col++) {
d[row][col] = row == col ? 1d : 0d;
}
}
return out;
}
//------------------------ Private methods
/**
* Returns a fresh copy of the underlying data array.
*
* @return a copy of the underlying data array.
*/
private double[][] copyOut() {
int nRows = this.getRowDimension();
double[][] out =
new double[nRows][this.getColumnDimension()];
// can't copy 2-d array in one shot, otherwise get row references
for (int i = 0; i < nRows; i++) {
System.arraycopy(data[i], 0, out[i], 0, data[i].length);
}
return out;
}
/**
* Replaces data with a fresh copy of the input array.
*
* @param in data to copy in
*/
private void copyIn(double[][] in) {
int nRows = in.length;
int nCols = in[0].length;
data = new double[nRows][nCols];
System.arraycopy(in, 0, data, 0, in.length);
for (int i = 0; i < nRows ; i++) {
System.arraycopy(in[i], 0, data[i], 0, nCols);
}
lu = null;
}
}

View File

@ -61,21 +61,34 @@ import junit.framework.TestSuite;
* Test cases for the {@link RealMatrixImpl} class.
*
* @author Phil Steitz
* @version $Revision: 1.1 $ $Date: 2003/05/12 19:02:53 $
* @version $Revision: 1.2 $ $Date: 2003/06/15 17:01:39 $
*/
public final class RealMatrixImplTest extends TestCase {
private double[][] testData = { {1d,2d,3d}, {2d,5d,3d}, {1d,0d,8d} };
private double[][] testDataPlus2 = { {3d,4d,5d}, {4d,7d,5d}, {3d,2d,10d} };
private double[][] testDataMinus = { {-1d,-2d,-3d}, {-2d,-5d,-3d},
{-1d,0d,-8d} };
private double[] testDataRow1 = {1d,2d,3d};
private double[] testDataCol3 = {3d,3d,8d};
private double[][] testDataInv =
{ {-40d,16d,9d}, {13d,-5d,-3d}, {5d,-2d,-1d} };
private double[][] preMultTest = {{8,12,33}};
private double[][] testData2 ={ {1d,2d,3d}, {2d,5d,3d}};
private double[][] testData2T = { {1d,2d}, {2d,5d}, {3d,3d}};
private double[][] testDataPlusInv =
{ {-39d,18d,12d}, {15d,0d,0d}, {6d,-2d,7d} };
private double[][] id = { {1d,0d,0d}, {0d,1d,0d}, {0d,0d,1d} };
private double[][] luData = { {2d,3d,3d}, {0d,5d,7d}, {6d,9d,8d} };
private double[][] singular = { {2d,3d}, {2d,3d} };
private double[][] bigSingular = {{1d,2d,3d,4d}, {2d,5d,3d,4d},
{7d,3d,256d,1930d}, {3d,7d,6d,8d}}; // 4th row = 1st + 2nd
private double[][] detData = { {1d,2d,3d}, {4d,5d,6d}, {7d,8d,10d} };
private double[] testVector = {1,2,3};
private double entryTolerance = Math.pow(2,-64);
private double normTolerance = Math.pow(2,-64);
private double[] testVector2 = {1,2,3,4};
private double entryTolerance = 10E-16;
private double normTolerance = 10E-14;
public RealMatrixImplTest(String name) {
super(name);
@ -101,7 +114,24 @@ public final class RealMatrixImplTest extends TestCase {
assertEquals("testData2 row dimension",m2.getRowDimension(),2);
assertEquals("testData2 column dimension",m2.getColumnDimension(),3);
assertTrue("testData2 is not square",!m2.isSquare());
}
}
/** test copy functions */
public void testCopyFunctions() {
RealMatrixImpl m = new RealMatrixImpl(testData);
RealMatrixImpl m2 = new RealMatrixImpl(testData2);
m2.setData(m.getData());
assertClose("getData",m2,m,entryTolerance);
// no dangling reference...
m2.setEntry(1,1,2000d);
RealMatrixImpl m3 = new RealMatrixImpl(testData);
assertClose("no getData side effect",m,m3,entryTolerance);
m3 = (RealMatrixImpl) m.copy();
double[][] stompMe = {{1d,2d,3d}};
m3.setDataRef(stompMe);
assertClose("no copy side effect",m,new RealMatrixImpl(testData),
entryTolerance);
}
/** test add */
public void testAdd() {
@ -143,7 +173,13 @@ public final class RealMatrixImplTest extends TestCase {
RealMatrixImpl m = new RealMatrixImpl(testData);
RealMatrixImpl m2 = new RealMatrixImpl(testDataInv);
assertClose("m-n = m + -n",m.subtract(m2),
m2.scalarMultiply(-1d).add(m),entryTolerance);
m2.scalarMultiply(-1d).add(m),entryTolerance);
try {
RealMatrix a = m.subtract(new RealMatrixImpl(testData2));
fail("Expecting illegalArgumentException");
} catch (IllegalArgumentException ex) {
;
}
}
/** test multiply */
@ -161,13 +197,203 @@ public final class RealMatrixImplTest extends TestCase {
assertClose("identity multiply",identity.multiply(mInv),
mInv,entryTolerance);
assertClose("identity multiply",m2.multiply(identity),
m2,entryTolerance);
m2,entryTolerance);
try {
RealMatrix a = m.multiply(new RealMatrixImpl(bigSingular));
fail("Expecting illegalArgumentException");
} catch (IllegalArgumentException ex) {
;
}
}
/** test isSingular */
public void testIsSingular() {
RealMatrixImpl m = new RealMatrixImpl(singular);
assertTrue("singular",m.isSingular());
m = new RealMatrixImpl(bigSingular);
assertTrue("big singular",m.isSingular());
m = new RealMatrixImpl(id);
assertTrue("identity nonsingular",!m.isSingular());
m = new RealMatrixImpl(testData);
assertTrue("testData nonsingular",!m.isSingular());
}
/** test inverse */
public void testInverse() {
RealMatrixImpl m = new RealMatrixImpl(testData);
RealMatrix mInv = new RealMatrixImpl(testDataInv);
assertClose("inverse",mInv,m.inverse(),normTolerance);
assertClose("inverse^2",m,m.inverse().inverse(),10E-12);
}
/** test solve */
public void testSolve() {
RealMatrixImpl m = new RealMatrixImpl(testData);
RealMatrix mInv = new RealMatrixImpl(testDataInv);
// being a bit slothful here -- actually testing that X = A^-1 * B
assertClose("inverse-operate",mInv.operate(testVector),
m.solve(testVector),normTolerance);
try {
double[] x = m.solve(testVector2);
fail("expecting IllegalArgumentException");
} catch (IllegalArgumentException ex) {
;
}
RealMatrix bs = new RealMatrixImpl(bigSingular);
try {
RealMatrix a = bs.solve(bs);
fail("Expecting illegalArgumentException");
} catch (IllegalArgumentException ex) {
;
}
try {
RealMatrix a = m.solve(bs);
fail("Expecting illegalArgumentException");
} catch (IllegalArgumentException ex) {
;
}
}
/** test determinant */
public void testDeterminant() {
RealMatrix m = new RealMatrixImpl(bigSingular);
assertEquals("singular determinant",0,m.getDeterminant(),0);
m = new RealMatrixImpl(detData);
assertEquals("nonsingular test",-3d,m.getDeterminant(),normTolerance);
try {
double a = new RealMatrixImpl(testData2).getDeterminant();
fail("Expecting illegalArgumentException");
} catch (IllegalArgumentException ex) {
;
}
}
/** test trace */
public void testTrace() {
RealMatrix m = new RealMatrixImpl(id);
assertEquals("identity trace",3d,m.getTrace(),entryTolerance);
m = new RealMatrixImpl(testData2);
try {
double x = m.getTrace();
fail("Expecting illegalArgumentException");
} catch (IllegalArgumentException ex) {
;
}
}
/** test sclarAdd */
public void testScalarAdd() {
RealMatrix m = new RealMatrixImpl(testData);
assertClose("scalar add",new RealMatrixImpl(testDataPlus2),
m.scalarAdd(2d),entryTolerance);
}
/** test operate */
public void testOperate() {
RealMatrix m = new RealMatrixImpl(id);
double[] x = m.operate(testVector);
assertClose("identity operate",testVector,x,entryTolerance);
m = new RealMatrixImpl(bigSingular);
try {
x = m.operate(testVector);
fail("Expecting illegalArgumentException");
} catch (IllegalArgumentException ex) {
;
}
}
/** test transpose */
public void testTranspose() {
RealMatrix m = new RealMatrixImpl(testData);
assertClose("inverse-transpose",m.inverse().transpose(),
m.transpose().inverse(),normTolerance);
m = new RealMatrixImpl(testData2);
RealMatrix mt = new RealMatrixImpl(testData2T);
assertClose("transpose",mt,m.transpose(),normTolerance);
}
/** test preMultiply */
public void testPremultiply() {
RealMatrix m = new RealMatrixImpl(testData);
RealMatrix mp = new RealMatrixImpl(preMultTest);
assertClose("premultiply",m.preMultiply(testVector),mp,normTolerance);
m = new RealMatrixImpl(bigSingular);
try {
RealMatrix x = m.preMultiply(testVector);
fail("expecting IllegalArgumentException");
} catch (IllegalArgumentException ex) {
;
}
}
public void testGetVectors() {
RealMatrix m = new RealMatrixImpl(testData);
assertClose("get row",m.getRow(1),testDataRow1,entryTolerance);
assertClose("get col",m.getColumn(3),testDataCol3,entryTolerance);
try {
double[] x = m.getRow(10);
fail("expecting IllegalArgumentException");
} catch (IllegalArgumentException ex) {
;
}
try {
double[] x = m.getColumn(-1);
fail("expecting IllegalArgumentException");
} catch (IllegalArgumentException ex) {
;
}
}
public void testEntryMutators() {
RealMatrix m = new RealMatrixImpl(testData);
assertEquals("get entry",m.getEntry(1,2),2d,entryTolerance);
m.setEntry(1,2,100d);
assertEquals("get entry",m.getEntry(1,2),100d,entryTolerance);
try {
double x = m.getEntry(0,2);
fail("expecting IllegalArgumentException");
} catch (IllegalArgumentException ex) {
;
}
try {
m.setEntry(1,4,200d);
fail("expecting IllegalArgumentException");
} catch (IllegalArgumentException ex) {
;
}
}
//--------------- -----------------Private methods
/** verifies that two matrices are close (1-norm) */
private void assertClose(String msg, RealMatrix m, RealMatrix n,
double tolerance) {
assertTrue(msg,m.subtract(n).getNorm() < tolerance);
}
/** verifies that two vectors are close (sup norm) */
private void assertClose(String msg, double[] m, double[] n,
double tolerance) {
if (m.length != n.length) {
fail("vectors not same length");
}
for (int i = 0; i < m.length; i++) {
assertEquals(msg + " " + i + " elements differ",
m[i],n[i],tolerance);
}
}
/** Useful for debugging */
private void dumpMatrix(RealMatrix m) {
for (int i = 0; i < m.getRowDimension(); i++) {
String os = "";
for (int j = 0; j < m.getColumnDimension(); j++) {
os += m.getEntry(i+1, j+1) + " ";
}
System.out.println(os);
}
}
}