From 482ebca8f54c6d1c6ef3d07710d0717334bc0eee Mon Sep 17 00:00:00 2001 From: Luc Maisonobe Date: Sun, 1 Jun 2008 16:22:19 +0000 Subject: [PATCH] Improved linear RealMatrixImpl and BigMatrixImpl performances. The main changes are the following ones: - use directly the storage array when possible for diadic operations (add, subtract, multiply), as suggested by Phil, this avoids the cost of the generic getEntry method - replaced custom indices checks by simple use of the JVM checks and ArrayIndexOutOfBoundException - put row arrays reference in local variables to avoid multiple checks in double loops - use final variables where possible - removed unneeded array copying - added a constructor to build a matrix from an array without copying it where it makes sense The speed gain is about 3X for multiplication. Performances for this operation are now on par with Jama. git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/branches/MATH_2_0@662241 13f79535-47bb-0310-9956-ffa450edef68 --- .../commons/math/linear/BigMatrixImpl.java | 476 ++++++++++++------ .../commons/math/linear/MatrixUtils.java | 126 +++-- .../commons/math/linear/RealMatrixImpl.java | 433 ++++++++++------ src/site/xdoc/changes.xml | 4 + .../math/linear/BigMatrixImplTest.java | 15 +- .../commons/math/linear/MatrixUtilsTest.java | 8 + .../math/linear/QRDecompositionImplTest.java | 28 +- .../math/linear/RealMatrixImplTest.java | 11 +- 8 files changed, 738 insertions(+), 363 deletions(-) diff --git a/src/java/org/apache/commons/math/linear/BigMatrixImpl.java b/src/java/org/apache/commons/math/linear/BigMatrixImpl.java index 2363c1431..09eaf78b6 100644 --- a/src/java/org/apache/commons/math/linear/BigMatrixImpl.java +++ b/src/java/org/apache/commons/math/linear/BigMatrixImpl.java @@ -106,8 +106,9 @@ public class BigMatrixImpl implements BigMatrix, Serializable { /** * Create a new BigMatrix using d as the underlying * data array. - *

- * The input array is copied, not referenced.

+ *

The input array is copied, not referenced. This constructor has + * the same effect as calling {@link #BigMatrixImpl(BigDecimal[][], boolean)} + * with the second argument set to true.

* * @param d data for new matrix * @throws IllegalArgumentException if d is not rectangular @@ -118,12 +119,52 @@ public class BigMatrixImpl implements BigMatrix, Serializable { this.copyIn(d); lu = null; } - + + /** + * Create a new BigMatrix using the input array as the underlying + * data array. + *

If an array is built specially in order to be embedded in a + * BigMatrix and not used directly, the copyArray may be + * set to false + * @param d data for new matrix + * @param copyArray if true, the input array will be copied, otherwise + * it will be referenced + * @throws IllegalArgumentException if d is not rectangular + * (not all rows have the same length) or empty + * @throws NullPointerException if d is null + * @see #BigMatrix(BigDecimal[][]) + */ + public BigMatrixImpl(BigDecimal[][] d, boolean copyArray) { + if (copyArray) { + copyIn(d); + } else { + if (d == null) { + throw new NullPointerException(); + } + final int nRows = d.length; + if (nRows == 0) { + throw new IllegalArgumentException("Matrix must have at least one row."); + } + final int nCols = d[0].length; + if (nCols == 0) { + throw new IllegalArgumentException("Matrix must have at least one column."); + } + for (int r = 1; r < nRows; r++) { + if (d[r].length != nCols) { + throw new IllegalArgumentException("All input rows must have the same length."); + } + } + data = d; + } + lu = null; + } + /** * Create a new BigMatrix using d as the underlying * data array. - *

- * The input array is copied, not referenced.

+ *

Since the underlying array will hold BigDecimal + * instances, it will be created.

* * @param d data for new matrix * @throws IllegalArgumentException if d is not rectangular @@ -131,12 +172,12 @@ public class BigMatrixImpl implements BigMatrix, Serializable { * @throws NullPointerException if d is null */ public BigMatrixImpl(double[][] d) { - int nRows = d.length; + final int nRows = d.length; if (nRows == 0) { throw new IllegalArgumentException( "Matrix must have at least one row."); } - int nCols = d[0].length; + final int nCols = d[0].length; if (nCols == 0) { throw new IllegalArgumentException( "Matrix must have at least one column."); @@ -161,12 +202,12 @@ public class BigMatrixImpl implements BigMatrix, Serializable { * @throws NullPointerException if d is null */ public BigMatrixImpl(String[][] d) { - int nRows = d.length; + final int nRows = d.length; if (nRows == 0) { throw new IllegalArgumentException( "Matrix must have at least one row."); } - int nCols = d[0].length; + final int nCols = d[0].length; if (nCols == 0) { throw new IllegalArgumentException( "Matrix must have at least one column."); @@ -191,7 +232,7 @@ public class BigMatrixImpl implements BigMatrix, Serializable { * @param v column vector holding data for new matrix */ public BigMatrixImpl(BigDecimal[] v) { - int nRows = v.length; + final int nRows = v.length; data = new BigDecimal[nRows][1]; for (int row = 0; row < nRows; row++) { data[row][0] = v[row]; @@ -204,7 +245,7 @@ public class BigMatrixImpl implements BigMatrix, Serializable { * @return the cloned matrix */ public BigMatrix copy() { - return new BigMatrixImpl(this.copyOut()); + return new BigMatrixImpl(this.copyOut(), false); } /** @@ -212,47 +253,107 @@ public class BigMatrixImpl implements BigMatrix, Serializable { * * @param m matrix to be added * @return this + m - * @exception IllegalArgumentException if m is not the same size as this + * @throws IllegalArgumentException if m is not the same size as this */ public BigMatrix add(BigMatrix m) throws IllegalArgumentException { - if (this.getColumnDimension() != m.getColumnDimension() || - this.getRowDimension() != m.getRowDimension()) { + try { + return add((BigMatrixImpl) m); + } catch (ClassCastException cce) { + final int rowCount = getRowDimension(); + final int columnCount = getColumnDimension(); + if (columnCount != m.getColumnDimension() || rowCount != m.getRowDimension()) { + throw new IllegalArgumentException("matrix dimension mismatch"); + } + final BigDecimal[][] outData = new BigDecimal[rowCount][columnCount]; + for (int row = 0; row < rowCount; row++) { + final BigDecimal[] dataRow = data[row]; + final BigDecimal[] outDataRow = outData[row]; + for (int col = 0; col < columnCount; col++) { + outDataRow[col] = dataRow[col].add(m.getEntry(row, col)); + } + } + return new BigMatrixImpl(outData, false); + } + } + + /** + * Compute the sum of this and m. + * + * @param m matrix to be added + * @return this + m + * @throws IllegalArgumentException if m is not the same size as this + */ + public BigMatrixImpl add(BigMatrixImpl m) throws IllegalArgumentException { + final int rowCount = getRowDimension(); + final int columnCount = getColumnDimension(); + if (columnCount != m.getColumnDimension() || rowCount != m.getRowDimension()) { throw new IllegalArgumentException("matrix dimension mismatch"); } - int rowCount = this.getRowDimension(); - int columnCount = this.getColumnDimension(); - BigDecimal[][] outData = new BigDecimal[rowCount][columnCount]; + final BigDecimal[][] outData = new BigDecimal[rowCount][columnCount]; for (int row = 0; row < rowCount; row++) { + final BigDecimal[] dataRow = data[row]; + final BigDecimal[] mRow = m.data[row]; + final BigDecimal[] outDataRow = outData[row]; for (int col = 0; col < columnCount; col++) { - outData[row][col] = data[row][col].add(m.getEntry(row, col)); - } + outDataRow[col] = dataRow[col].add(mRow[col]); + } } - return new BigMatrixImpl(outData); + return new BigMatrixImpl(outData, false); } - + /** * Compute this minus m. * * @param m matrix to be subtracted * @return this + m - * @exception IllegalArgumentException if m is not the same size as *this + * @throws IllegalArgumentException if m is not the same size as this */ public BigMatrix subtract(BigMatrix m) throws IllegalArgumentException { - if (this.getColumnDimension() != m.getColumnDimension() || - this.getRowDimension() != m.getRowDimension()) { + try { + return subtract((BigMatrixImpl) m); + } catch (ClassCastException cce) { + final int rowCount = getRowDimension(); + final int columnCount = getColumnDimension(); + if (columnCount != m.getColumnDimension() || rowCount != m.getRowDimension()) { + throw new IllegalArgumentException("matrix dimension mismatch"); + } + final BigDecimal[][] outData = new BigDecimal[rowCount][columnCount]; + for (int row = 0; row < rowCount; row++) { + final BigDecimal[] dataRow = data[row]; + final BigDecimal[] outDataRow = outData[row]; + for (int col = 0; col < columnCount; col++) { + outDataRow[col] = dataRow[col].subtract(getEntry(row, col)); + } + } + return new BigMatrixImpl(outData, false); + } + } + + /** + * Compute this minus m. + * + * @param m matrix to be subtracted + * @return this + m + * @throws IllegalArgumentException if m is not the same size as this + */ + public BigMatrixImpl subtract(BigMatrixImpl m) throws IllegalArgumentException { + final int rowCount = getRowDimension(); + final int columnCount = getColumnDimension(); + if (columnCount != m.getColumnDimension() || rowCount != m.getRowDimension()) { throw new IllegalArgumentException("matrix dimension mismatch"); } - int rowCount = this.getRowDimension(); - int columnCount = this.getColumnDimension(); - BigDecimal[][] outData = new BigDecimal[rowCount][columnCount]; + final BigDecimal[][] outData = new BigDecimal[rowCount][columnCount]; for (int row = 0; row < rowCount; row++) { + final BigDecimal[] dataRow = data[row]; + final BigDecimal[] mRow = m.data[row]; + final BigDecimal[] outDataRow = outData[row]; for (int col = 0; col < columnCount; col++) { - outData[row][col] = data[row][col].subtract(m.getEntry(row, col)); - } + outDataRow[col] = dataRow[col].subtract(mRow[col]); + } } - return new BigMatrixImpl(outData); + return new BigMatrixImpl(outData, false); } - + /** * Returns the result of adding d to each entry of this. * @@ -260,34 +361,38 @@ public class BigMatrixImpl implements BigMatrix, Serializable { * @return d + this */ public BigMatrix scalarAdd(BigDecimal d) { - int rowCount = this.getRowDimension(); - int columnCount = this.getColumnDimension(); - BigDecimal[][] outData = new BigDecimal[rowCount][columnCount]; + final int rowCount = getRowDimension(); + final int columnCount = getColumnDimension(); + final BigDecimal[][] outData = new BigDecimal[rowCount][columnCount]; for (int row = 0; row < rowCount; row++) { + final BigDecimal[] dataRow = data[row]; + final BigDecimal[] outDataRow = outData[row]; for (int col = 0; col < columnCount; col++) { - outData[row][col] = data[row][col].add(d); + outDataRow[col] = dataRow[col].add(d); } } - return new BigMatrixImpl(outData); + return new BigMatrixImpl(outData, false); } - + /** - * Returns the result multiplying each entry of this by d + * Returns the result of multiplying each entry of this by d * @param d value to multiply all entries by * @return d * this */ public BigMatrix scalarMultiply(BigDecimal d) { - int rowCount = this.getRowDimension(); - int columnCount = this.getColumnDimension(); - BigDecimal[][] outData = new BigDecimal[rowCount][columnCount]; + final int rowCount = getRowDimension(); + final int columnCount = getColumnDimension(); + final BigDecimal[][] outData = new BigDecimal[rowCount][columnCount]; for (int row = 0; row < rowCount; row++) { + final BigDecimal[] dataRow = data[row]; + final BigDecimal[] outDataRow = outData[row]; for (int col = 0; col < columnCount; col++) { - outData[row][col] = data[row][col].multiply(d); + outDataRow[col] = dataRow[col].multiply(d); } } - return new BigMatrixImpl(outData); + return new BigMatrixImpl(outData, false); } - + /** * Returns the result of postmultiplying this by m. * @param m matrix to postmultiply by @@ -296,26 +401,60 @@ public class BigMatrixImpl implements BigMatrix, Serializable { * if columnDimension(this) != rowDimension(m) */ public BigMatrix multiply(BigMatrix m) throws IllegalArgumentException { + try { + return multiply((BigMatrixImpl) m); + } catch (ClassCastException cce) { + if (this.getColumnDimension() != m.getRowDimension()) { + throw new IllegalArgumentException("Matrices are not multiplication compatible."); + } + final int nRows = this.getRowDimension(); + final int nCols = m.getColumnDimension(); + final int nSum = this.getColumnDimension(); + final BigDecimal[][] outData = new BigDecimal[nRows][nCols]; + for (int row = 0; row < nRows; row++) { + final BigDecimal[] dataRow = data[row]; + final BigDecimal[] outDataRow = outData[row]; + for (int col = 0; col < nCols; col++) { + BigDecimal sum = ZERO; + for (int i = 0; i < nSum; i++) { + sum = sum.add(dataRow[i].multiply(m.getEntry(i, col))); + } + outDataRow[col] = sum; + } + } + return new BigMatrixImpl(outData, false); + } + } + + /** + * Returns the result of postmultiplying this by m. + * @param m matrix to postmultiply by + * @return this*m + * @throws IllegalArgumentException + * if columnDimension(this) != rowDimension(m) + */ + public BigMatrixImpl multiply(BigMatrixImpl m) throws IllegalArgumentException { if (this.getColumnDimension() != m.getRowDimension()) { throw new IllegalArgumentException("Matrices are not multiplication compatible."); } - int nRows = this.getRowDimension(); - int nCols = m.getColumnDimension(); - int nSum = this.getColumnDimension(); - BigDecimal[][] outData = new BigDecimal[nRows][nCols]; - BigDecimal sum = ZERO; + final int nRows = this.getRowDimension(); + final int nCols = m.getColumnDimension(); + final int nSum = this.getColumnDimension(); + final BigDecimal[][] outData = new BigDecimal[nRows][nCols]; for (int row = 0; row < nRows; row++) { + final BigDecimal[] dataRow = data[row]; + final BigDecimal[] outDataRow = outData[row]; for (int col = 0; col < nCols; col++) { - sum = ZERO; + BigDecimal sum = ZERO; for (int i = 0; i < nSum; i++) { - sum = sum.add(data[row][i].multiply(m.getEntry(i, col))); + sum = sum.add(dataRow[i].multiply(m.data[i][col])); } - outData[row][col] = sum; + outDataRow[col] = sum; } - } - return new BigMatrixImpl(outData); + } + return new BigMatrixImpl(outData, false); } - + /** * Returns the result premultiplying this by m. * @param m matrix to premultiply by @@ -326,7 +465,7 @@ public class BigMatrixImpl implements BigMatrix, Serializable { public BigMatrix preMultiply(BigMatrix m) throws IllegalArgumentException { return m.multiply(this); } - + /** * Returns matrix entries as a two-dimensional array. *

@@ -347,11 +486,11 @@ public class BigMatrixImpl implements BigMatrix, Serializable { * @return 2-dimensional array of entries */ public double[][] getDataAsDoubleArray() { - int nRows = getRowDimension(); - int nCols = getColumnDimension(); - double d[][] = new double[nRows][nCols]; + final int nRows = getRowDimension(); + final int nCols = getColumnDimension(); + final double d[][] = new double[nRows][nCols]; for (int i = 0; i < nRows; i++) { - for (int j=0; j endRow || endRow > data.length || startColumn < 0 || startColumn > endColumn || endColumn > data[0].length ) { throw new MatrixIndexException( "invalid row or column index selection"); } - BigMatrixImpl subMatrix = new BigMatrixImpl(endRow - startRow+1, - endColumn - startColumn+1); - BigDecimal[][] subMatrixData = subMatrix.getDataRef(); + final BigDecimal[][] subMatrixData = + new BigDecimal[endRow - startRow + 1][endColumn - startColumn + 1]; for (int i = startRow; i <= endRow; i++) { - for (int j = startColumn; j <= endColumn; j++) { - subMatrixData[i - startRow][j - startColumn] = data[i][j]; - } + System.arraycopy(data[i], startColumn, + subMatrixData[i - startRow], 0, + endColumn - startColumn + 1); } - return subMatrix; + return new BigMatrixImpl(subMatrixData, false); } /** @@ -468,25 +607,26 @@ public class BigMatrixImpl implements BigMatrix, Serializable { * are not valid */ public BigMatrix getSubMatrix(int[] selectedRows, int[] selectedColumns) - throws MatrixIndexException { + throws MatrixIndexException { if (selectedRows.length * selectedColumns.length == 0) { throw new MatrixIndexException( "selected row and column index arrays must be non-empty"); } - BigMatrixImpl subMatrix = new BigMatrixImpl(selectedRows.length, - selectedColumns.length); - BigDecimal[][] subMatrixData = subMatrix.getDataRef(); + final BigDecimal[][] subMatrixData = + new BigDecimal[selectedRows.length][selectedColumns.length]; try { for (int i = 0; i < selectedRows.length; i++) { + final BigDecimal[] subI = subMatrixData[i]; + final BigDecimal[] dataSelectedI = data[selectedRows[i]]; for (int j = 0; j < selectedColumns.length; j++) { - subMatrixData[i][j] = data[selectedRows[i]][selectedColumns[j]]; + subI[j] = dataSelectedI[selectedColumns[j]]; } } } catch (ArrayIndexOutOfBoundsException e) { throw new MatrixIndexException("matrix dimension mismatch"); } - return subMatrix; + return new BigMatrixImpl(subMatrixData, false); } /** @@ -522,12 +662,12 @@ public class BigMatrixImpl implements BigMatrix, Serializable { throw new MatrixIndexException ("invalid row or column index selection"); } - int nRows = subMatrix.length; + final int nRows = subMatrix.length; if (nRows == 0) { throw new IllegalArgumentException( "Matrix must have at least one row."); } - int nCols = subMatrix[0].length; + final int nCols = subMatrix[0].length; if (nCols == 0) { throw new IllegalArgumentException( "Matrix must have at least one column."); @@ -566,10 +706,10 @@ public class BigMatrixImpl implements BigMatrix, Serializable { if ( !isValidCoordinate( row, 0)) { throw new MatrixIndexException("illegal row argument"); } - int ncols = this.getColumnDimension(); - BigDecimal[][] out = new BigDecimal[1][ncols]; + final int ncols = this.getColumnDimension(); + final BigDecimal[][] out = new BigDecimal[1][ncols]; System.arraycopy(data[row], 0, out[0], 0, ncols); - return new BigMatrixImpl(out); + return new BigMatrixImpl(out, false); } /** @@ -584,12 +724,12 @@ public class BigMatrixImpl implements BigMatrix, Serializable { if ( !isValidCoordinate( 0, column)) { throw new MatrixIndexException("illegal column argument"); } - int nRows = this.getRowDimension(); - BigDecimal[][] out = new BigDecimal[nRows][1]; + final int nRows = this.getRowDimension(); + final BigDecimal[][] out = new BigDecimal[nRows][1]; for (int row = 0; row < nRows; row++) { out[row][0] = data[row][column]; } - return new BigMatrixImpl(out); + return new BigMatrixImpl(out, false); } /** @@ -606,8 +746,8 @@ public class BigMatrixImpl implements BigMatrix, Serializable { if ( !isValidCoordinate( row, 0 ) ) { throw new MatrixIndexException("illegal row argument"); } - int ncols = this.getColumnDimension(); - BigDecimal[] out = new BigDecimal[ncols]; + final int ncols = this.getColumnDimension(); + final BigDecimal[] out = new BigDecimal[ncols]; System.arraycopy(data[row], 0, out, 0, ncols); return out; } @@ -627,8 +767,8 @@ public class BigMatrixImpl implements BigMatrix, Serializable { if ( !isValidCoordinate( row, 0 ) ) { throw new MatrixIndexException("illegal row argument"); } - int ncols = this.getColumnDimension(); - double[] out = new double[ncols]; + final int ncols = this.getColumnDimension(); + final double[] out = new double[ncols]; for (int i=0;i= 0; col--) { + final BigDecimal[] bpCol = bp[col]; + final BigDecimal luDiag = lu[col][col]; for (int j = 0; j < nColB; j++) { - bp[col][j] = bp[col][j].divide(lu[col][col], scale, roundingMode); + bpCol[j] = bpCol[j].divide(luDiag, scale, roundingMode); } for (int i = 0; i < col; i++) { + final BigDecimal[] bpI = bp[i]; + final BigDecimal[] luI = lu[i]; for (int j = 0; j < nColB; j++) { - bp[i][j] = bp[i][j].subtract(bp[col][j].multiply(lu[i][col])); + bpI[j] = bpI[j].subtract(bp[col][j].multiply(luI[col])); } } } - - BigMatrixImpl outMat = new BigMatrixImpl(bp); - return outMat; + + return new BigMatrixImpl(bp, false); + } /** @@ -1021,8 +1168,8 @@ public class BigMatrixImpl implements BigMatrix, Serializable { */ public void luDecompose() throws InvalidMatrixException { - int nRows = this.getRowDimension(); - int nCols = this.getColumnDimension(); + final int nRows = this.getRowDimension(); + final int nCols = this.getColumnDimension(); if (nRows != nCols) { throw new InvalidMatrixException("LU decomposition requires that the matrix be square."); } @@ -1042,22 +1189,24 @@ public class BigMatrixImpl implements BigMatrix, Serializable { // upper for (int row = 0; row < col; row++) { - sum = lu[row][col]; + final BigDecimal[] luRow = lu[row]; + sum = luRow[col]; for (int i = 0; i < row; i++) { - sum = sum.subtract(lu[row][i].multiply(lu[i][col])); + sum = sum.subtract(luRow[i].multiply(lu[i][col])); } - lu[row][col] = sum; + luRow[col] = sum; } // lower int max = col; // permutation row BigDecimal largest = ZERO; for (int row = col; row < nRows; row++) { - sum = lu[row][col]; + final BigDecimal[] luRow = lu[row]; + sum = luRow[col]; for (int i = 0; i < col; i++) { - sum = sum.subtract(lu[row][i].multiply(lu[i][col])); + sum = sum.subtract(luRow[i].multiply(lu[i][col])); } - lu[row][col] = sum; + luRow[col] = sum; // maintain best permutation choice if (sum.abs().compareTo(largest) == 1) { @@ -1086,9 +1235,11 @@ public class BigMatrixImpl implements BigMatrix, Serializable { parity = -parity; } - //Divide the lower elements by the "winning" diagonal elt. + // Divide the lower elements by the "winning" diagonal elt. + final BigDecimal luDiag = lu[col][col]; for (int row = col + 1; row < nRows; row++) { - lu[row][col] = lu[row][col].divide(lu[col][col], scale, roundingMode); + final BigDecimal[] luRow = lu[row]; + luRow[col] = luRow[col].divide(luDiag, scale, roundingMode); } } @@ -1104,12 +1255,14 @@ public class BigMatrixImpl implements BigMatrix, Serializable { res.append("BigMatrixImpl{"); if (data != null) { for (int i = 0; i < data.length; i++) { - if (i > 0) + if (i > 0) { res.append(","); + } res.append("{"); for (int j = 0; j < data[0].length; j++) { - if (j > 0) + if (j > 0) { res.append(","); + } res.append(data[i][j]); } res.append("}"); @@ -1135,15 +1288,16 @@ public class BigMatrixImpl implements BigMatrix, Serializable { if (object instanceof BigMatrixImpl == false) { return false; } - BigMatrix m = (BigMatrix) object; - int nRows = getRowDimension(); - int nCols = getColumnDimension(); + final BigMatrix m = (BigMatrix) object; + final int nRows = getRowDimension(); + final int nCols = getColumnDimension(); if (m.getColumnDimension() != nCols || m.getRowDimension() != nRows) { return false; } for (int row = 0; row < nRows; row++) { + final BigDecimal[] dataRow = data[row]; for (int col = 0; col < nCols; col++) { - if (!data[row][col].equals(m.getEntry(row, col))) { + if (!dataRow[col].equals(m.getEntry(row, col))) { return false; } } @@ -1158,14 +1312,15 @@ public class BigMatrixImpl implements BigMatrix, Serializable { */ public int hashCode() { int ret = 7; - int nRows = getRowDimension(); - int nCols = getColumnDimension(); + final int nRows = getRowDimension(); + final int nCols = getColumnDimension(); ret = ret * 31 + nRows; ret = ret * 31 + nCols; for (int row = 0; row < nRows; row++) { + final BigDecimal[] dataRow = data[row]; for (int col = 0; col < nCols; col++) { ret = ret * 31 + (11 * (row+1) + 17 * (col+1)) * - data[row][col].hashCode(); + dataRow[col].hashCode(); } } return ret; @@ -1220,7 +1375,7 @@ public class BigMatrixImpl implements BigMatrix, Serializable { * @return the permutation */ protected int[] getPermutation() { - int[] out = new int[permutation.length]; + final int[] out = new int[permutation.length]; System.arraycopy(permutation, 0, out, 0, permutation.length); return out; } @@ -1233,8 +1388,8 @@ public class BigMatrixImpl implements BigMatrix, Serializable { * @return a copy of the underlying data array. */ private BigDecimal[][] copyOut() { - int nRows = this.getRowDimension(); - BigDecimal[][] out = new BigDecimal[nRows][this.getColumnDimension()]; + final int nRows = this.getRowDimension(); + final BigDecimal[][] out = new BigDecimal[nRows][this.getColumnDimension()]; // can't copy 2-d array in one shot, otherwise get row references for (int i = 0; i < nRows; i++) { System.arraycopy(data[i], 0, out[i], 0, data[i].length); @@ -1262,12 +1417,14 @@ public class BigMatrixImpl implements BigMatrix, Serializable { * @param in data to copy in */ private void copyIn(double[][] in) { - int nRows = in.length; - int nCols = in[0].length; + final int nRows = in.length; + final int nCols = in[0].length; data = new BigDecimal[nRows][nCols]; for (int i = 0; i < nRows; i++) { - for (int j=0; j < nCols; j++) { - data[i][j] = new BigDecimal(in[i][j]); + final BigDecimal[] dataI = data[i]; + final double[] inI = in[i]; + for (int j = 0; j < nCols; j++) { + dataI[j] = new BigDecimal(inI[j]); } } lu = null; @@ -1280,12 +1437,14 @@ public class BigMatrixImpl implements BigMatrix, Serializable { * @param in data to copy in */ private void copyIn(String[][] in) { - int nRows = in.length; - int nCols = in[0].length; + final int nRows = in.length; + final int nCols = in[0].length; data = new BigDecimal[nRows][nCols]; for (int i = 0; i < nRows; i++) { - for (int j=0; j < nCols; j++) { - data[i][j] = new BigDecimal(in[i][j]); + final BigDecimal[] dataI = data[i]; + final String[] inI = in[i]; + for (int j = 0; j < nCols; j++) { + dataI[j] = new BigDecimal(inI[j]); } } lu = null; @@ -1299,9 +1458,8 @@ public class BigMatrixImpl implements BigMatrix, Serializable { * @return true if the coordinate is with the current dimensions */ private boolean isValidCoordinate(int row, int col) { - int nRows = this.getRowDimension(); - int nCols = this.getColumnDimension(); - + final int nRows = this.getRowDimension(); + final int nCols = this.getColumnDimension(); return !(row < 0 || row >= nRows || col < 0 || col >= nCols); } diff --git a/src/java/org/apache/commons/math/linear/MatrixUtils.java b/src/java/org/apache/commons/math/linear/MatrixUtils.java index 6f08dd0a1..2f9bf3e3d 100644 --- a/src/java/org/apache/commons/math/linear/MatrixUtils.java +++ b/src/java/org/apache/commons/math/linear/MatrixUtils.java @@ -18,6 +18,7 @@ package org.apache.commons.math.linear; import java.math.BigDecimal; +import java.util.Arrays; /** * A collection of static methods that operate on or return matrices. @@ -41,12 +42,33 @@ public class MatrixUtils { * @return RealMatrix containing the values of the array * @throws IllegalArgumentException if data is not rectangular * (not all rows have the same length) or empty - * @throws NullPointerException if data is null + * @throws NullPointerException if data is null + * @see #createRealMatrix(double[][], boolean) */ public static RealMatrix createRealMatrix(double[][] data) { return new RealMatrixImpl(data); } - + + /** + * Returns a {@link RealMatrix} whose entries are the the values in the + * the input array. + *

If an array is built specially in order to be embedded in a + * RealMatrix and not used directly, the copyArray may be + * set to false + * @param data data for new matrix + * @param copyArray if true, the input array will be copied, otherwise + * it will be referenced + * @return RealMatrix containing the values of the array + * @throws IllegalArgumentException if data is not rectangular + * (not all rows have the same length) or empty + * @throws NullPointerException if data is null + * @see #createRealMatrix(double[][]) + */ + public static RealMatrix createRealMatrix(double[][] data, boolean copyArray) { + return new RealMatrixImpl(data, copyArray); + } + /** * Returns dimension x dimension identity matrix. * @@ -56,14 +78,11 @@ public class MatrixUtils { * @since 1.1 */ public static RealMatrix createRealIdentityMatrix(int dimension) { - RealMatrixImpl out = new RealMatrixImpl(dimension, dimension); - double[][] d = out.getDataRef(); + double[][] d = new double[dimension][dimension]; for (int row = 0; row < dimension; row++) { - for (int col = 0; col < dimension; col++) { - d[row][col] = row == col ? 1d : 0d; - } + d[row][row] = 1d; } - return out; + return new RealMatrixImpl(d, false); } /** @@ -93,7 +112,27 @@ public class MatrixUtils { public static BigMatrix createBigMatrix(BigDecimal[][] data) { return new BigMatrixImpl(data); } - + + /** + * Returns a {@link BigMatrix} whose entries are the the values in the + * the input array. + *

If an array is built specially in order to be embedded in a + * BigMatrix and not used directly, the copyArray may be + * set to false + * @param data data for new matrix + * @param copyArray if true, the input array will be copied, otherwise + * it will be referenced + * @return BigMatrix containing the values of the array + * @throws IllegalArgumentException if data is not rectangular + * (not all rows have the same length) or empty + * @throws NullPointerException if data is null + * @see #createRealMatrix(double[][]) + */ + public static BigMatrix createBigMatrix(BigDecimal[][] data, boolean copyArray) { + return new BigMatrixImpl(data, copyArray); + } + /** * Returns a {@link BigMatrix} whose entries are the the values in the * the input array. The input array is copied, not referenced. @@ -118,10 +157,10 @@ public class MatrixUtils { * @throws NullPointerException if rowDatais null */ public static RealMatrix createRowRealMatrix(double[] rowData) { - int nCols = rowData.length; - double[][] data = new double[1][nCols]; + final int nCols = rowData.length; + final double[][] data = new double[1][nCols]; System.arraycopy(rowData, 0, data[0], 0, nCols); - return new RealMatrixImpl(data); + return new RealMatrixImpl(data, false); } /** @@ -134,10 +173,12 @@ public class MatrixUtils { * @throws NullPointerException if rowDatais null */ public static BigMatrix createRowBigMatrix(double[] rowData) { - int nCols = rowData.length; - double[][] data = new double[1][nCols]; - System.arraycopy(rowData, 0, data[0], 0, nCols); - return new BigMatrixImpl(data); + final int nCols = rowData.length; + final BigDecimal[][] data = new BigDecimal[1][nCols]; + for (int i = 0; i < nCols; ++i) { + data[0][i] = new BigDecimal(rowData[i]); + } + return new BigMatrixImpl(data, false); } /** @@ -150,10 +191,10 @@ public class MatrixUtils { * @throws NullPointerException if rowDatais null */ public static BigMatrix createRowBigMatrix(BigDecimal[] rowData) { - int nCols = rowData.length; - BigDecimal[][] data = new BigDecimal[1][nCols]; + final int nCols = rowData.length; + final BigDecimal[][] data = new BigDecimal[1][nCols]; System.arraycopy(rowData, 0, data[0], 0, nCols); - return new BigMatrixImpl(data); + return new BigMatrixImpl(data, false); } /** @@ -166,10 +207,12 @@ public class MatrixUtils { * @throws NullPointerException if rowDatais null */ public static BigMatrix createRowBigMatrix(String[] rowData) { - int nCols = rowData.length; - String[][] data = new String[1][nCols]; - System.arraycopy(rowData, 0, data[0], 0, nCols); - return new BigMatrixImpl(data); + final int nCols = rowData.length; + final BigDecimal[][] data = new BigDecimal[1][nCols]; + for (int i = 0; i < nCols; ++i) { + data[0][i] = new BigDecimal(rowData[i]); + } + return new BigMatrixImpl(data, false); } /** @@ -182,12 +225,12 @@ public class MatrixUtils { * @throws NullPointerException if columnDatais null */ public static RealMatrix createColumnRealMatrix(double[] columnData) { - int nRows = columnData.length; - double[][] data = new double[nRows][1]; + final int nRows = columnData.length; + final double[][] data = new double[nRows][1]; for (int row = 0; row < nRows; row++) { data[row][0] = columnData[row]; } - return new RealMatrixImpl(data); + return new RealMatrixImpl(data, false); } /** @@ -200,12 +243,12 @@ public class MatrixUtils { * @throws NullPointerException if columnDatais null */ public static BigMatrix createColumnBigMatrix(double[] columnData) { - int nRows = columnData.length; - double[][] data = new double[nRows][1]; + final int nRows = columnData.length; + final BigDecimal[][] data = new BigDecimal[nRows][1]; for (int row = 0; row < nRows; row++) { - data[row][0] = columnData[row]; + data[row][0] = new BigDecimal(columnData[row]); } - return new BigMatrixImpl(data); + return new BigMatrixImpl(data, false); } /** @@ -218,12 +261,12 @@ public class MatrixUtils { * @throws NullPointerException if columnDatais null */ public static BigMatrix createColumnBigMatrix(BigDecimal[] columnData) { - int nRows = columnData.length; - BigDecimal[][] data = new BigDecimal[nRows][1]; + final int nRows = columnData.length; + final BigDecimal[][] data = new BigDecimal[nRows][1]; for (int row = 0; row < nRows; row++) { data[row][0] = columnData[row]; } - return new BigMatrixImpl(data); + return new BigMatrixImpl(data, false); } /** @@ -237,11 +280,11 @@ public class MatrixUtils { */ public static BigMatrix createColumnBigMatrix(String[] columnData) { int nRows = columnData.length; - String[][] data = new String[nRows][1]; + final BigDecimal[][] data = new BigDecimal[nRows][1]; for (int row = 0; row < nRows; row++) { - data[row][0] = columnData[row]; + data[row][0] = new BigDecimal(columnData[row]); } - return new BigMatrixImpl(data); + return new BigMatrixImpl(data, false); } /** @@ -253,14 +296,13 @@ public class MatrixUtils { * @since 1.1 */ public static BigMatrix createBigIdentityMatrix(int dimension) { - BigMatrixImpl out = new BigMatrixImpl(dimension, dimension); - BigDecimal[][] d = out.getDataRef(); + final BigDecimal[][] d = new BigDecimal[dimension][dimension]; for (int row = 0; row < dimension; row++) { - for (int col = 0; col < dimension; col++) { - d[row][col] = row == col ? BigMatrixImpl.ONE : BigMatrixImpl.ZERO; - } + final BigDecimal[] dRow = d[row]; + Arrays.fill(dRow, BigMatrixImpl.ZERO); + dRow[row] = BigMatrixImpl.ONE; } - return out; + return new BigMatrixImpl(d, false); } } diff --git a/src/java/org/apache/commons/math/linear/RealMatrixImpl.java b/src/java/org/apache/commons/math/linear/RealMatrixImpl.java index 8c5b96c12..97599f509 100644 --- a/src/java/org/apache/commons/math/linear/RealMatrixImpl.java +++ b/src/java/org/apache/commons/math/linear/RealMatrixImpl.java @@ -52,7 +52,7 @@ import org.apache.commons.math.util.MathUtils; public class RealMatrixImpl implements RealMatrix, Serializable { /** Serializable version identifier */ - private static final long serialVersionUID = 4237564493130426188L; + private static final long serialVersionUID = -4828886979278117018L; /** Entries of the matrix */ private double data[][] = null; @@ -97,16 +97,58 @@ public class RealMatrixImpl implements RealMatrix, Serializable { /** * Create a new RealMatrix using the input array as the underlying * data array. - *

- * The input array is copied, not referenced.

+ *

The input array is copied, not referenced. This constructor has + * the same effect as calling {@link #RealMatrixImpl(double[][], boolean)} + * with the second argument set to true.

* * @param d data for new matrix * @throws IllegalArgumentException if d is not rectangular * (not all rows have the same length) or empty * @throws NullPointerException if d is null + * @see #RealMatrixImpl(double[][], boolean) */ public RealMatrixImpl(double[][] d) { - this.copyIn(d); + copyIn(d); + lu = null; + } + + /** + * Create a new RealMatrix using the input array as the underlying + * data array. + *

If an array is built specially in order to be embedded in a + * RealMatrix and not used directly, the copyArray may be + * set to false + * @param d data for new matrix + * @param copyArray if true, the input array will be copied, otherwise + * it will be referenced + * @throws IllegalArgumentException if d is not rectangular + * (not all rows have the same length) or empty + * @throws NullPointerException if d is null + * @see #RealMatrixImpl(double[][]) + */ + public RealMatrixImpl(double[][] d, boolean copyArray) { + if (copyArray) { + copyIn(d); + } else { + if (d == null) { + throw new NullPointerException(); + } + final int nRows = d.length; + if (nRows == 0) { + throw new IllegalArgumentException("Matrix must have at least one row."); + } + final int nCols = d[0].length; + if (nCols == 0) { + throw new IllegalArgumentException("Matrix must have at least one column."); + } + for (int r = 1; r < nRows; r++) { + if (d[r].length != nCols) { + throw new IllegalArgumentException("All input rows must have the same length."); + } + } + data = d; + } lu = null; } @@ -114,13 +156,12 @@ public class RealMatrixImpl implements RealMatrix, Serializable { * Create a new (column) RealMatrix using v as the * data for the unique column of the v.length x 1 matrix * created. - *

- * The input array is copied, not referenced.

+ *

The input array is copied, not referenced.

* * @param v column vector holding data for new matrix */ public RealMatrixImpl(double[] v) { - int nRows = v.length; + final int nRows = v.length; data = new double[nRows][1]; for (int row = 0; row < nRows; row++) { data[row][0] = v[row]; @@ -133,7 +174,7 @@ public class RealMatrixImpl implements RealMatrix, Serializable { * @return the cloned matrix */ public RealMatrix copy() { - return new RealMatrixImpl(this.copyOut()); + return new RealMatrixImpl(copyOut(), false); } /** @@ -144,19 +185,49 @@ public class RealMatrixImpl implements RealMatrix, Serializable { * @throws IllegalArgumentException if m is not the same size as this */ public RealMatrix add(RealMatrix m) throws IllegalArgumentException { - if (this.getColumnDimension() != m.getColumnDimension() || - this.getRowDimension() != m.getRowDimension()) { + try { + return add((RealMatrixImpl) m); + } catch (ClassCastException cce) { + final int rowCount = getRowDimension(); + final int columnCount = getColumnDimension(); + if (columnCount != m.getColumnDimension() || rowCount != m.getRowDimension()) { + throw new IllegalArgumentException("matrix dimension mismatch"); + } + final double[][] outData = new double[rowCount][columnCount]; + for (int row = 0; row < rowCount; row++) { + final double[] dataRow = data[row]; + final double[] outDataRow = outData[row]; + for (int col = 0; col < columnCount; col++) { + outDataRow[col] = dataRow[col] + m.getEntry(row, col); + } + } + return new RealMatrixImpl(outData, false); + } + } + + /** + * Compute the sum of this and m. + * + * @param m matrix to be added + * @return this + m + * @throws IllegalArgumentException if m is not the same size as this + */ + public RealMatrixImpl add(RealMatrixImpl m) throws IllegalArgumentException { + final int rowCount = getRowDimension(); + final int columnCount = getColumnDimension(); + if (columnCount != m.getColumnDimension() || rowCount != m.getRowDimension()) { throw new IllegalArgumentException("matrix dimension mismatch"); } - int rowCount = this.getRowDimension(); - int columnCount = this.getColumnDimension(); - double[][] outData = new double[rowCount][columnCount]; + final double[][] outData = new double[rowCount][columnCount]; for (int row = 0; row < rowCount; row++) { + final double[] dataRow = data[row]; + final double[] mRow = m.data[row]; + final double[] outDataRow = outData[row]; for (int col = 0; col < columnCount; col++) { - outData[row][col] = data[row][col] + m.getEntry(row, col); + outDataRow[col] = dataRow[col] + mRow[col]; } } - return new RealMatrixImpl(outData); + return new RealMatrixImpl(outData, false); } /** @@ -167,19 +238,49 @@ public class RealMatrixImpl implements RealMatrix, Serializable { * @throws IllegalArgumentException if m is not the same size as this */ public RealMatrix subtract(RealMatrix m) throws IllegalArgumentException { - if (this.getColumnDimension() != m.getColumnDimension() || - this.getRowDimension() != m.getRowDimension()) { + try { + return subtract((RealMatrixImpl) m); + } catch (ClassCastException cce) { + final int rowCount = getRowDimension(); + final int columnCount = getColumnDimension(); + if (columnCount != m.getColumnDimension() || rowCount != m.getRowDimension()) { + throw new IllegalArgumentException("matrix dimension mismatch"); + } + final double[][] outData = new double[rowCount][columnCount]; + for (int row = 0; row < rowCount; row++) { + final double[] dataRow = data[row]; + final double[] outDataRow = outData[row]; + for (int col = 0; col < columnCount; col++) { + outDataRow[col] = dataRow[col] - m.getEntry(row, col); + } + } + return new RealMatrixImpl(outData, false); + } + } + + /** + * Compute this minus m. + * + * @param m matrix to be subtracted + * @return this + m + * @throws IllegalArgumentException if m is not the same size as this + */ + public RealMatrixImpl subtract(RealMatrixImpl m) throws IllegalArgumentException { + final int rowCount = getRowDimension(); + final int columnCount = getColumnDimension(); + if (columnCount != m.getColumnDimension() || rowCount != m.getRowDimension()) { throw new IllegalArgumentException("matrix dimension mismatch"); } - int rowCount = this.getRowDimension(); - int columnCount = this.getColumnDimension(); - double[][] outData = new double[rowCount][columnCount]; + final double[][] outData = new double[rowCount][columnCount]; for (int row = 0; row < rowCount; row++) { + final double[] dataRow = data[row]; + final double[] mRow = m.data[row]; + final double[] outDataRow = outData[row]; for (int col = 0; col < columnCount; col++) { - outData[row][col] = data[row][col] - m.getEntry(row, col); - } + outDataRow[col] = dataRow[col] - mRow[col]; + } } - return new RealMatrixImpl(outData); + return new RealMatrixImpl(outData, false); } /** @@ -189,32 +290,36 @@ public class RealMatrixImpl implements RealMatrix, Serializable { * @return d + this */ public RealMatrix scalarAdd(double d) { - int rowCount = this.getRowDimension(); - int columnCount = this.getColumnDimension(); - double[][] outData = new double[rowCount][columnCount]; + final int rowCount = getRowDimension(); + final int columnCount = getColumnDimension(); + final double[][] outData = new double[rowCount][columnCount]; for (int row = 0; row < rowCount; row++) { + final double[] dataRow = data[row]; + final double[] outDataRow = outData[row]; for (int col = 0; col < columnCount; col++) { - outData[row][col] = data[row][col] + d; + outDataRow[col] = dataRow[col] + d; } } - return new RealMatrixImpl(outData); + return new RealMatrixImpl(outData, false); } /** - * Returns the result multiplying each entry of this by d + * Returns the result of multiplying each entry of this by d * @param d value to multiply all entries by * @return d * this */ public RealMatrix scalarMultiply(double d) { - int rowCount = this.getRowDimension(); - int columnCount = this.getColumnDimension(); - double[][] outData = new double[rowCount][columnCount]; + final int rowCount = getRowDimension(); + final int columnCount = getColumnDimension(); + final double[][] outData = new double[rowCount][columnCount]; for (int row = 0; row < rowCount; row++) { + final double[] dataRow = data[row]; + final double[] outDataRow = outData[row]; for (int col = 0; col < columnCount; col++) { - outData[row][col] = data[row][col] * d; + outDataRow[col] = dataRow[col] * d; } } - return new RealMatrixImpl(outData); + return new RealMatrixImpl(outData, false); } /** @@ -225,28 +330,62 @@ public class RealMatrixImpl implements RealMatrix, Serializable { * if columnDimension(this) != rowDimension(m) */ public RealMatrix multiply(RealMatrix m) throws IllegalArgumentException { - if (this.getColumnDimension() != m.getRowDimension()) { - throw new IllegalArgumentException("Matrices are not multiplication compatible."); - } - int nRows = this.getRowDimension(); - int nCols = m.getColumnDimension(); - int nSum = this.getColumnDimension(); - double[][] outData = new double[nRows][nCols]; - double sum = 0; - for (int row = 0; row < nRows; row++) { - for (int col = 0; col < nCols; col++) { - sum = 0; - for (int i = 0; i < nSum; i++) { - sum += data[row][i] * m.getEntry(i, col); - } - outData[row][col] = sum; + try { + return multiply((RealMatrixImpl) m); + } catch (ClassCastException cce) { + if (this.getColumnDimension() != m.getRowDimension()) { + throw new IllegalArgumentException("Matrices are not multiplication compatible."); } + final int nRows = this.getRowDimension(); + final int nCols = m.getColumnDimension(); + final int nSum = this.getColumnDimension(); + final double[][] outData = new double[nRows][nCols]; + for (int row = 0; row < nRows; row++) { + final double[] dataRow = data[row]; + final double[] outDataRow = outData[row]; + for (int col = 0; col < nCols; col++) { + double sum = 0; + for (int i = 0; i < nSum; i++) { + sum += dataRow[i] * m.getEntry(i, col); + } + outDataRow[col] = sum; + } + } + return new RealMatrixImpl(outData, false); } - return new RealMatrixImpl(outData); } /** - * Returns the result premultiplying this by m. + * Returns the result of postmultiplying this by m. + * @param m matrix to postmultiply by + * @return this*m + * @throws IllegalArgumentException + * if columnDimension(this) != rowDimension(m) + */ + public RealMatrixImpl multiply(RealMatrixImpl m) throws IllegalArgumentException { + if (this.getColumnDimension() != m.getRowDimension()) { + throw new IllegalArgumentException("Matrices are not multiplication compatible."); + } + final int nRows = this.getRowDimension(); + final int nCols = m.getColumnDimension(); + final int nSum = this.getColumnDimension(); + final double[][] outData = new double[nRows][nCols]; + for (int row = 0; row < nRows; row++) { + final double[] dataRow = data[row]; + final double[] outDataRow = outData[row]; + for (int col = 0; col < nCols; col++) { + double sum = 0; + for (int i = 0; i < nSum; i++) { + sum += dataRow[i] * m.data[i][col]; + } + outDataRow[col] = sum; + } + } + return new RealMatrixImpl(outData, false); + } + + /** + * Returns the result of premultiplying this by m. * @param m matrix to premultiply by * @return m * this * @throws IllegalArgumentException @@ -306,23 +445,23 @@ public class RealMatrixImpl implements RealMatrix, Serializable { * specified rows and columns * @exception MatrixIndexException if row or column selections are not valid */ - public RealMatrix getSubMatrix(int startRow, int endRow, int startColumn, - int endColumn) throws MatrixIndexException { + public RealMatrix getSubMatrix(int startRow, int endRow, + int startColumn, int endColumn) + throws MatrixIndexException { if (startRow < 0 || startRow > endRow || endRow > data.length || startColumn < 0 || startColumn > endColumn || - endColumn > data[0].length ) { + endColumn > data[0].length) { throw new MatrixIndexException( "invalid row or column index selection"); } - RealMatrixImpl subMatrix = new RealMatrixImpl(endRow - startRow+1, - endColumn - startColumn+1); - double[][] subMatrixData = subMatrix.getDataRef(); + final double[][] subMatrixData = + new double[endRow - startRow + 1][endColumn - startColumn + 1]; for (int i = startRow; i <= endRow; i++) { - for (int j = startColumn; j <= endColumn; j++) { - subMatrixData[i - startRow][j - startColumn] = data[i][j]; - } - } - return subMatrix; + System.arraycopy(data[i], startColumn, + subMatrixData[i - startRow], 0, + endColumn - startColumn + 1); + } + return new RealMatrixImpl(subMatrixData, false); } /** @@ -337,25 +476,25 @@ public class RealMatrixImpl implements RealMatrix, Serializable { * are not valid */ public RealMatrix getSubMatrix(int[] selectedRows, int[] selectedColumns) - throws MatrixIndexException { + throws MatrixIndexException { if (selectedRows.length * selectedColumns.length == 0) { throw new MatrixIndexException( "selected row and column index arrays must be non-empty"); } - RealMatrixImpl subMatrix = new RealMatrixImpl(selectedRows.length, - selectedColumns.length); - double[][] subMatrixData = subMatrix.getDataRef(); + final double[][] subMatrixData = + new double[selectedRows.length][selectedColumns.length]; try { for (int i = 0; i < selectedRows.length; i++) { + final double[] subI = subMatrixData[i]; + final double[] dataSelectedI = data[selectedRows[i]]; for (int j = 0; j < selectedColumns.length; j++) { - subMatrixData[i][j] = data[selectedRows[i]][selectedColumns[j]]; + subI[j] = dataSelectedI[selectedColumns[j]]; } } - } - catch (ArrayIndexOutOfBoundsException e) { + } catch (ArrayIndexOutOfBoundsException e) { throw new MatrixIndexException("matrix dimension mismatch"); } - return subMatrix; + return new RealMatrixImpl(subMatrixData, false); } /** @@ -391,12 +530,12 @@ public class RealMatrixImpl implements RealMatrix, Serializable { throw new MatrixIndexException ("invalid row or column index selection"); } - int nRows = subMatrix.length; + final int nRows = subMatrix.length; if (nRows == 0) { throw new IllegalArgumentException( "Matrix must have at least one row."); } - int nCols = subMatrix[0].length; + final int nCols = subMatrix[0].length; if (nCols == 0) { throw new IllegalArgumentException( "Matrix must have at least one column."); @@ -435,10 +574,10 @@ public class RealMatrixImpl implements RealMatrix, Serializable { if ( !isValidCoordinate( row, 0)) { throw new MatrixIndexException("illegal row argument"); } - int ncols = this.getColumnDimension(); - double[][] out = new double[1][ncols]; + final int ncols = this.getColumnDimension(); + final double[][] out = new double[1][ncols]; System.arraycopy(data[row], 0, out[0], 0, ncols); - return new RealMatrixImpl(out); + return new RealMatrixImpl(out, false); } /** @@ -453,12 +592,12 @@ public class RealMatrixImpl implements RealMatrix, Serializable { if ( !isValidCoordinate( 0, column)) { throw new MatrixIndexException("illegal column argument"); } - int nRows = this.getRowDimension(); - double[][] out = new double[nRows][1]; + final int nRows = this.getRowDimension(); + final double[][] out = new double[nRows][1]; for (int row = 0; row < nRows; row++) { out[row][0] = data[row][column]; } - return new RealMatrixImpl(out); + return new RealMatrixImpl(out, false); } /** @@ -475,8 +614,8 @@ public class RealMatrixImpl implements RealMatrix, Serializable { if ( !isValidCoordinate( row, 0 ) ) { throw new MatrixIndexException("illegal row argument"); } - int ncols = this.getColumnDimension(); - double[] out = new double[ncols]; + final int ncols = this.getColumnDimension(); + final double[] out = new double[ncols]; System.arraycopy(data[row], 0, out, 0, ncols); return out; } @@ -495,8 +634,8 @@ public class RealMatrixImpl implements RealMatrix, Serializable { if ( !isValidCoordinate(0, col) ) { throw new MatrixIndexException("illegal column argument"); } - int nRows = this.getRowDimension(); - double[] out = new double[nRows]; + final int nRows = this.getRowDimension(); + final double[] out = new double[nRows]; for (int row = 0; row < nRows; row++) { out[row] = data[row][col]; } @@ -520,10 +659,11 @@ public class RealMatrixImpl implements RealMatrix, Serializable { */ public double getEntry(int row, int column) throws MatrixIndexException { - if (!isValidCoordinate(row,column)) { + try { + return data[row][column]; + } catch (ArrayIndexOutOfBoundsException e) { throw new MatrixIndexException("matrix entry does not exist"); } - return data[row][column]; } /** @@ -532,16 +672,16 @@ public class RealMatrixImpl implements RealMatrix, Serializable { * @return transpose matrix */ public RealMatrix transpose() { - int nRows = this.getRowDimension(); - int nCols = this.getColumnDimension(); - RealMatrixImpl out = new RealMatrixImpl(nCols, nRows); - double[][] outData = out.getDataRef(); + final int nRows = getRowDimension(); + final int nCols = getColumnDimension(); + final double[][] outData = new double[nCols][nRows]; for (int row = 0; row < nRows; row++) { + final double[] dataRow = data[row]; for (int col = 0; col < nCols; col++) { - outData[col][row] = data[row][col]; + outData[col][row] = dataRow[col]; } } - return out; + return new RealMatrixImpl(outData, false); } /** @@ -551,8 +691,7 @@ public class RealMatrixImpl implements RealMatrix, Serializable { * @throws InvalidMatrixException if this is not invertible */ public RealMatrix inverse() throws InvalidMatrixException { - return solve(MatrixUtils.createRealIdentityMatrix - (this.getRowDimension())); + return solve(MatrixUtils.createRealIdentityMatrix(getRowDimension())); } /** @@ -632,16 +771,17 @@ public class RealMatrixImpl implements RealMatrix, Serializable { * @return resulting vector */ public double[] operate(double[] v) throws IllegalArgumentException { - if (v.length != this.getColumnDimension()) { + final int nRows = this.getRowDimension(); + final int nCols = this.getColumnDimension(); + if (v.length != nCols) { throw new IllegalArgumentException("vector has wrong length"); } - int nRows = this.getRowDimension(); - int nCols = this.getColumnDimension(); - double[] out = new double[v.length]; + final double[] out = new double[v.length]; for (int row = 0; row < nRows; row++) { + final double[] dataRow = data[row]; double sum = 0; for (int i = 0; i < nCols; i++) { - sum += data[row][i] * v[i]; + sum += dataRow[i] * v[i]; } out[row] = sum; } @@ -654,12 +794,12 @@ public class RealMatrixImpl implements RealMatrix, Serializable { * @return resulting matrix */ public double[] preMultiply(double[] v) throws IllegalArgumentException { - int nRows = this.getRowDimension(); + final int nRows = this.getRowDimension(); if (v.length != nRows) { throw new IllegalArgumentException("vector has wrong length"); } - int nCols = this.getColumnDimension(); - double[] out = new double[nCols]; + final int nCols = this.getColumnDimension(); + final double[] out = new double[nCols]; for (int col = 0; col < nCols; col++) { double sum = 0; for (int i = 0; i < nRows; i++) { @@ -682,13 +822,13 @@ public class RealMatrixImpl implements RealMatrix, Serializable { * @throws InvalidMatrixException if this matrix is not square or is singular */ public double[] solve(double[] b) throws IllegalArgumentException, InvalidMatrixException { - int nRows = this.getRowDimension(); + final int nRows = this.getRowDimension(); if (b.length != nRows) { throw new IllegalArgumentException("constant vector has wrong length"); } - RealMatrix bMatrix = new RealMatrixImpl(b); - double[][] solution = ((RealMatrixImpl) (solve(bMatrix))).getDataRef(); - double[] out = new double[nRows]; + final RealMatrix bMatrix = new RealMatrixImpl(b); + final double[][] solution = ((RealMatrixImpl) (solve(bMatrix))).getDataRef(); + final double[] out = new double[nRows]; for (int row = 0; row < nRows; row++) { out[row] = solution[row][0]; } @@ -717,41 +857,48 @@ public class RealMatrixImpl implements RealMatrix, Serializable { throw new InvalidMatrixException("Matrix is singular."); } - int nCol = this.getColumnDimension(); - int nColB = b.getColumnDimension(); - int nRowB = b.getRowDimension(); + final int nCol = this.getColumnDimension(); + final int nColB = b.getColumnDimension(); + final int nRowB = b.getRowDimension(); // Apply permutations to b - double[][] bp = new double[nRowB][nColB]; + final double[][] bp = new double[nRowB][nColB]; for (int row = 0; row < nRowB; row++) { + final double[] bpRow = bp[row]; for (int col = 0; col < nColB; col++) { - bp[row][col] = b.getEntry(permutation[row], col); + bpRow[col] = b.getEntry(permutation[row], col); } } // Solve LY = b for (int col = 0; col < nCol; col++) { for (int i = col + 1; i < nCol; i++) { + final double[] bpI = bp[i]; + final double[] luI = lu[i]; for (int j = 0; j < nColB; j++) { - bp[i][j] -= bp[col][j] * lu[i][col]; + bpI[j] -= bp[col][j] * luI[col]; } } } // Solve UX = Y for (int col = nCol - 1; col >= 0; col--) { + final double[] bpCol = bp[col]; + final double luDiag = lu[col][col]; for (int j = 0; j < nColB; j++) { - bp[col][j] /= lu[col][col]; + bpCol[j] /= luDiag; } for (int i = 0; i < col; i++) { + final double[] bpI = bp[i]; + final double[] luI = lu[i]; for (int j = 0; j < nColB; j++) { - bp[i][j] -= bp[col][j] * lu[i][col]; + bpI[j] -= bp[col][j] * luI[col]; } } } - RealMatrixImpl outMat = new RealMatrixImpl(bp); - return outMat; + return new RealMatrixImpl(bp, false); + } /** @@ -774,12 +921,12 @@ public class RealMatrixImpl implements RealMatrix, Serializable { */ public void luDecompose() throws InvalidMatrixException { - int nRows = this.getRowDimension(); - int nCols = this.getColumnDimension(); + final int nRows = this.getRowDimension(); + final int nCols = this.getColumnDimension(); if (nRows != nCols) { throw new InvalidMatrixException("LU decomposition requires that the matrix be square."); } - lu = this.getData(); + lu = getData(); // Initialize permutation array and parity permutation = new int[nRows]; @@ -795,22 +942,24 @@ public class RealMatrixImpl implements RealMatrix, Serializable { // upper for (int row = 0; row < col; row++) { - sum = lu[row][col]; + final double[] luRow = lu[row]; + sum = luRow[col]; for (int i = 0; i < row; i++) { - sum -= lu[row][i] * lu[i][col]; + sum -= luRow[i] * lu[i][col]; } - lu[row][col] = sum; + luRow[col] = sum; } // lower int max = col; // permutation row double largest = 0d; for (int row = col; row < nRows; row++) { - sum = lu[row][col]; + final double[] luRow = lu[row]; + sum = luRow[col]; for (int i = 0; i < col; i++) { - sum -= lu[row][i] * lu[i][col]; + sum -= luRow[i] * lu[i][col]; } - lu[row][col] = sum; + luRow[col] = sum; // maintain best permutation choice if (Math.abs(sum) > largest) { @@ -839,9 +988,10 @@ public class RealMatrixImpl implements RealMatrix, Serializable { parity = -parity; } - //Divide the lower elements by the "winning" diagonal elt. + // Divide the lower elements by the "winning" diagonal elt. + final double luDiag = lu[col][col]; for (int row = col + 1; row < nRows; row++) { - lu[row][col] /= lu[col][col]; + lu[row][col] /= luDiag; } } } @@ -855,12 +1005,14 @@ public class RealMatrixImpl implements RealMatrix, Serializable { res.append("RealMatrixImpl{"); if (data != null) { for (int i = 0; i < data.length; i++) { - if (i > 0) + if (i > 0) { res.append(","); + } res.append("{"); for (int j = 0; j < data[0].length; j++) { - if (j > 0) + if (j > 0) { res.append(","); + } res.append(data[i][j]); } res.append("}"); @@ -887,14 +1039,15 @@ public class RealMatrixImpl implements RealMatrix, Serializable { return false; } RealMatrix m = (RealMatrix) object; - int nRows = getRowDimension(); - int nCols = getColumnDimension(); + final int nRows = getRowDimension(); + final int nCols = getColumnDimension(); if (m.getColumnDimension() != nCols || m.getRowDimension() != nRows) { return false; } for (int row = 0; row < nRows; row++) { + final double[] dataRow = data[row]; for (int col = 0; col < nCols; col++) { - if (Double.doubleToLongBits(data[row][col]) != + if (Double.doubleToLongBits(dataRow[col]) != Double.doubleToLongBits(m.getEntry(row, col))) { return false; } @@ -910,14 +1063,15 @@ public class RealMatrixImpl implements RealMatrix, Serializable { */ public int hashCode() { int ret = 7; - int nRows = getRowDimension(); - int nCols = getColumnDimension(); + final int nRows = getRowDimension(); + final int nCols = getColumnDimension(); ret = ret * 31 + nRows; ret = ret * 31 + nCols; for (int row = 0; row < nRows; row++) { - for (int col = 0; col < nCols; col++) { + final double[] dataRow = data[row]; + for (int col = 0; col < nCols; col++) { ret = ret * 31 + (11 * (row+1) + 17 * (col+1)) * - MathUtils.hash(data[row][col]); + MathUtils.hash(dataRow[col]); } } return ret; @@ -972,7 +1126,7 @@ public class RealMatrixImpl implements RealMatrix, Serializable { * @return the permutation */ protected int[] getPermutation() { - int[] out = new int[permutation.length]; + final int[] out = new int[permutation.length]; System.arraycopy(permutation, 0, out, 0, permutation.length); return out; } @@ -985,8 +1139,8 @@ public class RealMatrixImpl implements RealMatrix, Serializable { * @return a copy of the underlying data array. */ private double[][] copyOut() { - int nRows = this.getRowDimension(); - double[][] out = new double[nRows][this.getColumnDimension()]; + final int nRows = this.getRowDimension(); + final double[][] out = new double[nRows][this.getColumnDimension()]; // can't copy 2-d array in one shot, otherwise get row references for (int i = 0; i < nRows; i++) { System.arraycopy(data[i], 0, out[i], 0, data[i].length); @@ -1016,9 +1170,8 @@ public class RealMatrixImpl implements RealMatrix, Serializable { * @return true if the coordinate is with the current dimensions */ private boolean isValidCoordinate(int row, int col) { - int nRows = this.getRowDimension(); - int nCols = this.getColumnDimension(); - + final int nRows = getRowDimension(); + final int nCols = getColumnDimension(); return !(row < 0 || row > nRows - 1 || col < 0 || col > nCols -1); } diff --git a/src/site/xdoc/changes.xml b/src/site/xdoc/changes.xml index 7b8d68376..acd5465eb 100644 --- a/src/site/xdoc/changes.xml +++ b/src/site/xdoc/changes.xml @@ -39,6 +39,10 @@ The type attribute can be add,update,fix,remove. + + Greatly improved RealMatrixImpl and BigMatrixImpl performances, + both in terms of speed and in terms of temporary memory footprint. + Added Mauro's patch to support multiple regression. diff --git a/src/test/org/apache/commons/math/linear/BigMatrixImplTest.java b/src/test/org/apache/commons/math/linear/BigMatrixImplTest.java index b1b982914..f30e04b9c 100644 --- a/src/test/org/apache/commons/math/linear/BigMatrixImplTest.java +++ b/src/test/org/apache/commons/math/linear/BigMatrixImplTest.java @@ -156,9 +156,12 @@ public final class BigMatrixImplTest extends TestCase { /** test copy functions */ public void testCopyFunctions() { - BigMatrixImpl m = new BigMatrixImpl(testData); - BigMatrixImpl m2 = new BigMatrixImpl(m.getData()); - assertEquals(m2,m); + BigMatrixImpl m1 = new BigMatrixImpl(testData); + BigMatrixImpl m2 = new BigMatrixImpl(m1.getData()); + assertEquals(m2,m1); + BigMatrixImpl m3 = new BigMatrixImpl(testData); + BigMatrixImpl m4 = new BigMatrixImpl(m3.getData(), false); + assertEquals(m4,m3); } /** test constructors */ @@ -166,9 +169,13 @@ public final class BigMatrixImplTest extends TestCase { BigMatrix m1 = new BigMatrixImpl(testData); BigMatrix m2 = new BigMatrixImpl(testDataString); BigMatrix m3 = new BigMatrixImpl(asBigDecimal(testData)); + BigMatrix m4 = new BigMatrixImpl(asBigDecimal(testData), true); + BigMatrix m5 = new BigMatrixImpl(asBigDecimal(testData), false); assertClose("double, string", m1, m2, Double.MIN_VALUE); assertClose("double, BigDecimal", m1, m3, Double.MIN_VALUE); assertClose("string, BigDecimal", m2, m3, Double.MIN_VALUE); + assertClose("double, BigDecimal/true", m1, m4, Double.MIN_VALUE); + assertClose("double, BigDecimal/false", m1, m5, Double.MIN_VALUE); try { new BigMatrixImpl(new String[][] {{"0", "hello", "1"}}); fail("Expecting NumberFormatException"); @@ -212,7 +219,7 @@ public final class BigMatrixImplTest extends TestCase { public void testAdd() { BigMatrixImpl m = new BigMatrixImpl(testData); BigMatrixImpl mInv = new BigMatrixImpl(testDataInv); - BigMatrixImpl mPlusMInv = (BigMatrixImpl)m.add(mInv); + BigMatrix mPlusMInv = m.add(mInv); double[][] sumEntries = asDouble(mPlusMInv.getData()); for (int row = 0; row < m.getRowDimension(); row++) { for (int col = 0; col < m.getColumnDimension(); col++) { diff --git a/src/test/org/apache/commons/math/linear/MatrixUtilsTest.java b/src/test/org/apache/commons/math/linear/MatrixUtilsTest.java index 90cba3a4b..bc251f41b 100644 --- a/src/test/org/apache/commons/math/linear/MatrixUtilsTest.java +++ b/src/test/org/apache/commons/math/linear/MatrixUtilsTest.java @@ -65,6 +65,10 @@ public final class MatrixUtilsTest extends TestCase { public void testCreateRealMatrix() { assertEquals(new RealMatrixImpl(testData), MatrixUtils.createRealMatrix(testData)); + assertEquals(new RealMatrixImpl(testData, false), + MatrixUtils.createRealMatrix(testData, true)); + assertEquals(new RealMatrixImpl(testData, true), + MatrixUtils.createRealMatrix(testData, false)); try { MatrixUtils.createRealMatrix(new double[][] {{1}, {1,2}}); // ragged fail("Expecting IllegalArgumentException"); @@ -88,6 +92,10 @@ public final class MatrixUtilsTest extends TestCase { public void testCreateBigMatrix() { assertEquals(new BigMatrixImpl(testData), MatrixUtils.createBigMatrix(testData)); + assertEquals(new BigMatrixImpl(BigMatrixImplTest.asBigDecimal(testData), true), + MatrixUtils.createBigMatrix(BigMatrixImplTest.asBigDecimal(testData), false)); + assertEquals(new BigMatrixImpl(BigMatrixImplTest.asBigDecimal(testData), false), + MatrixUtils.createBigMatrix(BigMatrixImplTest.asBigDecimal(testData), true)); assertEquals(new BigMatrixImpl(bigColMatrix), MatrixUtils.createBigMatrix(bigColMatrix)); assertEquals(new BigMatrixImpl(stringColMatrix), diff --git a/src/test/org/apache/commons/math/linear/QRDecompositionImplTest.java b/src/test/org/apache/commons/math/linear/QRDecompositionImplTest.java index c0604c2ca..1147a2a83 100644 --- a/src/test/org/apache/commons/math/linear/QRDecompositionImplTest.java +++ b/src/test/org/apache/commons/math/linear/QRDecompositionImplTest.java @@ -59,21 +59,21 @@ public class QRDecompositionImplTest extends TestCase { /** test dimensions */ public void testDimensions() { - RealMatrixImpl matrix = new RealMatrixImpl(testData3x3NonSingular); + RealMatrixImpl matrix = new RealMatrixImpl(testData3x3NonSingular, false); QRDecomposition qr = new QRDecompositionImpl(matrix); assertEquals("3x3 Q size", qr.getQ().getRowDimension(), 3); assertEquals("3x3 Q size", qr.getQ().getColumnDimension(), 3); assertEquals("3x3 R size", qr.getR().getRowDimension(), 3); assertEquals("3x3 R size", qr.getR().getColumnDimension(), 3); - matrix = new RealMatrixImpl(testData4x3); + matrix = new RealMatrixImpl(testData4x3, false); qr = new QRDecompositionImpl(matrix); assertEquals("4x3 Q size", qr.getQ().getRowDimension(), 4); assertEquals("4x3 Q size", qr.getQ().getColumnDimension(), 4); assertEquals("4x3 R size", qr.getR().getRowDimension(), 4); assertEquals("4x3 R size", qr.getR().getColumnDimension(), 3); - matrix = new RealMatrixImpl(testData3x4); + matrix = new RealMatrixImpl(testData3x4, false); qr = new QRDecompositionImpl(matrix); assertEquals("3x4 Q size", qr.getQ().getRowDimension(), 3); assertEquals("3x4 Q size", qr.getQ().getColumnDimension(), 3); @@ -83,24 +83,24 @@ public class QRDecompositionImplTest extends TestCase { /** test A = QR */ public void testAEqualQR() { - RealMatrix A = new RealMatrixImpl(testData3x3NonSingular); + RealMatrix A = new RealMatrixImpl(testData3x3NonSingular, false); QRDecomposition qr = new QRDecompositionImpl(A); RealMatrix Q = qr.getQ(); RealMatrix R = qr.getR(); double norm = Q.multiply(R).subtract(A).getNorm(); assertEquals("3x3 nonsingular A = QR", 0, norm, normTolerance); - RealMatrix matrix = new RealMatrixImpl(testData3x3Singular); + RealMatrix matrix = new RealMatrixImpl(testData3x3Singular, false); qr = new QRDecompositionImpl(matrix); norm = qr.getQ().multiply(qr.getR()).subtract(matrix).getNorm(); assertEquals("3x3 singular A = QR", 0, norm, normTolerance); - matrix = new RealMatrixImpl(testData3x4); + matrix = new RealMatrixImpl(testData3x4, false); qr = new QRDecompositionImpl(matrix); norm = qr.getQ().multiply(qr.getR()).subtract(matrix).getNorm(); assertEquals("3x4 A = QR", 0, norm, normTolerance); - matrix = new RealMatrixImpl(testData4x3); + matrix = new RealMatrixImpl(testData4x3, false); qr = new QRDecompositionImpl(matrix); norm = qr.getQ().multiply(qr.getR()).subtract(matrix).getNorm(); assertEquals("4x3 A = QR", 0, norm, normTolerance); @@ -108,28 +108,28 @@ public class QRDecompositionImplTest extends TestCase { /** test the orthogonality of Q */ public void testQOrthogonal() { - RealMatrix matrix = new RealMatrixImpl(testData3x3NonSingular); + RealMatrix matrix = new RealMatrixImpl(testData3x3NonSingular, false); matrix = new QRDecompositionImpl(matrix).getQ(); RealMatrix eye = MatrixUtils.createRealIdentityMatrix(3); double norm = matrix.transpose().multiply(matrix).subtract(eye) .getNorm(); assertEquals("3x3 nonsingular Q'Q = I", 0, norm, normTolerance); - matrix = new RealMatrixImpl(testData3x3Singular); + matrix = new RealMatrixImpl(testData3x3Singular, false); matrix = new QRDecompositionImpl(matrix).getQ(); eye = MatrixUtils.createRealIdentityMatrix(3); norm = matrix.transpose().multiply(matrix).subtract(eye) .getNorm(); assertEquals("3x3 singular Q'Q = I", 0, norm, normTolerance); - matrix = new RealMatrixImpl(testData3x4); + matrix = new RealMatrixImpl(testData3x4, false); matrix = new QRDecompositionImpl(matrix).getQ(); eye = MatrixUtils.createRealIdentityMatrix(3); norm = matrix.transpose().multiply(matrix).subtract(eye) .getNorm(); assertEquals("3x4 Q'Q = I", 0, norm, normTolerance); - matrix = new RealMatrixImpl(testData4x3); + matrix = new RealMatrixImpl(testData4x3, false); matrix = new QRDecompositionImpl(matrix).getQ(); eye = MatrixUtils.createRealIdentityMatrix(4); norm = matrix.transpose().multiply(matrix).subtract(eye) @@ -139,21 +139,21 @@ public class QRDecompositionImplTest extends TestCase { /** test that R is upper triangular */ public void testRUpperTriangular() { - RealMatrixImpl matrix = new RealMatrixImpl(testData3x3NonSingular); + RealMatrixImpl matrix = new RealMatrixImpl(testData3x3NonSingular, false); RealMatrix R = new QRDecompositionImpl(matrix).getR(); for (int i = 0; i < R.getRowDimension(); i++) for (int j = 0; j < i; j++) assertEquals("R lower triangle", R.getEntry(i, j), 0, entryTolerance); - matrix = new RealMatrixImpl(testData3x4); + matrix = new RealMatrixImpl(testData3x4, false); R = new QRDecompositionImpl(matrix).getR(); for (int i = 0; i < R.getRowDimension(); i++) for (int j = 0; j < i; j++) assertEquals("R lower triangle", R.getEntry(i, j), 0, entryTolerance); - matrix = new RealMatrixImpl(testData4x3); + matrix = new RealMatrixImpl(testData4x3, false); R = new QRDecompositionImpl(matrix).getR(); for (int i = 0; i < R.getRowDimension(); i++) for (int j = 0; j < i; j++) diff --git a/src/test/org/apache/commons/math/linear/RealMatrixImplTest.java b/src/test/org/apache/commons/math/linear/RealMatrixImplTest.java index 38fafd6b2..7553f3675 100644 --- a/src/test/org/apache/commons/math/linear/RealMatrixImplTest.java +++ b/src/test/org/apache/commons/math/linear/RealMatrixImplTest.java @@ -116,16 +116,19 @@ public final class RealMatrixImplTest extends TestCase { /** test copy functions */ public void testCopyFunctions() { - RealMatrixImpl m = new RealMatrixImpl(testData); - RealMatrixImpl m2 = new RealMatrixImpl(m.getData()); - assertEquals(m2,m); + RealMatrixImpl m1 = new RealMatrixImpl(testData); + RealMatrixImpl m2 = new RealMatrixImpl(m1.getData()); + assertEquals(m2,m1); + RealMatrixImpl m3 = new RealMatrixImpl(testData); + RealMatrixImpl m4 = new RealMatrixImpl(m3.getData(), false); + assertEquals(m4,m3); } /** test add */ public void testAdd() { RealMatrixImpl m = new RealMatrixImpl(testData); RealMatrixImpl mInv = new RealMatrixImpl(testDataInv); - RealMatrixImpl mPlusMInv = (RealMatrixImpl)m.add(mInv); + RealMatrix mPlusMInv = m.add(mInv); double[][] sumEntries = mPlusMInv.getData(); for (int row = 0; row < m.getRowDimension(); row++) { for (int col = 0; col < m.getColumnDimension(); col++) {