Improved linear RealMatrixImpl and BigMatrixImpl performances.

The main changes are the following ones:
 - use directly the storage array when possible for
   diadic operations (add, subtract, multiply), as
   suggested by Phil, this avoids the cost of the
   generic getEntry method
 - replaced custom indices checks by simple use of
   the JVM checks and ArrayIndexOutOfBoundException
 - put row arrays reference in local variables to
   avoid multiple checks in double loops
 - use final variables where possible
 - removed unneeded array copying
 - added a constructor to build a matrix from an
   array without copying it where it makes sense

The speed gain is about 3X for multiplication. Performances
for this operation are now on par with Jama.

git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/branches/MATH_2_0@662241 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Luc Maisonobe 2008-06-01 16:22:19 +00:00
parent 7e7207cd4f
commit 482ebca8f5
8 changed files with 738 additions and 363 deletions

View File

@ -106,8 +106,9 @@ public class BigMatrixImpl implements BigMatrix, Serializable {
/**
* Create a new BigMatrix using <code>d</code> as the underlying
* data array.
* <p>
* The input array is copied, not referenced.</p>
* <p>The input array is copied, not referenced. This constructor has
* the same effect as calling {@link #BigMatrixImpl(BigDecimal[][], boolean)}
* with the second argument set to <code>true</code>.</p>
*
* @param d data for new matrix
* @throws IllegalArgumentException if <code>d</code> is not rectangular
@ -118,12 +119,52 @@ public class BigMatrixImpl implements BigMatrix, Serializable {
this.copyIn(d);
lu = null;
}
/**
* Create a new BigMatrix using the input array as the underlying
* data array.
* <p>If an array is built specially in order to be embedded in a
* BigMatrix and not used directly, the <code>copyArray</code> may be
* set to <code>false</code. This will prevent the copying and improve
* performance as no new array will be built and no data will be copied.</p>
* @param d data for new matrix
* @param copyArray if true, the input array will be copied, otherwise
* it will be referenced
* @throws IllegalArgumentException if <code>d</code> is not rectangular
* (not all rows have the same length) or empty
* @throws NullPointerException if <code>d</code> is null
* @see #BigMatrix(BigDecimal[][])
*/
public BigMatrixImpl(BigDecimal[][] d, boolean copyArray) {
if (copyArray) {
copyIn(d);
} else {
if (d == null) {
throw new NullPointerException();
}
final int nRows = d.length;
if (nRows == 0) {
throw new IllegalArgumentException("Matrix must have at least one row.");
}
final int nCols = d[0].length;
if (nCols == 0) {
throw new IllegalArgumentException("Matrix must have at least one column.");
}
for (int r = 1; r < nRows; r++) {
if (d[r].length != nCols) {
throw new IllegalArgumentException("All input rows must have the same length.");
}
}
data = d;
}
lu = null;
}
/**
* Create a new BigMatrix using <code>d</code> as the underlying
* data array.
* <p>
* The input array is copied, not referenced.</p>
* <p>Since the underlying array will hold <code>BigDecimal</code>
* instances, it will be created.</p>
*
* @param d data for new matrix
* @throws IllegalArgumentException if <code>d</code> is not rectangular
@ -131,12 +172,12 @@ public class BigMatrixImpl implements BigMatrix, Serializable {
* @throws NullPointerException if <code>d</code> is null
*/
public BigMatrixImpl(double[][] d) {
int nRows = d.length;
final int nRows = d.length;
if (nRows == 0) {
throw new IllegalArgumentException(
"Matrix must have at least one row.");
}
int nCols = d[0].length;
final int nCols = d[0].length;
if (nCols == 0) {
throw new IllegalArgumentException(
"Matrix must have at least one column.");
@ -161,12 +202,12 @@ public class BigMatrixImpl implements BigMatrix, Serializable {
* @throws NullPointerException if <code>d</code> is null
*/
public BigMatrixImpl(String[][] d) {
int nRows = d.length;
final int nRows = d.length;
if (nRows == 0) {
throw new IllegalArgumentException(
"Matrix must have at least one row.");
}
int nCols = d[0].length;
final int nCols = d[0].length;
if (nCols == 0) {
throw new IllegalArgumentException(
"Matrix must have at least one column.");
@ -191,7 +232,7 @@ public class BigMatrixImpl implements BigMatrix, Serializable {
* @param v column vector holding data for new matrix
*/
public BigMatrixImpl(BigDecimal[] v) {
int nRows = v.length;
final int nRows = v.length;
data = new BigDecimal[nRows][1];
for (int row = 0; row < nRows; row++) {
data[row][0] = v[row];
@ -204,7 +245,7 @@ public class BigMatrixImpl implements BigMatrix, Serializable {
* @return the cloned matrix
*/
public BigMatrix copy() {
return new BigMatrixImpl(this.copyOut());
return new BigMatrixImpl(this.copyOut(), false);
}
/**
@ -212,47 +253,107 @@ public class BigMatrixImpl implements BigMatrix, Serializable {
*
* @param m matrix to be added
* @return this + m
* @exception IllegalArgumentException if m is not the same size as this
* @throws IllegalArgumentException if m is not the same size as this
*/
public BigMatrix add(BigMatrix m) throws IllegalArgumentException {
if (this.getColumnDimension() != m.getColumnDimension() ||
this.getRowDimension() != m.getRowDimension()) {
try {
return add((BigMatrixImpl) m);
} catch (ClassCastException cce) {
final int rowCount = getRowDimension();
final int columnCount = getColumnDimension();
if (columnCount != m.getColumnDimension() || rowCount != m.getRowDimension()) {
throw new IllegalArgumentException("matrix dimension mismatch");
}
final BigDecimal[][] outData = new BigDecimal[rowCount][columnCount];
for (int row = 0; row < rowCount; row++) {
final BigDecimal[] dataRow = data[row];
final BigDecimal[] outDataRow = outData[row];
for (int col = 0; col < columnCount; col++) {
outDataRow[col] = dataRow[col].add(m.getEntry(row, col));
}
}
return new BigMatrixImpl(outData, false);
}
}
/**
* Compute the sum of this and <code>m</code>.
*
* @param m matrix to be added
* @return this + m
* @throws IllegalArgumentException if m is not the same size as this
*/
public BigMatrixImpl add(BigMatrixImpl m) throws IllegalArgumentException {
final int rowCount = getRowDimension();
final int columnCount = getColumnDimension();
if (columnCount != m.getColumnDimension() || rowCount != m.getRowDimension()) {
throw new IllegalArgumentException("matrix dimension mismatch");
}
int rowCount = this.getRowDimension();
int columnCount = this.getColumnDimension();
BigDecimal[][] outData = new BigDecimal[rowCount][columnCount];
final BigDecimal[][] outData = new BigDecimal[rowCount][columnCount];
for (int row = 0; row < rowCount; row++) {
final BigDecimal[] dataRow = data[row];
final BigDecimal[] mRow = m.data[row];
final BigDecimal[] outDataRow = outData[row];
for (int col = 0; col < columnCount; col++) {
outData[row][col] = data[row][col].add(m.getEntry(row, col));
}
outDataRow[col] = dataRow[col].add(mRow[col]);
}
}
return new BigMatrixImpl(outData);
return new BigMatrixImpl(outData, false);
}
/**
* Compute this minus <code>m</code>.
*
* @param m matrix to be subtracted
* @return this + m
* @exception IllegalArgumentException if m is not the same size as *this
* @throws IllegalArgumentException if m is not the same size as this
*/
public BigMatrix subtract(BigMatrix m) throws IllegalArgumentException {
if (this.getColumnDimension() != m.getColumnDimension() ||
this.getRowDimension() != m.getRowDimension()) {
try {
return subtract((BigMatrixImpl) m);
} catch (ClassCastException cce) {
final int rowCount = getRowDimension();
final int columnCount = getColumnDimension();
if (columnCount != m.getColumnDimension() || rowCount != m.getRowDimension()) {
throw new IllegalArgumentException("matrix dimension mismatch");
}
final BigDecimal[][] outData = new BigDecimal[rowCount][columnCount];
for (int row = 0; row < rowCount; row++) {
final BigDecimal[] dataRow = data[row];
final BigDecimal[] outDataRow = outData[row];
for (int col = 0; col < columnCount; col++) {
outDataRow[col] = dataRow[col].subtract(getEntry(row, col));
}
}
return new BigMatrixImpl(outData, false);
}
}
/**
* Compute this minus <code>m</code>.
*
* @param m matrix to be subtracted
* @return this + m
* @throws IllegalArgumentException if m is not the same size as this
*/
public BigMatrixImpl subtract(BigMatrixImpl m) throws IllegalArgumentException {
final int rowCount = getRowDimension();
final int columnCount = getColumnDimension();
if (columnCount != m.getColumnDimension() || rowCount != m.getRowDimension()) {
throw new IllegalArgumentException("matrix dimension mismatch");
}
int rowCount = this.getRowDimension();
int columnCount = this.getColumnDimension();
BigDecimal[][] outData = new BigDecimal[rowCount][columnCount];
final BigDecimal[][] outData = new BigDecimal[rowCount][columnCount];
for (int row = 0; row < rowCount; row++) {
final BigDecimal[] dataRow = data[row];
final BigDecimal[] mRow = m.data[row];
final BigDecimal[] outDataRow = outData[row];
for (int col = 0; col < columnCount; col++) {
outData[row][col] = data[row][col].subtract(m.getEntry(row, col));
}
outDataRow[col] = dataRow[col].subtract(mRow[col]);
}
}
return new BigMatrixImpl(outData);
return new BigMatrixImpl(outData, false);
}
/**
* Returns the result of adding d to each entry of this.
*
@ -260,34 +361,38 @@ public class BigMatrixImpl implements BigMatrix, Serializable {
* @return d + this
*/
public BigMatrix scalarAdd(BigDecimal d) {
int rowCount = this.getRowDimension();
int columnCount = this.getColumnDimension();
BigDecimal[][] outData = new BigDecimal[rowCount][columnCount];
final int rowCount = getRowDimension();
final int columnCount = getColumnDimension();
final BigDecimal[][] outData = new BigDecimal[rowCount][columnCount];
for (int row = 0; row < rowCount; row++) {
final BigDecimal[] dataRow = data[row];
final BigDecimal[] outDataRow = outData[row];
for (int col = 0; col < columnCount; col++) {
outData[row][col] = data[row][col].add(d);
outDataRow[col] = dataRow[col].add(d);
}
}
return new BigMatrixImpl(outData);
return new BigMatrixImpl(outData, false);
}
/**
* Returns the result multiplying each entry of this by <code>d</code>
* Returns the result of multiplying each entry of this by <code>d</code>
* @param d value to multiply all entries by
* @return d * this
*/
public BigMatrix scalarMultiply(BigDecimal d) {
int rowCount = this.getRowDimension();
int columnCount = this.getColumnDimension();
BigDecimal[][] outData = new BigDecimal[rowCount][columnCount];
final int rowCount = getRowDimension();
final int columnCount = getColumnDimension();
final BigDecimal[][] outData = new BigDecimal[rowCount][columnCount];
for (int row = 0; row < rowCount; row++) {
final BigDecimal[] dataRow = data[row];
final BigDecimal[] outDataRow = outData[row];
for (int col = 0; col < columnCount; col++) {
outData[row][col] = data[row][col].multiply(d);
outDataRow[col] = dataRow[col].multiply(d);
}
}
return new BigMatrixImpl(outData);
return new BigMatrixImpl(outData, false);
}
/**
* Returns the result of postmultiplying this by <code>m</code>.
* @param m matrix to postmultiply by
@ -296,26 +401,60 @@ public class BigMatrixImpl implements BigMatrix, Serializable {
* if columnDimension(this) != rowDimension(m)
*/
public BigMatrix multiply(BigMatrix m) throws IllegalArgumentException {
try {
return multiply((BigMatrixImpl) m);
} catch (ClassCastException cce) {
if (this.getColumnDimension() != m.getRowDimension()) {
throw new IllegalArgumentException("Matrices are not multiplication compatible.");
}
final int nRows = this.getRowDimension();
final int nCols = m.getColumnDimension();
final int nSum = this.getColumnDimension();
final BigDecimal[][] outData = new BigDecimal[nRows][nCols];
for (int row = 0; row < nRows; row++) {
final BigDecimal[] dataRow = data[row];
final BigDecimal[] outDataRow = outData[row];
for (int col = 0; col < nCols; col++) {
BigDecimal sum = ZERO;
for (int i = 0; i < nSum; i++) {
sum = sum.add(dataRow[i].multiply(m.getEntry(i, col)));
}
outDataRow[col] = sum;
}
}
return new BigMatrixImpl(outData, false);
}
}
/**
* Returns the result of postmultiplying this by <code>m</code>.
* @param m matrix to postmultiply by
* @return this*m
* @throws IllegalArgumentException
* if columnDimension(this) != rowDimension(m)
*/
public BigMatrixImpl multiply(BigMatrixImpl m) throws IllegalArgumentException {
if (this.getColumnDimension() != m.getRowDimension()) {
throw new IllegalArgumentException("Matrices are not multiplication compatible.");
}
int nRows = this.getRowDimension();
int nCols = m.getColumnDimension();
int nSum = this.getColumnDimension();
BigDecimal[][] outData = new BigDecimal[nRows][nCols];
BigDecimal sum = ZERO;
final int nRows = this.getRowDimension();
final int nCols = m.getColumnDimension();
final int nSum = this.getColumnDimension();
final BigDecimal[][] outData = new BigDecimal[nRows][nCols];
for (int row = 0; row < nRows; row++) {
final BigDecimal[] dataRow = data[row];
final BigDecimal[] outDataRow = outData[row];
for (int col = 0; col < nCols; col++) {
sum = ZERO;
BigDecimal sum = ZERO;
for (int i = 0; i < nSum; i++) {
sum = sum.add(data[row][i].multiply(m.getEntry(i, col)));
sum = sum.add(dataRow[i].multiply(m.data[i][col]));
}
outData[row][col] = sum;
outDataRow[col] = sum;
}
}
return new BigMatrixImpl(outData);
}
return new BigMatrixImpl(outData, false);
}
/**
* Returns the result premultiplying this by <code>m</code>.
* @param m matrix to premultiply by
@ -326,7 +465,7 @@ public class BigMatrixImpl implements BigMatrix, Serializable {
public BigMatrix preMultiply(BigMatrix m) throws IllegalArgumentException {
return m.multiply(this);
}
/**
* Returns matrix entries as a two-dimensional array.
* <p>
@ -347,11 +486,11 @@ public class BigMatrixImpl implements BigMatrix, Serializable {
* @return 2-dimensional array of entries
*/
public double[][] getDataAsDoubleArray() {
int nRows = getRowDimension();
int nCols = getColumnDimension();
double d[][] = new double[nRows][nCols];
final int nRows = getRowDimension();
final int nCols = getColumnDimension();
final double d[][] = new double[nRows][nCols];
for (int i = 0; i < nRows; i++) {
for (int j=0; j<nCols;j++) {
for (int j = 0; j < nCols; j++) {
d[i][j] = data[i][j].doubleValue();
}
}
@ -437,23 +576,23 @@ public class BigMatrixImpl implements BigMatrix, Serializable {
* specified rows and columns
* @exception MatrixIndexException if row or column selections are not valid
*/
public BigMatrix getSubMatrix(int startRow, int endRow, int startColumn,
int endColumn) throws MatrixIndexException {
public BigMatrix getSubMatrix(int startRow, int endRow,
int startColumn, int endColumn)
throws MatrixIndexException {
if (startRow < 0 || startRow > endRow || endRow > data.length ||
startColumn < 0 || startColumn > endColumn ||
endColumn > data[0].length ) {
throw new MatrixIndexException(
"invalid row or column index selection");
}
BigMatrixImpl subMatrix = new BigMatrixImpl(endRow - startRow+1,
endColumn - startColumn+1);
BigDecimal[][] subMatrixData = subMatrix.getDataRef();
final BigDecimal[][] subMatrixData =
new BigDecimal[endRow - startRow + 1][endColumn - startColumn + 1];
for (int i = startRow; i <= endRow; i++) {
for (int j = startColumn; j <= endColumn; j++) {
subMatrixData[i - startRow][j - startColumn] = data[i][j];
}
System.arraycopy(data[i], startColumn,
subMatrixData[i - startRow], 0,
endColumn - startColumn + 1);
}
return subMatrix;
return new BigMatrixImpl(subMatrixData, false);
}
/**
@ -468,25 +607,26 @@ public class BigMatrixImpl implements BigMatrix, Serializable {
* are not valid
*/
public BigMatrix getSubMatrix(int[] selectedRows, int[] selectedColumns)
throws MatrixIndexException {
throws MatrixIndexException {
if (selectedRows.length * selectedColumns.length == 0) {
throw new MatrixIndexException(
"selected row and column index arrays must be non-empty");
}
BigMatrixImpl subMatrix = new BigMatrixImpl(selectedRows.length,
selectedColumns.length);
BigDecimal[][] subMatrixData = subMatrix.getDataRef();
final BigDecimal[][] subMatrixData =
new BigDecimal[selectedRows.length][selectedColumns.length];
try {
for (int i = 0; i < selectedRows.length; i++) {
final BigDecimal[] subI = subMatrixData[i];
final BigDecimal[] dataSelectedI = data[selectedRows[i]];
for (int j = 0; j < selectedColumns.length; j++) {
subMatrixData[i][j] = data[selectedRows[i]][selectedColumns[j]];
subI[j] = dataSelectedI[selectedColumns[j]];
}
}
}
catch (ArrayIndexOutOfBoundsException e) {
throw new MatrixIndexException("matrix dimension mismatch");
}
return subMatrix;
return new BigMatrixImpl(subMatrixData, false);
}
/**
@ -522,12 +662,12 @@ public class BigMatrixImpl implements BigMatrix, Serializable {
throw new MatrixIndexException
("invalid row or column index selection");
}
int nRows = subMatrix.length;
final int nRows = subMatrix.length;
if (nRows == 0) {
throw new IllegalArgumentException(
"Matrix must have at least one row.");
}
int nCols = subMatrix[0].length;
final int nCols = subMatrix[0].length;
if (nCols == 0) {
throw new IllegalArgumentException(
"Matrix must have at least one column.");
@ -566,10 +706,10 @@ public class BigMatrixImpl implements BigMatrix, Serializable {
if ( !isValidCoordinate( row, 0)) {
throw new MatrixIndexException("illegal row argument");
}
int ncols = this.getColumnDimension();
BigDecimal[][] out = new BigDecimal[1][ncols];
final int ncols = this.getColumnDimension();
final BigDecimal[][] out = new BigDecimal[1][ncols];
System.arraycopy(data[row], 0, out[0], 0, ncols);
return new BigMatrixImpl(out);
return new BigMatrixImpl(out, false);
}
/**
@ -584,12 +724,12 @@ public class BigMatrixImpl implements BigMatrix, Serializable {
if ( !isValidCoordinate( 0, column)) {
throw new MatrixIndexException("illegal column argument");
}
int nRows = this.getRowDimension();
BigDecimal[][] out = new BigDecimal[nRows][1];
final int nRows = this.getRowDimension();
final BigDecimal[][] out = new BigDecimal[nRows][1];
for (int row = 0; row < nRows; row++) {
out[row][0] = data[row][column];
}
return new BigMatrixImpl(out);
return new BigMatrixImpl(out, false);
}
/**
@ -606,8 +746,8 @@ public class BigMatrixImpl implements BigMatrix, Serializable {
if ( !isValidCoordinate( row, 0 ) ) {
throw new MatrixIndexException("illegal row argument");
}
int ncols = this.getColumnDimension();
BigDecimal[] out = new BigDecimal[ncols];
final int ncols = this.getColumnDimension();
final BigDecimal[] out = new BigDecimal[ncols];
System.arraycopy(data[row], 0, out, 0, ncols);
return out;
}
@ -627,8 +767,8 @@ public class BigMatrixImpl implements BigMatrix, Serializable {
if ( !isValidCoordinate( row, 0 ) ) {
throw new MatrixIndexException("illegal row argument");
}
int ncols = this.getColumnDimension();
double[] out = new double[ncols];
final int ncols = this.getColumnDimension();
final double[] out = new double[ncols];
for (int i=0;i<ncols;i++) {
out[i] = data[row][i].doubleValue();
}
@ -649,8 +789,8 @@ public class BigMatrixImpl implements BigMatrix, Serializable {
if ( !isValidCoordinate(0, col) ) {
throw new MatrixIndexException("illegal column argument");
}
int nRows = this.getRowDimension();
BigDecimal[] out = new BigDecimal[nRows];
final int nRows = this.getRowDimension();
final BigDecimal[] out = new BigDecimal[nRows];
for (int i = 0; i < nRows; i++) {
out[i] = data[i][col];
}
@ -672,8 +812,8 @@ public class BigMatrixImpl implements BigMatrix, Serializable {
if ( !isValidCoordinate( 0, col ) ) {
throw new MatrixIndexException("illegal column argument");
}
int nrows = this.getRowDimension();
double[] out = new double[nrows];
final int nrows = this.getRowDimension();
final double[] out = new double[nrows];
for (int i=0;i<nrows;i++) {
out[i] = data[i][col].doubleValue();
}
@ -697,10 +837,11 @@ public class BigMatrixImpl implements BigMatrix, Serializable {
*/
public BigDecimal getEntry(int row, int column)
throws MatrixIndexException {
if (!isValidCoordinate(row,column)) {
try {
return data[row][column];
} catch (ArrayIndexOutOfBoundsException e) {
throw new MatrixIndexException("matrix entry does not exist");
}
return data[row][column];
}
/**
@ -729,16 +870,16 @@ public class BigMatrixImpl implements BigMatrix, Serializable {
* @return transpose matrix
*/
public BigMatrix transpose() {
int nRows = this.getRowDimension();
int nCols = this.getColumnDimension();
BigMatrixImpl out = new BigMatrixImpl(nCols, nRows);
BigDecimal[][] outData = out.getDataRef();
final int nRows = this.getRowDimension();
final int nCols = this.getColumnDimension();
final BigDecimal[][] outData = new BigDecimal[nCols][nRows];
for (int row = 0; row < nRows; row++) {
final BigDecimal[] dataRow = data[row];
for (int col = 0; col < nCols; col++) {
outData[col][row] = data[row][col];
outData[col][row] = dataRow[col];
}
}
return out;
return new BigMatrixImpl(outData, false);
}
/**
@ -748,8 +889,7 @@ public class BigMatrixImpl implements BigMatrix, Serializable {
* @throws InvalidMatrixException if this is not invertible
*/
public BigMatrix inverse() throws InvalidMatrixException {
return solve(MatrixUtils.createBigIdentityMatrix
(this.getRowDimension()));
return solve(MatrixUtils.createBigIdentityMatrix(getRowDimension()));
}
/**
@ -846,9 +986,9 @@ public class BigMatrixImpl implements BigMatrix, Serializable {
if (v.length != this.getColumnDimension()) {
throw new IllegalArgumentException("vector has wrong length");
}
int nRows = this.getRowDimension();
int nCols = this.getColumnDimension();
BigDecimal[] out = new BigDecimal[v.length];
final int nRows = this.getRowDimension();
final int nCols = this.getColumnDimension();
final BigDecimal[] out = new BigDecimal[v.length];
for (int row = 0; row < nRows; row++) {
BigDecimal sum = ZERO;
for (int i = 0; i < nCols; i++) {
@ -867,8 +1007,8 @@ public class BigMatrixImpl implements BigMatrix, Serializable {
* @throws IllegalArgumentException if columnDimension != v.size()
*/
public BigDecimal[] operate(double[] v) throws IllegalArgumentException {
BigDecimal bd[] = new BigDecimal[v.length];
for (int i=0;i<bd.length;i++) {
final BigDecimal bd[] = new BigDecimal[v.length];
for (int i = 0; i < bd.length; i++) {
bd[i] = new BigDecimal(v[i]);
}
return operate(bd);
@ -882,12 +1022,12 @@ public class BigMatrixImpl implements BigMatrix, Serializable {
* @throws IllegalArgumentException if rowDimension != v.size()
*/
public BigDecimal[] preMultiply(BigDecimal[] v) throws IllegalArgumentException {
int nRows = this.getRowDimension();
final int nRows = this.getRowDimension();
if (v.length != nRows) {
throw new IllegalArgumentException("vector has wrong length");
}
int nCols = this.getColumnDimension();
BigDecimal[] out = new BigDecimal[nCols];
final int nCols = this.getColumnDimension();
final BigDecimal[] out = new BigDecimal[nCols];
for (int col = 0; col < nCols; col++) {
BigDecimal sum = ZERO;
for (int i = 0; i < nRows; i++) {
@ -910,13 +1050,13 @@ public class BigMatrixImpl implements BigMatrix, Serializable {
* @throws InvalidMatrixException if this matrix is not square or is singular
*/
public BigDecimal[] solve(BigDecimal[] b) throws IllegalArgumentException, InvalidMatrixException {
int nRows = this.getRowDimension();
final int nRows = this.getRowDimension();
if (b.length != nRows) {
throw new IllegalArgumentException("constant vector has wrong length");
}
BigMatrix bMatrix = new BigMatrixImpl(b);
BigDecimal[][] solution = ((BigMatrixImpl) (solve(bMatrix))).getDataRef();
BigDecimal[] out = new BigDecimal[nRows];
final BigMatrix bMatrix = new BigMatrixImpl(b);
final BigDecimal[][] solution = ((BigMatrixImpl) (solve(bMatrix))).getDataRef();
final BigDecimal[] out = new BigDecimal[nRows];
for (int row = 0; row < nRows; row++) {
out[row] = solution[row][0];
}
@ -935,8 +1075,8 @@ public class BigMatrixImpl implements BigMatrix, Serializable {
* @throws InvalidMatrixException if this matrix is not square or is singular
*/
public BigDecimal[] solve(double[] b) throws IllegalArgumentException, InvalidMatrixException {
BigDecimal bd[] = new BigDecimal[b.length];
for (int i=0;i<bd.length;i++) {
final BigDecimal bd[] = new BigDecimal[b.length];
for (int i = 0; i < bd.length; i++) {
bd[i] = new BigDecimal(b[i]);
}
return solve(bd);
@ -964,41 +1104,48 @@ public class BigMatrixImpl implements BigMatrix, Serializable {
throw new InvalidMatrixException("Matrix is singular.");
}
int nCol = this.getColumnDimension();
int nColB = b.getColumnDimension();
int nRowB = b.getRowDimension();
final int nCol = this.getColumnDimension();
final int nColB = b.getColumnDimension();
final int nRowB = b.getRowDimension();
// Apply permutations to b
BigDecimal[][] bp = new BigDecimal[nRowB][nColB];
final BigDecimal[][] bp = new BigDecimal[nRowB][nColB];
for (int row = 0; row < nRowB; row++) {
final BigDecimal[] bpRow = bp[row];
for (int col = 0; col < nColB; col++) {
bp[row][col] = b.getEntry(permutation[row], col);
bpRow[col] = b.getEntry(permutation[row], col);
}
}
// Solve LY = b
for (int col = 0; col < nCol; col++) {
for (int i = col + 1; i < nCol; i++) {
final BigDecimal[] bpI = bp[i];
final BigDecimal[] luI = lu[i];
for (int j = 0; j < nColB; j++) {
bp[i][j] = bp[i][j].subtract(bp[col][j].multiply(lu[i][col]));
bpI[j] = bpI[j].subtract(bp[col][j].multiply(luI[col]));
}
}
}
// Solve UX = Y
for (int col = nCol - 1; col >= 0; col--) {
final BigDecimal[] bpCol = bp[col];
final BigDecimal luDiag = lu[col][col];
for (int j = 0; j < nColB; j++) {
bp[col][j] = bp[col][j].divide(lu[col][col], scale, roundingMode);
bpCol[j] = bpCol[j].divide(luDiag, scale, roundingMode);
}
for (int i = 0; i < col; i++) {
final BigDecimal[] bpI = bp[i];
final BigDecimal[] luI = lu[i];
for (int j = 0; j < nColB; j++) {
bp[i][j] = bp[i][j].subtract(bp[col][j].multiply(lu[i][col]));
bpI[j] = bpI[j].subtract(bp[col][j].multiply(luI[col]));
}
}
}
BigMatrixImpl outMat = new BigMatrixImpl(bp);
return outMat;
return new BigMatrixImpl(bp, false);
}
/**
@ -1021,8 +1168,8 @@ public class BigMatrixImpl implements BigMatrix, Serializable {
*/
public void luDecompose() throws InvalidMatrixException {
int nRows = this.getRowDimension();
int nCols = this.getColumnDimension();
final int nRows = this.getRowDimension();
final int nCols = this.getColumnDimension();
if (nRows != nCols) {
throw new InvalidMatrixException("LU decomposition requires that the matrix be square.");
}
@ -1042,22 +1189,24 @@ public class BigMatrixImpl implements BigMatrix, Serializable {
// upper
for (int row = 0; row < col; row++) {
sum = lu[row][col];
final BigDecimal[] luRow = lu[row];
sum = luRow[col];
for (int i = 0; i < row; i++) {
sum = sum.subtract(lu[row][i].multiply(lu[i][col]));
sum = sum.subtract(luRow[i].multiply(lu[i][col]));
}
lu[row][col] = sum;
luRow[col] = sum;
}
// lower
int max = col; // permutation row
BigDecimal largest = ZERO;
for (int row = col; row < nRows; row++) {
sum = lu[row][col];
final BigDecimal[] luRow = lu[row];
sum = luRow[col];
for (int i = 0; i < col; i++) {
sum = sum.subtract(lu[row][i].multiply(lu[i][col]));
sum = sum.subtract(luRow[i].multiply(lu[i][col]));
}
lu[row][col] = sum;
luRow[col] = sum;
// maintain best permutation choice
if (sum.abs().compareTo(largest) == 1) {
@ -1086,9 +1235,11 @@ public class BigMatrixImpl implements BigMatrix, Serializable {
parity = -parity;
}
//Divide the lower elements by the "winning" diagonal elt.
// Divide the lower elements by the "winning" diagonal elt.
final BigDecimal luDiag = lu[col][col];
for (int row = col + 1; row < nRows; row++) {
lu[row][col] = lu[row][col].divide(lu[col][col], scale, roundingMode);
final BigDecimal[] luRow = lu[row];
luRow[col] = luRow[col].divide(luDiag, scale, roundingMode);
}
}
@ -1104,12 +1255,14 @@ public class BigMatrixImpl implements BigMatrix, Serializable {
res.append("BigMatrixImpl{");
if (data != null) {
for (int i = 0; i < data.length; i++) {
if (i > 0)
if (i > 0) {
res.append(",");
}
res.append("{");
for (int j = 0; j < data[0].length; j++) {
if (j > 0)
if (j > 0) {
res.append(",");
}
res.append(data[i][j]);
}
res.append("}");
@ -1135,15 +1288,16 @@ public class BigMatrixImpl implements BigMatrix, Serializable {
if (object instanceof BigMatrixImpl == false) {
return false;
}
BigMatrix m = (BigMatrix) object;
int nRows = getRowDimension();
int nCols = getColumnDimension();
final BigMatrix m = (BigMatrix) object;
final int nRows = getRowDimension();
final int nCols = getColumnDimension();
if (m.getColumnDimension() != nCols || m.getRowDimension() != nRows) {
return false;
}
for (int row = 0; row < nRows; row++) {
final BigDecimal[] dataRow = data[row];
for (int col = 0; col < nCols; col++) {
if (!data[row][col].equals(m.getEntry(row, col))) {
if (!dataRow[col].equals(m.getEntry(row, col))) {
return false;
}
}
@ -1158,14 +1312,15 @@ public class BigMatrixImpl implements BigMatrix, Serializable {
*/
public int hashCode() {
int ret = 7;
int nRows = getRowDimension();
int nCols = getColumnDimension();
final int nRows = getRowDimension();
final int nCols = getColumnDimension();
ret = ret * 31 + nRows;
ret = ret * 31 + nCols;
for (int row = 0; row < nRows; row++) {
final BigDecimal[] dataRow = data[row];
for (int col = 0; col < nCols; col++) {
ret = ret * 31 + (11 * (row+1) + 17 * (col+1)) *
data[row][col].hashCode();
dataRow[col].hashCode();
}
}
return ret;
@ -1220,7 +1375,7 @@ public class BigMatrixImpl implements BigMatrix, Serializable {
* @return the permutation
*/
protected int[] getPermutation() {
int[] out = new int[permutation.length];
final int[] out = new int[permutation.length];
System.arraycopy(permutation, 0, out, 0, permutation.length);
return out;
}
@ -1233,8 +1388,8 @@ public class BigMatrixImpl implements BigMatrix, Serializable {
* @return a copy of the underlying data array.
*/
private BigDecimal[][] copyOut() {
int nRows = this.getRowDimension();
BigDecimal[][] out = new BigDecimal[nRows][this.getColumnDimension()];
final int nRows = this.getRowDimension();
final BigDecimal[][] out = new BigDecimal[nRows][this.getColumnDimension()];
// can't copy 2-d array in one shot, otherwise get row references
for (int i = 0; i < nRows; i++) {
System.arraycopy(data[i], 0, out[i], 0, data[i].length);
@ -1262,12 +1417,14 @@ public class BigMatrixImpl implements BigMatrix, Serializable {
* @param in data to copy in
*/
private void copyIn(double[][] in) {
int nRows = in.length;
int nCols = in[0].length;
final int nRows = in.length;
final int nCols = in[0].length;
data = new BigDecimal[nRows][nCols];
for (int i = 0; i < nRows; i++) {
for (int j=0; j < nCols; j++) {
data[i][j] = new BigDecimal(in[i][j]);
final BigDecimal[] dataI = data[i];
final double[] inI = in[i];
for (int j = 0; j < nCols; j++) {
dataI[j] = new BigDecimal(inI[j]);
}
}
lu = null;
@ -1280,12 +1437,14 @@ public class BigMatrixImpl implements BigMatrix, Serializable {
* @param in data to copy in
*/
private void copyIn(String[][] in) {
int nRows = in.length;
int nCols = in[0].length;
final int nRows = in.length;
final int nCols = in[0].length;
data = new BigDecimal[nRows][nCols];
for (int i = 0; i < nRows; i++) {
for (int j=0; j < nCols; j++) {
data[i][j] = new BigDecimal(in[i][j]);
final BigDecimal[] dataI = data[i];
final String[] inI = in[i];
for (int j = 0; j < nCols; j++) {
dataI[j] = new BigDecimal(inI[j]);
}
}
lu = null;
@ -1299,9 +1458,8 @@ public class BigMatrixImpl implements BigMatrix, Serializable {
* @return true if the coordinate is with the current dimensions
*/
private boolean isValidCoordinate(int row, int col) {
int nRows = this.getRowDimension();
int nCols = this.getColumnDimension();
final int nRows = this.getRowDimension();
final int nCols = this.getColumnDimension();
return !(row < 0 || row >= nRows || col < 0 || col >= nCols);
}

View File

@ -18,6 +18,7 @@
package org.apache.commons.math.linear;
import java.math.BigDecimal;
import java.util.Arrays;
/**
* A collection of static methods that operate on or return matrices.
@ -41,12 +42,33 @@ public class MatrixUtils {
* @return RealMatrix containing the values of the array
* @throws IllegalArgumentException if <code>data</code> is not rectangular
* (not all rows have the same length) or empty
* @throws NullPointerException if data is null
* @throws NullPointerException if <code>data</code> is null
* @see #createRealMatrix(double[][], boolean)
*/
public static RealMatrix createRealMatrix(double[][] data) {
return new RealMatrixImpl(data);
}
/**
* Returns a {@link RealMatrix} whose entries are the the values in the
* the input array.
* <p>If an array is built specially in order to be embedded in a
* RealMatrix and not used directly, the <code>copyArray</code> may be
* set to <code>false</code. This will prevent the copying and improve
* performance as no new array will be built and no data will be copied.</p>
* @param data data for new matrix
* @param copyArray if true, the input array will be copied, otherwise
* it will be referenced
* @return RealMatrix containing the values of the array
* @throws IllegalArgumentException if <code>data</code> is not rectangular
* (not all rows have the same length) or empty
* @throws NullPointerException if <code>data</code> is null
* @see #createRealMatrix(double[][])
*/
public static RealMatrix createRealMatrix(double[][] data, boolean copyArray) {
return new RealMatrixImpl(data, copyArray);
}
/**
* Returns <code>dimension x dimension</code> identity matrix.
*
@ -56,14 +78,11 @@ public class MatrixUtils {
* @since 1.1
*/
public static RealMatrix createRealIdentityMatrix(int dimension) {
RealMatrixImpl out = new RealMatrixImpl(dimension, dimension);
double[][] d = out.getDataRef();
double[][] d = new double[dimension][dimension];
for (int row = 0; row < dimension; row++) {
for (int col = 0; col < dimension; col++) {
d[row][col] = row == col ? 1d : 0d;
}
d[row][row] = 1d;
}
return out;
return new RealMatrixImpl(d, false);
}
/**
@ -93,7 +112,27 @@ public class MatrixUtils {
public static BigMatrix createBigMatrix(BigDecimal[][] data) {
return new BigMatrixImpl(data);
}
/**
* Returns a {@link BigMatrix} whose entries are the the values in the
* the input array.
* <p>If an array is built specially in order to be embedded in a
* BigMatrix and not used directly, the <code>copyArray</code> may be
* set to <code>false</code. This will prevent the copying and improve
* performance as no new array will be built and no data will be copied.</p>
* @param data data for new matrix
* @param copyArray if true, the input array will be copied, otherwise
* it will be referenced
* @return BigMatrix containing the values of the array
* @throws IllegalArgumentException if <code>data</code> is not rectangular
* (not all rows have the same length) or empty
* @throws NullPointerException if <code>data</code> is null
* @see #createRealMatrix(double[][])
*/
public static BigMatrix createBigMatrix(BigDecimal[][] data, boolean copyArray) {
return new BigMatrixImpl(data, copyArray);
}
/**
* Returns a {@link BigMatrix} whose entries are the the values in the
* the input array. The input array is copied, not referenced.
@ -118,10 +157,10 @@ public class MatrixUtils {
* @throws NullPointerException if <code>rowData</code>is null
*/
public static RealMatrix createRowRealMatrix(double[] rowData) {
int nCols = rowData.length;
double[][] data = new double[1][nCols];
final int nCols = rowData.length;
final double[][] data = new double[1][nCols];
System.arraycopy(rowData, 0, data[0], 0, nCols);
return new RealMatrixImpl(data);
return new RealMatrixImpl(data, false);
}
/**
@ -134,10 +173,12 @@ public class MatrixUtils {
* @throws NullPointerException if <code>rowData</code>is null
*/
public static BigMatrix createRowBigMatrix(double[] rowData) {
int nCols = rowData.length;
double[][] data = new double[1][nCols];
System.arraycopy(rowData, 0, data[0], 0, nCols);
return new BigMatrixImpl(data);
final int nCols = rowData.length;
final BigDecimal[][] data = new BigDecimal[1][nCols];
for (int i = 0; i < nCols; ++i) {
data[0][i] = new BigDecimal(rowData[i]);
}
return new BigMatrixImpl(data, false);
}
/**
@ -150,10 +191,10 @@ public class MatrixUtils {
* @throws NullPointerException if <code>rowData</code>is null
*/
public static BigMatrix createRowBigMatrix(BigDecimal[] rowData) {
int nCols = rowData.length;
BigDecimal[][] data = new BigDecimal[1][nCols];
final int nCols = rowData.length;
final BigDecimal[][] data = new BigDecimal[1][nCols];
System.arraycopy(rowData, 0, data[0], 0, nCols);
return new BigMatrixImpl(data);
return new BigMatrixImpl(data, false);
}
/**
@ -166,10 +207,12 @@ public class MatrixUtils {
* @throws NullPointerException if <code>rowData</code>is null
*/
public static BigMatrix createRowBigMatrix(String[] rowData) {
int nCols = rowData.length;
String[][] data = new String[1][nCols];
System.arraycopy(rowData, 0, data[0], 0, nCols);
return new BigMatrixImpl(data);
final int nCols = rowData.length;
final BigDecimal[][] data = new BigDecimal[1][nCols];
for (int i = 0; i < nCols; ++i) {
data[0][i] = new BigDecimal(rowData[i]);
}
return new BigMatrixImpl(data, false);
}
/**
@ -182,12 +225,12 @@ public class MatrixUtils {
* @throws NullPointerException if <code>columnData</code>is null
*/
public static RealMatrix createColumnRealMatrix(double[] columnData) {
int nRows = columnData.length;
double[][] data = new double[nRows][1];
final int nRows = columnData.length;
final double[][] data = new double[nRows][1];
for (int row = 0; row < nRows; row++) {
data[row][0] = columnData[row];
}
return new RealMatrixImpl(data);
return new RealMatrixImpl(data, false);
}
/**
@ -200,12 +243,12 @@ public class MatrixUtils {
* @throws NullPointerException if <code>columnData</code>is null
*/
public static BigMatrix createColumnBigMatrix(double[] columnData) {
int nRows = columnData.length;
double[][] data = new double[nRows][1];
final int nRows = columnData.length;
final BigDecimal[][] data = new BigDecimal[nRows][1];
for (int row = 0; row < nRows; row++) {
data[row][0] = columnData[row];
data[row][0] = new BigDecimal(columnData[row]);
}
return new BigMatrixImpl(data);
return new BigMatrixImpl(data, false);
}
/**
@ -218,12 +261,12 @@ public class MatrixUtils {
* @throws NullPointerException if <code>columnData</code>is null
*/
public static BigMatrix createColumnBigMatrix(BigDecimal[] columnData) {
int nRows = columnData.length;
BigDecimal[][] data = new BigDecimal[nRows][1];
final int nRows = columnData.length;
final BigDecimal[][] data = new BigDecimal[nRows][1];
for (int row = 0; row < nRows; row++) {
data[row][0] = columnData[row];
}
return new BigMatrixImpl(data);
return new BigMatrixImpl(data, false);
}
/**
@ -237,11 +280,11 @@ public class MatrixUtils {
*/
public static BigMatrix createColumnBigMatrix(String[] columnData) {
int nRows = columnData.length;
String[][] data = new String[nRows][1];
final BigDecimal[][] data = new BigDecimal[nRows][1];
for (int row = 0; row < nRows; row++) {
data[row][0] = columnData[row];
data[row][0] = new BigDecimal(columnData[row]);
}
return new BigMatrixImpl(data);
return new BigMatrixImpl(data, false);
}
/**
@ -253,14 +296,13 @@ public class MatrixUtils {
* @since 1.1
*/
public static BigMatrix createBigIdentityMatrix(int dimension) {
BigMatrixImpl out = new BigMatrixImpl(dimension, dimension);
BigDecimal[][] d = out.getDataRef();
final BigDecimal[][] d = new BigDecimal[dimension][dimension];
for (int row = 0; row < dimension; row++) {
for (int col = 0; col < dimension; col++) {
d[row][col] = row == col ? BigMatrixImpl.ONE : BigMatrixImpl.ZERO;
}
final BigDecimal[] dRow = d[row];
Arrays.fill(dRow, BigMatrixImpl.ZERO);
dRow[row] = BigMatrixImpl.ONE;
}
return out;
return new BigMatrixImpl(d, false);
}
}

View File

@ -52,7 +52,7 @@ import org.apache.commons.math.util.MathUtils;
public class RealMatrixImpl implements RealMatrix, Serializable {
/** Serializable version identifier */
private static final long serialVersionUID = 4237564493130426188L;
private static final long serialVersionUID = -4828886979278117018L;
/** Entries of the matrix */
private double data[][] = null;
@ -97,16 +97,58 @@ public class RealMatrixImpl implements RealMatrix, Serializable {
/**
* Create a new RealMatrix using the input array as the underlying
* data array.
* <p>
* The input array is copied, not referenced.</p>
* <p>The input array is copied, not referenced. This constructor has
* the same effect as calling {@link #RealMatrixImpl(double[][], boolean)}
* with the second argument set to <code>true</code>.</p>
*
* @param d data for new matrix
* @throws IllegalArgumentException if <code>d</code> is not rectangular
* (not all rows have the same length) or empty
* @throws NullPointerException if <code>d</code> is null
* @see #RealMatrixImpl(double[][], boolean)
*/
public RealMatrixImpl(double[][] d) {
this.copyIn(d);
copyIn(d);
lu = null;
}
/**
* Create a new RealMatrix using the input array as the underlying
* data array.
* <p>If an array is built specially in order to be embedded in a
* RealMatrix and not used directly, the <code>copyArray</code> may be
* set to <code>false</code. This will prevent the copying and improve
* performance as no new array will be built and no data will be copied.</p>
* @param d data for new matrix
* @param copyArray if true, the input array will be copied, otherwise
* it will be referenced
* @throws IllegalArgumentException if <code>d</code> is not rectangular
* (not all rows have the same length) or empty
* @throws NullPointerException if <code>d</code> is null
* @see #RealMatrixImpl(double[][])
*/
public RealMatrixImpl(double[][] d, boolean copyArray) {
if (copyArray) {
copyIn(d);
} else {
if (d == null) {
throw new NullPointerException();
}
final int nRows = d.length;
if (nRows == 0) {
throw new IllegalArgumentException("Matrix must have at least one row.");
}
final int nCols = d[0].length;
if (nCols == 0) {
throw new IllegalArgumentException("Matrix must have at least one column.");
}
for (int r = 1; r < nRows; r++) {
if (d[r].length != nCols) {
throw new IllegalArgumentException("All input rows must have the same length.");
}
}
data = d;
}
lu = null;
}
@ -114,13 +156,12 @@ public class RealMatrixImpl implements RealMatrix, Serializable {
* Create a new (column) RealMatrix using <code>v</code> as the
* data for the unique column of the <code>v.length x 1</code> matrix
* created.
* <p>
* The input array is copied, not referenced.</p>
* <p>The input array is copied, not referenced.</p>
*
* @param v column vector holding data for new matrix
*/
public RealMatrixImpl(double[] v) {
int nRows = v.length;
final int nRows = v.length;
data = new double[nRows][1];
for (int row = 0; row < nRows; row++) {
data[row][0] = v[row];
@ -133,7 +174,7 @@ public class RealMatrixImpl implements RealMatrix, Serializable {
* @return the cloned matrix
*/
public RealMatrix copy() {
return new RealMatrixImpl(this.copyOut());
return new RealMatrixImpl(copyOut(), false);
}
/**
@ -144,19 +185,49 @@ public class RealMatrixImpl implements RealMatrix, Serializable {
* @throws IllegalArgumentException if m is not the same size as this
*/
public RealMatrix add(RealMatrix m) throws IllegalArgumentException {
if (this.getColumnDimension() != m.getColumnDimension() ||
this.getRowDimension() != m.getRowDimension()) {
try {
return add((RealMatrixImpl) m);
} catch (ClassCastException cce) {
final int rowCount = getRowDimension();
final int columnCount = getColumnDimension();
if (columnCount != m.getColumnDimension() || rowCount != m.getRowDimension()) {
throw new IllegalArgumentException("matrix dimension mismatch");
}
final double[][] outData = new double[rowCount][columnCount];
for (int row = 0; row < rowCount; row++) {
final double[] dataRow = data[row];
final double[] outDataRow = outData[row];
for (int col = 0; col < columnCount; col++) {
outDataRow[col] = dataRow[col] + m.getEntry(row, col);
}
}
return new RealMatrixImpl(outData, false);
}
}
/**
* Compute the sum of this and <code>m</code>.
*
* @param m matrix to be added
* @return this + m
* @throws IllegalArgumentException if m is not the same size as this
*/
public RealMatrixImpl add(RealMatrixImpl m) throws IllegalArgumentException {
final int rowCount = getRowDimension();
final int columnCount = getColumnDimension();
if (columnCount != m.getColumnDimension() || rowCount != m.getRowDimension()) {
throw new IllegalArgumentException("matrix dimension mismatch");
}
int rowCount = this.getRowDimension();
int columnCount = this.getColumnDimension();
double[][] outData = new double[rowCount][columnCount];
final double[][] outData = new double[rowCount][columnCount];
for (int row = 0; row < rowCount; row++) {
final double[] dataRow = data[row];
final double[] mRow = m.data[row];
final double[] outDataRow = outData[row];
for (int col = 0; col < columnCount; col++) {
outData[row][col] = data[row][col] + m.getEntry(row, col);
outDataRow[col] = dataRow[col] + mRow[col];
}
}
return new RealMatrixImpl(outData);
return new RealMatrixImpl(outData, false);
}
/**
@ -167,19 +238,49 @@ public class RealMatrixImpl implements RealMatrix, Serializable {
* @throws IllegalArgumentException if m is not the same size as this
*/
public RealMatrix subtract(RealMatrix m) throws IllegalArgumentException {
if (this.getColumnDimension() != m.getColumnDimension() ||
this.getRowDimension() != m.getRowDimension()) {
try {
return subtract((RealMatrixImpl) m);
} catch (ClassCastException cce) {
final int rowCount = getRowDimension();
final int columnCount = getColumnDimension();
if (columnCount != m.getColumnDimension() || rowCount != m.getRowDimension()) {
throw new IllegalArgumentException("matrix dimension mismatch");
}
final double[][] outData = new double[rowCount][columnCount];
for (int row = 0; row < rowCount; row++) {
final double[] dataRow = data[row];
final double[] outDataRow = outData[row];
for (int col = 0; col < columnCount; col++) {
outDataRow[col] = dataRow[col] - m.getEntry(row, col);
}
}
return new RealMatrixImpl(outData, false);
}
}
/**
* Compute this minus <code>m</code>.
*
* @param m matrix to be subtracted
* @return this + m
* @throws IllegalArgumentException if m is not the same size as this
*/
public RealMatrixImpl subtract(RealMatrixImpl m) throws IllegalArgumentException {
final int rowCount = getRowDimension();
final int columnCount = getColumnDimension();
if (columnCount != m.getColumnDimension() || rowCount != m.getRowDimension()) {
throw new IllegalArgumentException("matrix dimension mismatch");
}
int rowCount = this.getRowDimension();
int columnCount = this.getColumnDimension();
double[][] outData = new double[rowCount][columnCount];
final double[][] outData = new double[rowCount][columnCount];
for (int row = 0; row < rowCount; row++) {
final double[] dataRow = data[row];
final double[] mRow = m.data[row];
final double[] outDataRow = outData[row];
for (int col = 0; col < columnCount; col++) {
outData[row][col] = data[row][col] - m.getEntry(row, col);
}
outDataRow[col] = dataRow[col] - mRow[col];
}
}
return new RealMatrixImpl(outData);
return new RealMatrixImpl(outData, false);
}
/**
@ -189,32 +290,36 @@ public class RealMatrixImpl implements RealMatrix, Serializable {
* @return d + this
*/
public RealMatrix scalarAdd(double d) {
int rowCount = this.getRowDimension();
int columnCount = this.getColumnDimension();
double[][] outData = new double[rowCount][columnCount];
final int rowCount = getRowDimension();
final int columnCount = getColumnDimension();
final double[][] outData = new double[rowCount][columnCount];
for (int row = 0; row < rowCount; row++) {
final double[] dataRow = data[row];
final double[] outDataRow = outData[row];
for (int col = 0; col < columnCount; col++) {
outData[row][col] = data[row][col] + d;
outDataRow[col] = dataRow[col] + d;
}
}
return new RealMatrixImpl(outData);
return new RealMatrixImpl(outData, false);
}
/**
* Returns the result multiplying each entry of this by <code>d</code>
* Returns the result of multiplying each entry of this by <code>d</code>
* @param d value to multiply all entries by
* @return d * this
*/
public RealMatrix scalarMultiply(double d) {
int rowCount = this.getRowDimension();
int columnCount = this.getColumnDimension();
double[][] outData = new double[rowCount][columnCount];
final int rowCount = getRowDimension();
final int columnCount = getColumnDimension();
final double[][] outData = new double[rowCount][columnCount];
for (int row = 0; row < rowCount; row++) {
final double[] dataRow = data[row];
final double[] outDataRow = outData[row];
for (int col = 0; col < columnCount; col++) {
outData[row][col] = data[row][col] * d;
outDataRow[col] = dataRow[col] * d;
}
}
return new RealMatrixImpl(outData);
return new RealMatrixImpl(outData, false);
}
/**
@ -225,28 +330,62 @@ public class RealMatrixImpl implements RealMatrix, Serializable {
* if columnDimension(this) != rowDimension(m)
*/
public RealMatrix multiply(RealMatrix m) throws IllegalArgumentException {
if (this.getColumnDimension() != m.getRowDimension()) {
throw new IllegalArgumentException("Matrices are not multiplication compatible.");
}
int nRows = this.getRowDimension();
int nCols = m.getColumnDimension();
int nSum = this.getColumnDimension();
double[][] outData = new double[nRows][nCols];
double sum = 0;
for (int row = 0; row < nRows; row++) {
for (int col = 0; col < nCols; col++) {
sum = 0;
for (int i = 0; i < nSum; i++) {
sum += data[row][i] * m.getEntry(i, col);
}
outData[row][col] = sum;
try {
return multiply((RealMatrixImpl) m);
} catch (ClassCastException cce) {
if (this.getColumnDimension() != m.getRowDimension()) {
throw new IllegalArgumentException("Matrices are not multiplication compatible.");
}
final int nRows = this.getRowDimension();
final int nCols = m.getColumnDimension();
final int nSum = this.getColumnDimension();
final double[][] outData = new double[nRows][nCols];
for (int row = 0; row < nRows; row++) {
final double[] dataRow = data[row];
final double[] outDataRow = outData[row];
for (int col = 0; col < nCols; col++) {
double sum = 0;
for (int i = 0; i < nSum; i++) {
sum += dataRow[i] * m.getEntry(i, col);
}
outDataRow[col] = sum;
}
}
return new RealMatrixImpl(outData, false);
}
return new RealMatrixImpl(outData);
}
/**
* Returns the result premultiplying this by <code>m</code>.
* Returns the result of postmultiplying this by <code>m</code>.
* @param m matrix to postmultiply by
* @return this*m
* @throws IllegalArgumentException
* if columnDimension(this) != rowDimension(m)
*/
public RealMatrixImpl multiply(RealMatrixImpl m) throws IllegalArgumentException {
if (this.getColumnDimension() != m.getRowDimension()) {
throw new IllegalArgumentException("Matrices are not multiplication compatible.");
}
final int nRows = this.getRowDimension();
final int nCols = m.getColumnDimension();
final int nSum = this.getColumnDimension();
final double[][] outData = new double[nRows][nCols];
for (int row = 0; row < nRows; row++) {
final double[] dataRow = data[row];
final double[] outDataRow = outData[row];
for (int col = 0; col < nCols; col++) {
double sum = 0;
for (int i = 0; i < nSum; i++) {
sum += dataRow[i] * m.data[i][col];
}
outDataRow[col] = sum;
}
}
return new RealMatrixImpl(outData, false);
}
/**
* Returns the result of premultiplying this by <code>m</code>.
* @param m matrix to premultiply by
* @return m * this
* @throws IllegalArgumentException
@ -306,23 +445,23 @@ public class RealMatrixImpl implements RealMatrix, Serializable {
* specified rows and columns
* @exception MatrixIndexException if row or column selections are not valid
*/
public RealMatrix getSubMatrix(int startRow, int endRow, int startColumn,
int endColumn) throws MatrixIndexException {
public RealMatrix getSubMatrix(int startRow, int endRow,
int startColumn, int endColumn)
throws MatrixIndexException {
if (startRow < 0 || startRow > endRow || endRow > data.length ||
startColumn < 0 || startColumn > endColumn ||
endColumn > data[0].length ) {
endColumn > data[0].length) {
throw new MatrixIndexException(
"invalid row or column index selection");
}
RealMatrixImpl subMatrix = new RealMatrixImpl(endRow - startRow+1,
endColumn - startColumn+1);
double[][] subMatrixData = subMatrix.getDataRef();
final double[][] subMatrixData =
new double[endRow - startRow + 1][endColumn - startColumn + 1];
for (int i = startRow; i <= endRow; i++) {
for (int j = startColumn; j <= endColumn; j++) {
subMatrixData[i - startRow][j - startColumn] = data[i][j];
}
}
return subMatrix;
System.arraycopy(data[i], startColumn,
subMatrixData[i - startRow], 0,
endColumn - startColumn + 1);
}
return new RealMatrixImpl(subMatrixData, false);
}
/**
@ -337,25 +476,25 @@ public class RealMatrixImpl implements RealMatrix, Serializable {
* are not valid
*/
public RealMatrix getSubMatrix(int[] selectedRows, int[] selectedColumns)
throws MatrixIndexException {
throws MatrixIndexException {
if (selectedRows.length * selectedColumns.length == 0) {
throw new MatrixIndexException(
"selected row and column index arrays must be non-empty");
}
RealMatrixImpl subMatrix = new RealMatrixImpl(selectedRows.length,
selectedColumns.length);
double[][] subMatrixData = subMatrix.getDataRef();
final double[][] subMatrixData =
new double[selectedRows.length][selectedColumns.length];
try {
for (int i = 0; i < selectedRows.length; i++) {
final double[] subI = subMatrixData[i];
final double[] dataSelectedI = data[selectedRows[i]];
for (int j = 0; j < selectedColumns.length; j++) {
subMatrixData[i][j] = data[selectedRows[i]][selectedColumns[j]];
subI[j] = dataSelectedI[selectedColumns[j]];
}
}
}
catch (ArrayIndexOutOfBoundsException e) {
} catch (ArrayIndexOutOfBoundsException e) {
throw new MatrixIndexException("matrix dimension mismatch");
}
return subMatrix;
return new RealMatrixImpl(subMatrixData, false);
}
/**
@ -391,12 +530,12 @@ public class RealMatrixImpl implements RealMatrix, Serializable {
throw new MatrixIndexException
("invalid row or column index selection");
}
int nRows = subMatrix.length;
final int nRows = subMatrix.length;
if (nRows == 0) {
throw new IllegalArgumentException(
"Matrix must have at least one row.");
}
int nCols = subMatrix[0].length;
final int nCols = subMatrix[0].length;
if (nCols == 0) {
throw new IllegalArgumentException(
"Matrix must have at least one column.");
@ -435,10 +574,10 @@ public class RealMatrixImpl implements RealMatrix, Serializable {
if ( !isValidCoordinate( row, 0)) {
throw new MatrixIndexException("illegal row argument");
}
int ncols = this.getColumnDimension();
double[][] out = new double[1][ncols];
final int ncols = this.getColumnDimension();
final double[][] out = new double[1][ncols];
System.arraycopy(data[row], 0, out[0], 0, ncols);
return new RealMatrixImpl(out);
return new RealMatrixImpl(out, false);
}
/**
@ -453,12 +592,12 @@ public class RealMatrixImpl implements RealMatrix, Serializable {
if ( !isValidCoordinate( 0, column)) {
throw new MatrixIndexException("illegal column argument");
}
int nRows = this.getRowDimension();
double[][] out = new double[nRows][1];
final int nRows = this.getRowDimension();
final double[][] out = new double[nRows][1];
for (int row = 0; row < nRows; row++) {
out[row][0] = data[row][column];
}
return new RealMatrixImpl(out);
return new RealMatrixImpl(out, false);
}
/**
@ -475,8 +614,8 @@ public class RealMatrixImpl implements RealMatrix, Serializable {
if ( !isValidCoordinate( row, 0 ) ) {
throw new MatrixIndexException("illegal row argument");
}
int ncols = this.getColumnDimension();
double[] out = new double[ncols];
final int ncols = this.getColumnDimension();
final double[] out = new double[ncols];
System.arraycopy(data[row], 0, out, 0, ncols);
return out;
}
@ -495,8 +634,8 @@ public class RealMatrixImpl implements RealMatrix, Serializable {
if ( !isValidCoordinate(0, col) ) {
throw new MatrixIndexException("illegal column argument");
}
int nRows = this.getRowDimension();
double[] out = new double[nRows];
final int nRows = this.getRowDimension();
final double[] out = new double[nRows];
for (int row = 0; row < nRows; row++) {
out[row] = data[row][col];
}
@ -520,10 +659,11 @@ public class RealMatrixImpl implements RealMatrix, Serializable {
*/
public double getEntry(int row, int column)
throws MatrixIndexException {
if (!isValidCoordinate(row,column)) {
try {
return data[row][column];
} catch (ArrayIndexOutOfBoundsException e) {
throw new MatrixIndexException("matrix entry does not exist");
}
return data[row][column];
}
/**
@ -532,16 +672,16 @@ public class RealMatrixImpl implements RealMatrix, Serializable {
* @return transpose matrix
*/
public RealMatrix transpose() {
int nRows = this.getRowDimension();
int nCols = this.getColumnDimension();
RealMatrixImpl out = new RealMatrixImpl(nCols, nRows);
double[][] outData = out.getDataRef();
final int nRows = getRowDimension();
final int nCols = getColumnDimension();
final double[][] outData = new double[nCols][nRows];
for (int row = 0; row < nRows; row++) {
final double[] dataRow = data[row];
for (int col = 0; col < nCols; col++) {
outData[col][row] = data[row][col];
outData[col][row] = dataRow[col];
}
}
return out;
return new RealMatrixImpl(outData, false);
}
/**
@ -551,8 +691,7 @@ public class RealMatrixImpl implements RealMatrix, Serializable {
* @throws InvalidMatrixException if this is not invertible
*/
public RealMatrix inverse() throws InvalidMatrixException {
return solve(MatrixUtils.createRealIdentityMatrix
(this.getRowDimension()));
return solve(MatrixUtils.createRealIdentityMatrix(getRowDimension()));
}
/**
@ -632,16 +771,17 @@ public class RealMatrixImpl implements RealMatrix, Serializable {
* @return resulting vector
*/
public double[] operate(double[] v) throws IllegalArgumentException {
if (v.length != this.getColumnDimension()) {
final int nRows = this.getRowDimension();
final int nCols = this.getColumnDimension();
if (v.length != nCols) {
throw new IllegalArgumentException("vector has wrong length");
}
int nRows = this.getRowDimension();
int nCols = this.getColumnDimension();
double[] out = new double[v.length];
final double[] out = new double[v.length];
for (int row = 0; row < nRows; row++) {
final double[] dataRow = data[row];
double sum = 0;
for (int i = 0; i < nCols; i++) {
sum += data[row][i] * v[i];
sum += dataRow[i] * v[i];
}
out[row] = sum;
}
@ -654,12 +794,12 @@ public class RealMatrixImpl implements RealMatrix, Serializable {
* @return resulting matrix
*/
public double[] preMultiply(double[] v) throws IllegalArgumentException {
int nRows = this.getRowDimension();
final int nRows = this.getRowDimension();
if (v.length != nRows) {
throw new IllegalArgumentException("vector has wrong length");
}
int nCols = this.getColumnDimension();
double[] out = new double[nCols];
final int nCols = this.getColumnDimension();
final double[] out = new double[nCols];
for (int col = 0; col < nCols; col++) {
double sum = 0;
for (int i = 0; i < nRows; i++) {
@ -682,13 +822,13 @@ public class RealMatrixImpl implements RealMatrix, Serializable {
* @throws InvalidMatrixException if this matrix is not square or is singular
*/
public double[] solve(double[] b) throws IllegalArgumentException, InvalidMatrixException {
int nRows = this.getRowDimension();
final int nRows = this.getRowDimension();
if (b.length != nRows) {
throw new IllegalArgumentException("constant vector has wrong length");
}
RealMatrix bMatrix = new RealMatrixImpl(b);
double[][] solution = ((RealMatrixImpl) (solve(bMatrix))).getDataRef();
double[] out = new double[nRows];
final RealMatrix bMatrix = new RealMatrixImpl(b);
final double[][] solution = ((RealMatrixImpl) (solve(bMatrix))).getDataRef();
final double[] out = new double[nRows];
for (int row = 0; row < nRows; row++) {
out[row] = solution[row][0];
}
@ -717,41 +857,48 @@ public class RealMatrixImpl implements RealMatrix, Serializable {
throw new InvalidMatrixException("Matrix is singular.");
}
int nCol = this.getColumnDimension();
int nColB = b.getColumnDimension();
int nRowB = b.getRowDimension();
final int nCol = this.getColumnDimension();
final int nColB = b.getColumnDimension();
final int nRowB = b.getRowDimension();
// Apply permutations to b
double[][] bp = new double[nRowB][nColB];
final double[][] bp = new double[nRowB][nColB];
for (int row = 0; row < nRowB; row++) {
final double[] bpRow = bp[row];
for (int col = 0; col < nColB; col++) {
bp[row][col] = b.getEntry(permutation[row], col);
bpRow[col] = b.getEntry(permutation[row], col);
}
}
// Solve LY = b
for (int col = 0; col < nCol; col++) {
for (int i = col + 1; i < nCol; i++) {
final double[] bpI = bp[i];
final double[] luI = lu[i];
for (int j = 0; j < nColB; j++) {
bp[i][j] -= bp[col][j] * lu[i][col];
bpI[j] -= bp[col][j] * luI[col];
}
}
}
// Solve UX = Y
for (int col = nCol - 1; col >= 0; col--) {
final double[] bpCol = bp[col];
final double luDiag = lu[col][col];
for (int j = 0; j < nColB; j++) {
bp[col][j] /= lu[col][col];
bpCol[j] /= luDiag;
}
for (int i = 0; i < col; i++) {
final double[] bpI = bp[i];
final double[] luI = lu[i];
for (int j = 0; j < nColB; j++) {
bp[i][j] -= bp[col][j] * lu[i][col];
bpI[j] -= bp[col][j] * luI[col];
}
}
}
RealMatrixImpl outMat = new RealMatrixImpl(bp);
return outMat;
return new RealMatrixImpl(bp, false);
}
/**
@ -774,12 +921,12 @@ public class RealMatrixImpl implements RealMatrix, Serializable {
*/
public void luDecompose() throws InvalidMatrixException {
int nRows = this.getRowDimension();
int nCols = this.getColumnDimension();
final int nRows = this.getRowDimension();
final int nCols = this.getColumnDimension();
if (nRows != nCols) {
throw new InvalidMatrixException("LU decomposition requires that the matrix be square.");
}
lu = this.getData();
lu = getData();
// Initialize permutation array and parity
permutation = new int[nRows];
@ -795,22 +942,24 @@ public class RealMatrixImpl implements RealMatrix, Serializable {
// upper
for (int row = 0; row < col; row++) {
sum = lu[row][col];
final double[] luRow = lu[row];
sum = luRow[col];
for (int i = 0; i < row; i++) {
sum -= lu[row][i] * lu[i][col];
sum -= luRow[i] * lu[i][col];
}
lu[row][col] = sum;
luRow[col] = sum;
}
// lower
int max = col; // permutation row
double largest = 0d;
for (int row = col; row < nRows; row++) {
sum = lu[row][col];
final double[] luRow = lu[row];
sum = luRow[col];
for (int i = 0; i < col; i++) {
sum -= lu[row][i] * lu[i][col];
sum -= luRow[i] * lu[i][col];
}
lu[row][col] = sum;
luRow[col] = sum;
// maintain best permutation choice
if (Math.abs(sum) > largest) {
@ -839,9 +988,10 @@ public class RealMatrixImpl implements RealMatrix, Serializable {
parity = -parity;
}
//Divide the lower elements by the "winning" diagonal elt.
// Divide the lower elements by the "winning" diagonal elt.
final double luDiag = lu[col][col];
for (int row = col + 1; row < nRows; row++) {
lu[row][col] /= lu[col][col];
lu[row][col] /= luDiag;
}
}
}
@ -855,12 +1005,14 @@ public class RealMatrixImpl implements RealMatrix, Serializable {
res.append("RealMatrixImpl{");
if (data != null) {
for (int i = 0; i < data.length; i++) {
if (i > 0)
if (i > 0) {
res.append(",");
}
res.append("{");
for (int j = 0; j < data[0].length; j++) {
if (j > 0)
if (j > 0) {
res.append(",");
}
res.append(data[i][j]);
}
res.append("}");
@ -887,14 +1039,15 @@ public class RealMatrixImpl implements RealMatrix, Serializable {
return false;
}
RealMatrix m = (RealMatrix) object;
int nRows = getRowDimension();
int nCols = getColumnDimension();
final int nRows = getRowDimension();
final int nCols = getColumnDimension();
if (m.getColumnDimension() != nCols || m.getRowDimension() != nRows) {
return false;
}
for (int row = 0; row < nRows; row++) {
final double[] dataRow = data[row];
for (int col = 0; col < nCols; col++) {
if (Double.doubleToLongBits(data[row][col]) !=
if (Double.doubleToLongBits(dataRow[col]) !=
Double.doubleToLongBits(m.getEntry(row, col))) {
return false;
}
@ -910,14 +1063,15 @@ public class RealMatrixImpl implements RealMatrix, Serializable {
*/
public int hashCode() {
int ret = 7;
int nRows = getRowDimension();
int nCols = getColumnDimension();
final int nRows = getRowDimension();
final int nCols = getColumnDimension();
ret = ret * 31 + nRows;
ret = ret * 31 + nCols;
for (int row = 0; row < nRows; row++) {
for (int col = 0; col < nCols; col++) {
final double[] dataRow = data[row];
for (int col = 0; col < nCols; col++) {
ret = ret * 31 + (11 * (row+1) + 17 * (col+1)) *
MathUtils.hash(data[row][col]);
MathUtils.hash(dataRow[col]);
}
}
return ret;
@ -972,7 +1126,7 @@ public class RealMatrixImpl implements RealMatrix, Serializable {
* @return the permutation
*/
protected int[] getPermutation() {
int[] out = new int[permutation.length];
final int[] out = new int[permutation.length];
System.arraycopy(permutation, 0, out, 0, permutation.length);
return out;
}
@ -985,8 +1139,8 @@ public class RealMatrixImpl implements RealMatrix, Serializable {
* @return a copy of the underlying data array.
*/
private double[][] copyOut() {
int nRows = this.getRowDimension();
double[][] out = new double[nRows][this.getColumnDimension()];
final int nRows = this.getRowDimension();
final double[][] out = new double[nRows][this.getColumnDimension()];
// can't copy 2-d array in one shot, otherwise get row references
for (int i = 0; i < nRows; i++) {
System.arraycopy(data[i], 0, out[i], 0, data[i].length);
@ -1016,9 +1170,8 @@ public class RealMatrixImpl implements RealMatrix, Serializable {
* @return true if the coordinate is with the current dimensions
*/
private boolean isValidCoordinate(int row, int col) {
int nRows = this.getRowDimension();
int nCols = this.getColumnDimension();
final int nRows = getRowDimension();
final int nCols = getColumnDimension();
return !(row < 0 || row > nRows - 1 || col < 0 || col > nCols -1);
}

View File

@ -39,6 +39,10 @@ The <action> type attribute can be add,update,fix,remove.
</properties>
<body>
<release version="2.0" date="TBD" description="TBD">
<action dev="luc" type="update" >
Greatly improved RealMatrixImpl and BigMatrixImpl performances,
both in terms of speed and in terms of temporary memory footprint.
</action>
<action dev="luc" type="add" issue="MATH-203" due-to="Mauro Talevi">
Added Mauro's patch to support multiple regression.
</action>

View File

@ -156,9 +156,12 @@ public final class BigMatrixImplTest extends TestCase {
/** test copy functions */
public void testCopyFunctions() {
BigMatrixImpl m = new BigMatrixImpl(testData);
BigMatrixImpl m2 = new BigMatrixImpl(m.getData());
assertEquals(m2,m);
BigMatrixImpl m1 = new BigMatrixImpl(testData);
BigMatrixImpl m2 = new BigMatrixImpl(m1.getData());
assertEquals(m2,m1);
BigMatrixImpl m3 = new BigMatrixImpl(testData);
BigMatrixImpl m4 = new BigMatrixImpl(m3.getData(), false);
assertEquals(m4,m3);
}
/** test constructors */
@ -166,9 +169,13 @@ public final class BigMatrixImplTest extends TestCase {
BigMatrix m1 = new BigMatrixImpl(testData);
BigMatrix m2 = new BigMatrixImpl(testDataString);
BigMatrix m3 = new BigMatrixImpl(asBigDecimal(testData));
BigMatrix m4 = new BigMatrixImpl(asBigDecimal(testData), true);
BigMatrix m5 = new BigMatrixImpl(asBigDecimal(testData), false);
assertClose("double, string", m1, m2, Double.MIN_VALUE);
assertClose("double, BigDecimal", m1, m3, Double.MIN_VALUE);
assertClose("string, BigDecimal", m2, m3, Double.MIN_VALUE);
assertClose("double, BigDecimal/true", m1, m4, Double.MIN_VALUE);
assertClose("double, BigDecimal/false", m1, m5, Double.MIN_VALUE);
try {
new BigMatrixImpl(new String[][] {{"0", "hello", "1"}});
fail("Expecting NumberFormatException");
@ -212,7 +219,7 @@ public final class BigMatrixImplTest extends TestCase {
public void testAdd() {
BigMatrixImpl m = new BigMatrixImpl(testData);
BigMatrixImpl mInv = new BigMatrixImpl(testDataInv);
BigMatrixImpl mPlusMInv = (BigMatrixImpl)m.add(mInv);
BigMatrix mPlusMInv = m.add(mInv);
double[][] sumEntries = asDouble(mPlusMInv.getData());
for (int row = 0; row < m.getRowDimension(); row++) {
for (int col = 0; col < m.getColumnDimension(); col++) {

View File

@ -65,6 +65,10 @@ public final class MatrixUtilsTest extends TestCase {
public void testCreateRealMatrix() {
assertEquals(new RealMatrixImpl(testData),
MatrixUtils.createRealMatrix(testData));
assertEquals(new RealMatrixImpl(testData, false),
MatrixUtils.createRealMatrix(testData, true));
assertEquals(new RealMatrixImpl(testData, true),
MatrixUtils.createRealMatrix(testData, false));
try {
MatrixUtils.createRealMatrix(new double[][] {{1}, {1,2}}); // ragged
fail("Expecting IllegalArgumentException");
@ -88,6 +92,10 @@ public final class MatrixUtilsTest extends TestCase {
public void testCreateBigMatrix() {
assertEquals(new BigMatrixImpl(testData),
MatrixUtils.createBigMatrix(testData));
assertEquals(new BigMatrixImpl(BigMatrixImplTest.asBigDecimal(testData), true),
MatrixUtils.createBigMatrix(BigMatrixImplTest.asBigDecimal(testData), false));
assertEquals(new BigMatrixImpl(BigMatrixImplTest.asBigDecimal(testData), false),
MatrixUtils.createBigMatrix(BigMatrixImplTest.asBigDecimal(testData), true));
assertEquals(new BigMatrixImpl(bigColMatrix),
MatrixUtils.createBigMatrix(bigColMatrix));
assertEquals(new BigMatrixImpl(stringColMatrix),

View File

@ -59,21 +59,21 @@ public class QRDecompositionImplTest extends TestCase {
/** test dimensions */
public void testDimensions() {
RealMatrixImpl matrix = new RealMatrixImpl(testData3x3NonSingular);
RealMatrixImpl matrix = new RealMatrixImpl(testData3x3NonSingular, false);
QRDecomposition qr = new QRDecompositionImpl(matrix);
assertEquals("3x3 Q size", qr.getQ().getRowDimension(), 3);
assertEquals("3x3 Q size", qr.getQ().getColumnDimension(), 3);
assertEquals("3x3 R size", qr.getR().getRowDimension(), 3);
assertEquals("3x3 R size", qr.getR().getColumnDimension(), 3);
matrix = new RealMatrixImpl(testData4x3);
matrix = new RealMatrixImpl(testData4x3, false);
qr = new QRDecompositionImpl(matrix);
assertEquals("4x3 Q size", qr.getQ().getRowDimension(), 4);
assertEquals("4x3 Q size", qr.getQ().getColumnDimension(), 4);
assertEquals("4x3 R size", qr.getR().getRowDimension(), 4);
assertEquals("4x3 R size", qr.getR().getColumnDimension(), 3);
matrix = new RealMatrixImpl(testData3x4);
matrix = new RealMatrixImpl(testData3x4, false);
qr = new QRDecompositionImpl(matrix);
assertEquals("3x4 Q size", qr.getQ().getRowDimension(), 3);
assertEquals("3x4 Q size", qr.getQ().getColumnDimension(), 3);
@ -83,24 +83,24 @@ public class QRDecompositionImplTest extends TestCase {
/** test A = QR */
public void testAEqualQR() {
RealMatrix A = new RealMatrixImpl(testData3x3NonSingular);
RealMatrix A = new RealMatrixImpl(testData3x3NonSingular, false);
QRDecomposition qr = new QRDecompositionImpl(A);
RealMatrix Q = qr.getQ();
RealMatrix R = qr.getR();
double norm = Q.multiply(R).subtract(A).getNorm();
assertEquals("3x3 nonsingular A = QR", 0, norm, normTolerance);
RealMatrix matrix = new RealMatrixImpl(testData3x3Singular);
RealMatrix matrix = new RealMatrixImpl(testData3x3Singular, false);
qr = new QRDecompositionImpl(matrix);
norm = qr.getQ().multiply(qr.getR()).subtract(matrix).getNorm();
assertEquals("3x3 singular A = QR", 0, norm, normTolerance);
matrix = new RealMatrixImpl(testData3x4);
matrix = new RealMatrixImpl(testData3x4, false);
qr = new QRDecompositionImpl(matrix);
norm = qr.getQ().multiply(qr.getR()).subtract(matrix).getNorm();
assertEquals("3x4 A = QR", 0, norm, normTolerance);
matrix = new RealMatrixImpl(testData4x3);
matrix = new RealMatrixImpl(testData4x3, false);
qr = new QRDecompositionImpl(matrix);
norm = qr.getQ().multiply(qr.getR()).subtract(matrix).getNorm();
assertEquals("4x3 A = QR", 0, norm, normTolerance);
@ -108,28 +108,28 @@ public class QRDecompositionImplTest extends TestCase {
/** test the orthogonality of Q */
public void testQOrthogonal() {
RealMatrix matrix = new RealMatrixImpl(testData3x3NonSingular);
RealMatrix matrix = new RealMatrixImpl(testData3x3NonSingular, false);
matrix = new QRDecompositionImpl(matrix).getQ();
RealMatrix eye = MatrixUtils.createRealIdentityMatrix(3);
double norm = matrix.transpose().multiply(matrix).subtract(eye)
.getNorm();
assertEquals("3x3 nonsingular Q'Q = I", 0, norm, normTolerance);
matrix = new RealMatrixImpl(testData3x3Singular);
matrix = new RealMatrixImpl(testData3x3Singular, false);
matrix = new QRDecompositionImpl(matrix).getQ();
eye = MatrixUtils.createRealIdentityMatrix(3);
norm = matrix.transpose().multiply(matrix).subtract(eye)
.getNorm();
assertEquals("3x3 singular Q'Q = I", 0, norm, normTolerance);
matrix = new RealMatrixImpl(testData3x4);
matrix = new RealMatrixImpl(testData3x4, false);
matrix = new QRDecompositionImpl(matrix).getQ();
eye = MatrixUtils.createRealIdentityMatrix(3);
norm = matrix.transpose().multiply(matrix).subtract(eye)
.getNorm();
assertEquals("3x4 Q'Q = I", 0, norm, normTolerance);
matrix = new RealMatrixImpl(testData4x3);
matrix = new RealMatrixImpl(testData4x3, false);
matrix = new QRDecompositionImpl(matrix).getQ();
eye = MatrixUtils.createRealIdentityMatrix(4);
norm = matrix.transpose().multiply(matrix).subtract(eye)
@ -139,21 +139,21 @@ public class QRDecompositionImplTest extends TestCase {
/** test that R is upper triangular */
public void testRUpperTriangular() {
RealMatrixImpl matrix = new RealMatrixImpl(testData3x3NonSingular);
RealMatrixImpl matrix = new RealMatrixImpl(testData3x3NonSingular, false);
RealMatrix R = new QRDecompositionImpl(matrix).getR();
for (int i = 0; i < R.getRowDimension(); i++)
for (int j = 0; j < i; j++)
assertEquals("R lower triangle", R.getEntry(i, j), 0,
entryTolerance);
matrix = new RealMatrixImpl(testData3x4);
matrix = new RealMatrixImpl(testData3x4, false);
R = new QRDecompositionImpl(matrix).getR();
for (int i = 0; i < R.getRowDimension(); i++)
for (int j = 0; j < i; j++)
assertEquals("R lower triangle", R.getEntry(i, j), 0,
entryTolerance);
matrix = new RealMatrixImpl(testData4x3);
matrix = new RealMatrixImpl(testData4x3, false);
R = new QRDecompositionImpl(matrix).getR();
for (int i = 0; i < R.getRowDimension(); i++)
for (int j = 0; j < i; j++)

View File

@ -116,16 +116,19 @@ public final class RealMatrixImplTest extends TestCase {
/** test copy functions */
public void testCopyFunctions() {
RealMatrixImpl m = new RealMatrixImpl(testData);
RealMatrixImpl m2 = new RealMatrixImpl(m.getData());
assertEquals(m2,m);
RealMatrixImpl m1 = new RealMatrixImpl(testData);
RealMatrixImpl m2 = new RealMatrixImpl(m1.getData());
assertEquals(m2,m1);
RealMatrixImpl m3 = new RealMatrixImpl(testData);
RealMatrixImpl m4 = new RealMatrixImpl(m3.getData(), false);
assertEquals(m4,m3);
}
/** test add */
public void testAdd() {
RealMatrixImpl m = new RealMatrixImpl(testData);
RealMatrixImpl mInv = new RealMatrixImpl(testDataInv);
RealMatrixImpl mPlusMInv = (RealMatrixImpl)m.add(mInv);
RealMatrix mPlusMInv = m.add(mInv);
double[][] sumEntries = mPlusMInv.getData();
for (int row = 0; row < m.getRowDimension(); row++) {
for (int col = 0; col < m.getColumnDimension(); col++) {