Added support for sparse real matrices

JIRA: MATH-230

git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@726460 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Luc Maisonobe 2008-12-14 15:04:22 +00:00
parent 578a219a0f
commit 11d8f0ec5f
7 changed files with 985 additions and 67 deletions

View File

@ -135,6 +135,9 @@
<contributor>
<name>Fredrik Norin</name>
</contributor>
<contributor>
<name>Sujit Pal</name>
</contributor>
<contributor>
<name>Todd C. Parnell</name>
</contributor>

View File

@ -78,19 +78,12 @@ public abstract class AbstractRealMatrix implements RealMatrix, Serializable {
/** {@inheritDoc} */
public RealMatrix add(RealMatrix m) throws IllegalArgumentException {
// safety check
checkAdditionCompatible(m);
final int rowCount = getRowDimension();
final int columnCount = getColumnDimension();
if (columnCount != m.getColumnDimension() || rowCount != m.getRowDimension()) {
throw MathRuntimeException.createIllegalArgumentException("{0}x{1} and {2}x{3} matrices are not" +
" addition compatible",
new Object[] {
getRowDimension(),
getColumnDimension(),
m.getRowDimension(),
m.getColumnDimension()
});
}
final RealMatrix out = createMatrix(rowCount, columnCount);
for (int row = 0; row < rowCount; ++row) {
for (int col = 0; col < columnCount; ++col) {
@ -104,19 +97,12 @@ public abstract class AbstractRealMatrix implements RealMatrix, Serializable {
/** {@inheritDoc} */
public RealMatrix subtract(final RealMatrix m) throws IllegalArgumentException {
// safety check
checkSubtractionCompatible(m);
final int rowCount = getRowDimension();
final int columnCount = getColumnDimension();
if (columnCount != m.getColumnDimension() || rowCount != m.getRowDimension()) {
throw MathRuntimeException.createIllegalArgumentException("{0}x{1} and {2}x{3} matrices are not" +
" subtraction compatible",
new Object[] {
getRowDimension(),
getColumnDimension(),
m.getRowDimension(),
m.getColumnDimension()
});
}
final RealMatrix out = createMatrix(rowCount, columnCount);
for (int row = 0; row < rowCount; ++row) {
for (int col = 0; col < columnCount; ++col) {
@ -163,16 +149,9 @@ public abstract class AbstractRealMatrix implements RealMatrix, Serializable {
/** {@inheritDoc} */
public RealMatrix multiply(final RealMatrix m)
throws IllegalArgumentException {
if (getColumnDimension() != m.getRowDimension()) {
throw MathRuntimeException.createIllegalArgumentException("{0}x{1} and {2}x{3} matrices are not" +
" multiplication compatible",
new Object[] {
getRowDimension(),
getColumnDimension(),
m.getRowDimension(),
m.getColumnDimension()
});
}
// safety check
checkMultiplicationCompatible(m);
final int nRows = getRowDimension();
final int nCols = m.getColumnDimension();
@ -199,7 +178,20 @@ public abstract class AbstractRealMatrix implements RealMatrix, Serializable {
}
/** {@inheritDoc} */
public abstract double[][] getData();
public double[][] getData() {
final double[][] data = new double[getRowDimension()][getColumnDimension()];
for (int i = 0; i < data.length; ++i) {
final double[] dataI = data[i];
for (int j = 0; j < dataI.length; ++j) {
dataI[j] = getEntry(i, j);
}
}
return data;
}
/** {@inheritDoc} */
public double getNorm() {
@ -767,7 +759,9 @@ public abstract class AbstractRealMatrix implements RealMatrix, Serializable {
final int nRows = getRowDimension();
final int nCols = getColumnDimension();
final StringBuffer res = new StringBuffer();
res.append("RealMatrixImpl{");
String fullClassName = getClass().getName();
String shortClassName = fullClassName.substring(fullClassName.lastIndexOf('.') + 1);
res.append(shortClassName).append("{");
for (int i = 0; i < nRows; ++i) {
if (i > 0) {
@ -844,7 +838,7 @@ public abstract class AbstractRealMatrix implements RealMatrix, Serializable {
* @param row row index to check
* @exception MatrixIndexException if index is not valid
*/
private void checkRowIndex(final int row) {
protected void checkRowIndex(final int row) {
if (row < 0 || row >= getRowDimension()) {
throw new MatrixIndexException("row index {0} out of allowed range [{1}, {2}]",
new Object[] { row, 0, getRowDimension() - 1});
@ -856,7 +850,7 @@ public abstract class AbstractRealMatrix implements RealMatrix, Serializable {
* @param column column index to check
* @exception MatrixIndexException if index is not valid
*/
private void checkColumnIndex(final int column)
protected void checkColumnIndex(final int column)
throws MatrixIndexException {
if (column < 0 || column >= getColumnDimension()) {
throw new MatrixIndexException("column index {0} out of allowed range [{1}, {2}]",
@ -864,4 +858,60 @@ public abstract class AbstractRealMatrix implements RealMatrix, Serializable {
}
}
/**
* Check if a matrix is addition compatible with the instance
* @param m matrix to check
* @exception IllegalArgumentException if matrix is not addition compatible with instance
*/
protected void checkAdditionCompatible(final RealMatrix m) {
if ((getRowDimension() != m.getRowDimension()) ||
(getColumnDimension() != m.getColumnDimension())) {
throw MathRuntimeException.createIllegalArgumentException("{0}x{1} and {2}x{3} matrices are not" +
" addition compatible",
new Object[] {
getRowDimension(),
getColumnDimension(),
m.getRowDimension(),
m.getColumnDimension()
});
}
}
/**
* Check if a matrix is subtraction compatible with the instance
* @param m matrix to check
* @exception IllegalArgumentException if matrix is not subtraction compatible with instance
*/
protected void checkSubtractionCompatible(final RealMatrix m) {
if ((getRowDimension() != m.getRowDimension()) ||
(getColumnDimension() != m.getColumnDimension())) {
throw MathRuntimeException.createIllegalArgumentException("{0}x{1} and {2}x{3} matrices are not" +
" subtraction compatible",
new Object[] {
getRowDimension(),
getColumnDimension(),
m.getRowDimension(),
m.getColumnDimension()
});
}
}
/**
* Check if a matrix is multiplication compatible with the instance
* @param m matrix to check
* @exception IllegalArgumentException if matrix is not multiplication compatible with instance
*/
protected void checkMultiplicationCompatible(final RealMatrix m) {
if (getColumnDimension() != m.getRowDimension()) {
throw MathRuntimeException.createIllegalArgumentException("{0}x{1} and {2}x{3} matrices are not" +
" multiplication compatible",
new Object[] {
getRowDimension(),
getColumnDimension(),
m.getRowDimension(),
m.getColumnDimension()
});
}
}
}

View File

@ -185,16 +185,12 @@ public class RealMatrixImpl extends AbstractRealMatrix implements Serializable {
*/
public RealMatrixImpl add(final RealMatrixImpl m)
throws IllegalArgumentException {
// safety check
checkAdditionCompatible(m);
final int rowCount = getRowDimension();
final int columnCount = getColumnDimension();
if (columnCount != m.getColumnDimension() || rowCount != m.getRowDimension()) {
throw MathRuntimeException.createIllegalArgumentException("{0}x{1} and {2}x{3} matrices are not" +
" addition compatible",
new Object[] {
getRowDimension(), getColumnDimension(),
m.getRowDimension(), m.getColumnDimension()
});
}
final double[][] outData = new double[rowCount][columnCount];
for (int row = 0; row < rowCount; row++) {
final double[] dataRow = data[row];
@ -204,7 +200,9 @@ public class RealMatrixImpl extends AbstractRealMatrix implements Serializable {
outDataRow[col] = dataRow[col] + mRow[col];
}
}
return new RealMatrixImpl(outData, false);
}
/** {@inheritDoc} */
@ -226,16 +224,12 @@ public class RealMatrixImpl extends AbstractRealMatrix implements Serializable {
*/
public RealMatrixImpl subtract(final RealMatrixImpl m)
throws IllegalArgumentException {
// safety check
checkSubtractionCompatible(m);
final int rowCount = getRowDimension();
final int columnCount = getColumnDimension();
if (columnCount != m.getColumnDimension() || rowCount != m.getRowDimension()) {
throw MathRuntimeException.createIllegalArgumentException("{0}x{1} and {2}x{3} matrices are not" +
" subtraction compatible",
new Object[] {
getRowDimension(), getColumnDimension(),
m.getRowDimension(), m.getColumnDimension()
});
}
final double[][] outData = new double[rowCount][columnCount];
for (int row = 0; row < rowCount; row++) {
final double[] dataRow = data[row];
@ -245,7 +239,9 @@ public class RealMatrixImpl extends AbstractRealMatrix implements Serializable {
outDataRow[col] = dataRow[col] - mRow[col];
}
}
return new RealMatrixImpl(outData, false);
}
/** {@inheritDoc} */
@ -267,14 +263,10 @@ public class RealMatrixImpl extends AbstractRealMatrix implements Serializable {
*/
public RealMatrixImpl multiply(final RealMatrixImpl m)
throws IllegalArgumentException {
if (this.getColumnDimension() != m.getRowDimension()) {
throw MathRuntimeException.createIllegalArgumentException("{0}x{1} and {2}x{3} matrices are not" +
" multiplication compatible",
new Object[] {
getRowDimension(), getColumnDimension(),
m.getRowDimension(), m.getColumnDimension()
});
}
// safety check
checkMultiplicationCompatible(m);
final int nRows = this.getRowDimension();
final int nCols = m.getColumnDimension();
final int nSum = this.getColumnDimension();
@ -290,7 +282,9 @@ public class RealMatrixImpl extends AbstractRealMatrix implements Serializable {
outDataRow[col] = sum;
}
}
return new RealMatrixImpl(outData, false);
}
/** {@inheritDoc} */

View File

@ -0,0 +1,188 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math.linear;
import org.apache.commons.math.util.OpenIntToDoubleHashMap;
/**
* Sparse matrix implementation based on an open addressed map.
*
* @version $Revision$ $Date$
* @since 2.0
*/
public class SparseRealMatrix extends AbstractRealMatrix {
/** Serializable version identifier. */
private static final long serialVersionUID = -5962461716457143437L;
/** Number of rows of the matrix. */
private final int rowDimension;
/** Number of columns of the matrix. */
private final int columnDimension;
/** Storage for (sparse) matrix elements. */
private OpenIntToDoubleHashMap entries;
/**
* Build a sparse matrix with the supplied row and column dimensions.
* @param rowDimension number of rows of the matrix
* @param columnDimension number of columns of the matrix
*/
public SparseRealMatrix(int rowDimension, int columnDimension) {
super(rowDimension, columnDimension);
this.rowDimension = rowDimension;
this.columnDimension = columnDimension;
this.entries = new OpenIntToDoubleHashMap();
}
/**
* Build a matrix by copying another one.
* @param matrix matrix to copy
*/
public SparseRealMatrix(SparseRealMatrix matrix) {
this.rowDimension = matrix.rowDimension;
this.columnDimension = matrix.columnDimension;
this.entries = new OpenIntToDoubleHashMap(matrix.entries);
}
/** {@inheritDoc} */
@Override
public RealMatrix copy() {
return new SparseRealMatrix(this);
}
/** {@inheritDoc} */
@Override
public RealMatrix createMatrix(int rowDimension, int columnDimension)
throws IllegalArgumentException {
return new SparseRealMatrix(rowDimension, columnDimension);
}
/** {@inheritDoc} */
@Override
public int getColumnDimension() {
return columnDimension;
}
/** {@inheritDoc} */
public RealMatrix add(final RealMatrix m)
throws IllegalArgumentException {
try {
return add((SparseRealMatrix) m);
} catch (ClassCastException cce) {
return super.add(m);
}
}
/**
* Compute the sum of this and <code>m</code>.
*
* @param m matrix to be added
* @return this + m
* @throws IllegalArgumentException if m is not the same size as this
*/
public RealMatrix add(SparseRealMatrix m) throws IllegalArgumentException {
// safety check
checkAdditionCompatible(m);
final RealMatrix out = new SparseRealMatrix(this);
for (OpenIntToDoubleHashMap.Iterator iterator = m.entries.iterator(); iterator.hasNext();) {
final OpenIntToDoubleHashMap.Entry entry = iterator.next();
final int row = entry.key() / columnDimension;
final int col = entry.key() - row * columnDimension;
out.setEntry(row, col, getEntry(row, col) + entry.value());
}
return out;
}
/** {@inheritDoc} */
public RealMatrix subtract(final RealMatrix m)
throws IllegalArgumentException {
try {
return subtract((SparseRealMatrix) m);
} catch (ClassCastException cce) {
return super.add(m);
}
}
/**
* Compute this minus <code>m</code>.
*
* @param m matrix to be subtracted
* @return this - m
* @throws IllegalArgumentException if m is not the same size as this
*/
public RealMatrix subtract(SparseRealMatrix m) throws IllegalArgumentException {
// safety check
checkAdditionCompatible(m);
final RealMatrix out = new SparseRealMatrix(this);
for (OpenIntToDoubleHashMap.Iterator iterator = m.entries.iterator(); iterator.hasNext();) {
final OpenIntToDoubleHashMap.Entry entry = iterator.next();
final int row = entry.key() / columnDimension;
final int col = entry.key() - row * columnDimension;
out.setEntry(row, col, getEntry(row, col) - entry.value());
}
return out;
}
/** {@inheritDoc} */
@Override
public double getEntry(int row, int column) throws MatrixIndexException {
checkRowIndex(row);
checkColumnIndex(column);
return entries.get(computeKey(row, column));
}
/** {@inheritDoc} */
@Override
public int getRowDimension() {
return rowDimension;
}
/** {@inheritDoc} */
@Override
public void setEntry(int row, int column, double value)
throws MatrixIndexException {
checkRowIndex(row);
checkColumnIndex(column);
if (value == 0.0) {
entries.remove(computeKey(row, column));
} else {
entries.put(computeKey(row, column), value);
}
}
/**
* Compute the key to access a matrix element
* @param row row index of the matrix element
* @param column column index of the matrix element
* @return key within the map to access the matrix element
*/
private int computeKey(int row, int column) {
return row * columnDimension + column;
}
}

View File

@ -39,6 +39,9 @@ The <action> type attribute can be add,update,fix,remove.
</properties>
<body>
<release version="2.0" date="TBD" description="TBD">
<action dev="luc" type="add" issue="MATH-230" due-to="Sujit Pal and Ismael Juma">
Added support for sparse matrix.
</action>
<action dev="luc" type="add" due-to="Ismael Juma">
Added an int/double hash map (OpenIntToDoubleHashMap) with much smaller
memory overhead than standard java.util.Map (open addressing and no boxing).

View File

@ -29,9 +29,10 @@
<section name="3 Linear Algebra">
<subsection name="3.1 Overview" href="overview">
<p>
Currently, numerical linear algebra support in commons-math is
limited to basic operations on real matrices and vectors and
solving linear systems.
Linear algebra support in commons-math provides operations on real matrices
(both dense and sparse matrices are supported) and vectors. It features basic
operations (addition, subtraction ...) and decomposition algorithms that can
be used to solve linear systems either in exact sense and in least squares sense.
</p>
</subsection>
<subsection name="3.2 Real matrices" href="real_matrices">
@ -69,6 +70,13 @@ System.out.println(p.getColumnDimension()); // 2
RealMatrix pInverse = new LUSolver(new LUDecompositionImpl(p))).getInverse();
</source>
</p>
<p>
The two main implementations of the interface are <a
href="../apidocs/org/apache/commons/math/linear/RealMatrixImpl.html">
RealMatrixImpl</a> for dense matrices and <a
href="../apidocs/org/apache/commons/math/linear/SparseRealMatrix.html">
SparseRealMatrix</a> for sparse matrices.
</p>
</subsection>
<subsection name="3.3 Real vectors" href="real_vectors">
<p>

View File

@ -0,0 +1,672 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math.linear;
import junit.framework.Test;
import junit.framework.TestCase;
import junit.framework.TestSuite;
/**
* Test cases for the {@link SparseRealMatrix} class.
*
* @version $Revision$ $Date: 2008-11-07 06:48:13 -0800 (Fri, 07 Nov
* 2008) $
*/
public final class SparseRealMatrixTest extends TestCase {
// 3 x 3 identity matrix
protected double[][] id = { { 1d, 0d, 0d }, { 0d, 1d, 0d }, { 0d, 0d, 1d } };
// Test data for group operations
protected double[][] testData = { { 1d, 2d, 3d }, { 2d, 5d, 3d },
{ 1d, 0d, 8d } };
protected double[][] testDataLU = { { 2d, 5d, 3d }, { .5d, -2.5d, 6.5d },
{ 0.5d, 0.2d, .2d } };
protected double[][] testDataPlus2 = { { 3d, 4d, 5d }, { 4d, 7d, 5d },
{ 3d, 2d, 10d } };
protected double[][] testDataMinus = { { -1d, -2d, -3d },
{ -2d, -5d, -3d }, { -1d, 0d, -8d } };
protected double[] testDataRow1 = { 1d, 2d, 3d };
protected double[] testDataCol3 = { 3d, 3d, 8d };
protected double[][] testDataInv = { { -40d, 16d, 9d }, { 13d, -5d, -3d },
{ 5d, -2d, -1d } };
protected double[] preMultTest = { 8, 12, 33 };
protected double[][] testData2 = { { 1d, 2d, 3d }, { 2d, 5d, 3d } };
protected double[][] testData2T = { { 1d, 2d }, { 2d, 5d }, { 3d, 3d } };
protected double[][] testDataPlusInv = { { -39d, 18d, 12d },
{ 15d, 0d, 0d }, { 6d, -2d, 7d } };
// lu decomposition tests
protected double[][] luData = { { 2d, 3d, 3d }, { 0d, 5d, 7d }, { 6d, 9d, 8d } };
protected double[][] luDataLUDecomposition = { { 6d, 9d, 8d },
{ 0d, 5d, 7d }, { 0.33333333333333, 0d, 0.33333333333333 } };
// singular matrices
protected double[][] singular = { { 2d, 3d }, { 2d, 3d } };
protected double[][] bigSingular = { { 1d, 2d, 3d, 4d },
{ 2d, 5d, 3d, 4d }, { 7d, 3d, 256d, 1930d }, { 3d, 7d, 6d, 8d } }; // 4th
// row
// =
// 1st
// +
// 2nd
protected double[][] detData = { { 1d, 2d, 3d }, { 4d, 5d, 6d },
{ 7d, 8d, 10d } };
protected double[][] detData2 = { { 1d, 3d }, { 2d, 4d } };
// vectors
protected double[] testVector = { 1, 2, 3 };
protected double[] testVector2 = { 1, 2, 3, 4 };
// submatrix accessor tests
protected double[][] subTestData = { { 1, 2, 3, 4 },
{ 1.5, 2.5, 3.5, 4.5 }, { 2, 4, 6, 8 }, { 4, 5, 6, 7 } };
// array selections
protected double[][] subRows02Cols13 = { { 2, 4 }, { 4, 8 } };
protected double[][] subRows03Cols12 = { { 2, 3 }, { 5, 6 } };
protected double[][] subRows03Cols123 = { { 2, 3, 4 }, { 5, 6, 7 } };
// effective permutations
protected double[][] subRows20Cols123 = { { 4, 6, 8 }, { 2, 3, 4 } };
protected double[][] subRows31Cols31 = { { 7, 5 }, { 4.5, 2.5 } };
// contiguous ranges
protected double[][] subRows01Cols23 = { { 3, 4 }, { 3.5, 4.5 } };
protected double[][] subRows23Cols00 = { { 2 }, { 4 } };
protected double[][] subRows00Cols33 = { { 4 } };
// row matrices
protected double[][] subRow0 = { { 1, 2, 3, 4 } };
protected double[][] subRow3 = { { 4, 5, 6, 7 } };
// column matrices
protected double[][] subColumn1 = { { 2 }, { 2.5 }, { 4 }, { 5 } };
protected double[][] subColumn3 = { { 4 }, { 4.5 }, { 8 }, { 7 } };
// tolerances
protected double entryTolerance = 10E-16;
protected double normTolerance = 10E-14;
public SparseRealMatrixTest(String name) {
super(name);
}
public void setUp() {
}
public static Test suite() {
TestSuite suite = new TestSuite(SparseRealMatrixTest.class);
suite.setName("SparseRealMatrix Tests");
return suite;
}
/** test dimensions */
public void testDimensions() {
SparseRealMatrix m = createSparseMatrix(testData);
SparseRealMatrix m2 = createSparseMatrix(testData2);
assertEquals("testData row dimension", 3, m.getRowDimension());
assertEquals("testData column dimension", 3, m.getColumnDimension());
assertTrue("testData is square", m.isSquare());
assertEquals("testData2 row dimension", m2.getRowDimension(), 2);
assertEquals("testData2 column dimension", m2.getColumnDimension(), 3);
assertTrue("testData2 is not square", !m2.isSquare());
}
/** test copy functions */
public void testCopyFunctions() {
SparseRealMatrix m1 = createSparseMatrix(testData);
RealMatrix m2 = m1.copy();
assertTrue(m2 instanceof SparseRealMatrix);
assertEquals(((SparseRealMatrix) m2), m1);
SparseRealMatrix m3 = createSparseMatrix(testData);
RealMatrix m4 = m3.copy();
assertTrue(m4 instanceof SparseRealMatrix);
assertEquals(((SparseRealMatrix) m4), m3);
}
/** test add */
public void testAdd() {
SparseRealMatrix m = createSparseMatrix(testData);
SparseRealMatrix mInv = createSparseMatrix(testDataInv);
SparseRealMatrix mDataPlusInv = createSparseMatrix(testDataPlusInv);
RealMatrix mPlusMInv = m.add(mInv);
for (int row = 0; row < m.getRowDimension(); row++) {
for (int col = 0; col < m.getColumnDimension(); col++) {
assertEquals("sum entry entry",
mDataPlusInv.getEntry(row, col), mPlusMInv.getEntry(row, col),
entryTolerance);
}
}
}
/** test add failure */
public void testAddFail() {
SparseRealMatrix m = createSparseMatrix(testData);
SparseRealMatrix m2 = createSparseMatrix(testData2);
try {
m.add(m2);
fail("IllegalArgumentException expected");
} catch (IllegalArgumentException ex) {
;
}
}
/** test norm */
public void testNorm() {
SparseRealMatrix m = createSparseMatrix(testData);
SparseRealMatrix m2 = createSparseMatrix(testData2);
assertEquals("testData norm", 14d, m.getNorm(), entryTolerance);
assertEquals("testData2 norm", 7d, m2.getNorm(), entryTolerance);
}
/** test m-n = m + -n */
public void testPlusMinus() {
SparseRealMatrix m = createSparseMatrix(testData);
SparseRealMatrix n = createSparseMatrix(testDataInv);
assertClose("m-n = m + -n", m.subtract(n),
n.scalarMultiply(-1d).add(m), entryTolerance);
try {
m.subtract(createSparseMatrix(testData2));
fail("Expecting illegalArgumentException");
} catch (IllegalArgumentException ex) {
;
}
}
/** test multiply */
public void testMultiply() {
SparseRealMatrix m = createSparseMatrix(testData);
SparseRealMatrix mInv = createSparseMatrix(testDataInv);
SparseRealMatrix identity = createSparseMatrix(id);
SparseRealMatrix m2 = createSparseMatrix(testData2);
assertClose("inverse multiply", m.multiply(mInv), identity,
entryTolerance);
assertClose("inverse multiply", mInv.multiply(m), identity,
entryTolerance);
assertClose("identity multiply", m.multiply(identity), m,
entryTolerance);
assertClose("identity multiply", identity.multiply(mInv), mInv,
entryTolerance);
assertClose("identity multiply", m2.multiply(identity), m2,
entryTolerance);
try {
m.multiply(createSparseMatrix(bigSingular));
fail("Expecting illegalArgumentException");
} catch (IllegalArgumentException ex) {
;
}
}
// Additional Test for RealMatrixImplTest.testMultiply
private double[][] d3 = new double[][] { { 1, 2, 3, 4 }, { 5, 6, 7, 8 } };
private double[][] d4 = new double[][] { { 1 }, { 2 }, { 3 }, { 4 } };
private double[][] d5 = new double[][] { { 30 }, { 70 } };
public void testMultiply2() {
RealMatrix m3 = createSparseMatrix(d3);
RealMatrix m4 = createSparseMatrix(d4);
RealMatrix m5 = createSparseMatrix(d5);
assertClose("m3*m4=m5", m3.multiply(m4), m5, entryTolerance);
}
/** test trace */
public void testTrace() {
RealMatrix m = createSparseMatrix(id);
assertEquals("identity trace", 3d, m.getTrace(), entryTolerance);
m = createSparseMatrix(testData2);
try {
m.getTrace();
fail("Expecting NonSquareMatrixException");
} catch (NonSquareMatrixException ex) {
;
}
}
/** test sclarAdd */
public void testScalarAdd() {
RealMatrix m = createSparseMatrix(testData);
assertClose("scalar add", createSparseMatrix(testDataPlus2),
m.scalarAdd(2d), entryTolerance);
}
/** test operate */
public void testOperate() {
RealMatrix m = createSparseMatrix(id);
assertClose("identity operate", testVector, m.operate(testVector),
entryTolerance);
assertClose("identity operate", testVector, m.operate(
new RealVectorImpl(testVector)).getData(), entryTolerance);
m = createSparseMatrix(bigSingular);
try {
m.operate(testVector);
fail("Expecting illegalArgumentException");
} catch (IllegalArgumentException ex) {
;
}
}
/** test issue MATH-209 */
public void testMath209() {
RealMatrix a = createSparseMatrix(new double[][] {
{ 1, 2 }, { 3, 4 }, { 5, 6 } });
double[] b = a.operate(new double[] { 1, 1 });
assertEquals(a.getRowDimension(), b.length);
assertEquals(3.0, b[0], 1.0e-12);
assertEquals(7.0, b[1], 1.0e-12);
assertEquals(11.0, b[2], 1.0e-12);
}
/** test transpose */
public void testTranspose() {
RealMatrix m = createSparseMatrix(testData);
RealMatrix mIT = new LUSolver(new LUDecompositionImpl(m)).getInverse().transpose();
RealMatrix mTI = new LUSolver(new LUDecompositionImpl(m.transpose())).getInverse();
assertClose("inverse-transpose", mIT, mTI, normTolerance);
m = createSparseMatrix(testData2);
RealMatrix mt = createSparseMatrix(testData2T);
assertClose("transpose",mt,m.transpose(),normTolerance);
}
/** test preMultiply by vector */
public void testPremultiplyVector() {
RealMatrix m = createSparseMatrix(testData);
assertClose("premultiply", m.preMultiply(testVector), preMultTest,
normTolerance);
assertClose("premultiply", m.preMultiply(
new RealVectorImpl(testVector).getData()), preMultTest, normTolerance);
m = createSparseMatrix(bigSingular);
try {
m.preMultiply(testVector);
fail("expecting IllegalArgumentException");
} catch (IllegalArgumentException ex) {
;
}
}
public void testPremultiply() {
RealMatrix m3 = createSparseMatrix(d3);
RealMatrix m4 = createSparseMatrix(d4);
RealMatrix m5 = createSparseMatrix(d5);
assertClose("m3*m4=m5", m4.preMultiply(m3), m5, entryTolerance);
SparseRealMatrix m = createSparseMatrix(testData);
SparseRealMatrix mInv = createSparseMatrix(testDataInv);
SparseRealMatrix identity = createSparseMatrix(id);
assertClose("inverse multiply", m.preMultiply(mInv), identity,
entryTolerance);
assertClose("inverse multiply", mInv.preMultiply(m), identity,
entryTolerance);
assertClose("identity multiply", m.preMultiply(identity), m,
entryTolerance);
assertClose("identity multiply", identity.preMultiply(mInv), mInv,
entryTolerance);
try {
m.preMultiply(createSparseMatrix(bigSingular));
fail("Expecting illegalArgumentException");
} catch (IllegalArgumentException ex) {
;
}
}
public void testGetVectors() {
RealMatrix m = createSparseMatrix(testData);
assertClose("get row", m.getRow(0), testDataRow1, entryTolerance);
assertClose("get col", m.getColumn(2), testDataCol3, entryTolerance);
try {
m.getRow(10);
fail("expecting MatrixIndexException");
} catch (MatrixIndexException ex) {
;
}
try {
m.getColumn(-1);
fail("expecting MatrixIndexException");
} catch (MatrixIndexException ex) {
;
}
}
public void testGetEntry() {
RealMatrix m = createSparseMatrix(testData);
assertEquals("get entry", m.getEntry(0, 1), 2d, entryTolerance);
try {
m.getEntry(10, 4);
fail("Expecting MatrixIndexException");
} catch (MatrixIndexException ex) {
// expected
}
}
/** test examples in user guide */
public void testExamples() {
// Create a real matrix with two rows and three columns
double[][] matrixData = { { 1d, 2d, 3d }, { 2d, 5d, 3d } };
RealMatrix m = createSparseMatrix(matrixData);
// One more with three rows, two columns
double[][] matrixData2 = { { 1d, 2d }, { 2d, 5d }, { 1d, 7d } };
RealMatrix n = createSparseMatrix(matrixData2);
// Now multiply m by n
RealMatrix p = m.multiply(n);
assertEquals(2, p.getRowDimension());
assertEquals(2, p.getColumnDimension());
// Invert p
RealMatrix pInverse = new LUSolver(new LUDecompositionImpl(p)).getInverse();
assertEquals(2, pInverse.getRowDimension());
assertEquals(2, pInverse.getColumnDimension());
// Solve example
double[][] coefficientsData = { { 2, 3, -2 }, { -1, 7, 6 },
{ 4, -3, -5 } };
RealMatrix coefficients = createSparseMatrix(coefficientsData);
double[] constants = { 1, -2, 1 };
double[] solution = new LUSolver(new LUDecompositionImpl(coefficients)).solve(constants);
assertEquals(2 * solution[0] + 3 * solution[1] - 2 * solution[2],
constants[0], 1E-12);
assertEquals(-1 * solution[0] + 7 * solution[1] + 6 * solution[2],
constants[1], 1E-12);
assertEquals(4 * solution[0] - 3 * solution[1] - 5 * solution[2],
constants[2], 1E-12);
}
// test submatrix accessors
public void testSubMatrix() {
RealMatrix m = createSparseMatrix(subTestData);
RealMatrix mRows23Cols00 = createSparseMatrix(subRows23Cols00);
RealMatrix mRows00Cols33 = createSparseMatrix(subRows00Cols33);
RealMatrix mRows01Cols23 = createSparseMatrix(subRows01Cols23);
RealMatrix mRows02Cols13 = createSparseMatrix(subRows02Cols13);
RealMatrix mRows03Cols12 = createSparseMatrix(subRows03Cols12);
RealMatrix mRows03Cols123 = createSparseMatrix(subRows03Cols123);
RealMatrix mRows20Cols123 = createSparseMatrix(subRows20Cols123);
RealMatrix mRows31Cols31 = createSparseMatrix(subRows31Cols31);
assertEquals("Rows23Cols00", mRows23Cols00, m.getSubMatrix(2, 3, 0, 0));
assertEquals("Rows00Cols33", mRows00Cols33, m.getSubMatrix(0, 0, 3, 3));
assertEquals("Rows01Cols23", mRows01Cols23, m.getSubMatrix(0, 1, 2, 3));
assertEquals("Rows02Cols13", mRows02Cols13,
m.getSubMatrix(new int[] { 0, 2 }, new int[] { 1, 3 }));
assertEquals("Rows03Cols12", mRows03Cols12,
m.getSubMatrix(new int[] { 0, 3 }, new int[] { 1, 2 }));
assertEquals("Rows03Cols123", mRows03Cols123,
m.getSubMatrix(new int[] { 0, 3 }, new int[] { 1, 2, 3 }));
assertEquals("Rows20Cols123", mRows20Cols123,
m.getSubMatrix(new int[] { 2, 0 }, new int[] { 1, 2, 3 }));
assertEquals("Rows31Cols31", mRows31Cols31,
m.getSubMatrix(new int[] { 3, 1 }, new int[] { 3, 1 }));
assertEquals("Rows31Cols31", mRows31Cols31,
m.getSubMatrix(new int[] { 3, 1 }, new int[] { 3, 1 }));
try {
m.getSubMatrix(1, 0, 2, 4);
fail("Expecting MatrixIndexException");
} catch (MatrixIndexException ex) {
// expected
}
try {
m.getSubMatrix(-1, 1, 2, 2);
fail("Expecting MatrixIndexException");
} catch (MatrixIndexException ex) {
// expected
}
try {
m.getSubMatrix(1, 0, 2, 2);
fail("Expecting MatrixIndexException");
} catch (MatrixIndexException ex) {
// expected
}
try {
m.getSubMatrix(1, 0, 2, 4);
fail("Expecting MatrixIndexException");
} catch (MatrixIndexException ex) {
// expected
}
try {
m.getSubMatrix(new int[] {}, new int[] { 0 });
fail("Expecting MatrixIndexException");
} catch (MatrixIndexException ex) {
// expected
}
try {
m.getSubMatrix(new int[] { 0 }, new int[] { 4 });
fail("Expecting MatrixIndexException");
} catch (MatrixIndexException ex) {
// expected
}
}
public void testGetRowMatrix() {
RealMatrix m = createSparseMatrix(subTestData);
RealMatrix mRow0 = createSparseMatrix(subRow0);
RealMatrix mRow3 = createSparseMatrix(subRow3);
assertEquals("Row0", mRow0, m.getRowMatrix(0));
assertEquals("Row3", mRow3, m.getRowMatrix(3));
try {
m.getRowMatrix(-1);
fail("Expecting MatrixIndexException");
} catch (MatrixIndexException ex) {
// expected
}
try {
m.getRowMatrix(4);
fail("Expecting MatrixIndexException");
} catch (MatrixIndexException ex) {
// expected
}
}
public void testGetColumnMatrix() {
RealMatrix m = createSparseMatrix(subTestData);
RealMatrix mColumn1 = createSparseMatrix(subColumn1);
RealMatrix mColumn3 = createSparseMatrix(subColumn3);
assertEquals("Column1", mColumn1, m.getColumnMatrix(1));
assertEquals("Column3", mColumn3, m.getColumnMatrix(3));
try {
m.getColumnMatrix(-1);
fail("Expecting MatrixIndexException");
} catch (MatrixIndexException ex) {
// expected
}
try {
m.getColumnMatrix(4);
fail("Expecting MatrixIndexException");
} catch (MatrixIndexException ex) {
// expected
}
}
public void testGetRowVector() {
RealMatrix m = createSparseMatrix(subTestData);
RealVector mRow0 = new RealVectorImpl(subRow0[0]);
RealVector mRow3 = new RealVectorImpl(subRow3[0]);
assertEquals("Row0", mRow0, m.getRowVector(0));
assertEquals("Row3", mRow3, m.getRowVector(3));
try {
m.getRowVector(-1);
fail("Expecting MatrixIndexException");
} catch (MatrixIndexException ex) {
// expected
}
try {
m.getRowVector(4);
fail("Expecting MatrixIndexException");
} catch (MatrixIndexException ex) {
// expected
}
}
public void testGetColumnVector() {
RealMatrix m = createSparseMatrix(subTestData);
RealVector mColumn1 = columnToVector(subColumn1);
RealVector mColumn3 = columnToVector(subColumn3);
assertEquals("Column1", mColumn1, m.getColumnVector(1));
assertEquals("Column3", mColumn3, m.getColumnVector(3));
try {
m.getColumnVector(-1);
fail("Expecting MatrixIndexException");
} catch (MatrixIndexException ex) {
// expected
}
try {
m.getColumnVector(4);
fail("Expecting MatrixIndexException");
} catch (MatrixIndexException ex) {
// expected
}
}
private RealVector columnToVector(double[][] column) {
double[] data = new double[column.length];
for (int i = 0; i < data.length; ++i) {
data[i] = column[i][0];
}
return new RealVectorImpl(data, false);
}
public void testEqualsAndHashCode() {
SparseRealMatrix m = createSparseMatrix(testData);
SparseRealMatrix m1 = (SparseRealMatrix) m.copy();
SparseRealMatrix mt = (SparseRealMatrix) m.transpose();
assertTrue(m.hashCode() != mt.hashCode());
assertEquals(m.hashCode(), m1.hashCode());
assertEquals(m, m);
assertEquals(m, m1);
assertFalse(m.equals(null));
assertFalse(m.equals(mt));
assertFalse(m.equals(createSparseMatrix(bigSingular)));
}
public void testToString() {
SparseRealMatrix m = createSparseMatrix(testData);
assertEquals("SparseRealMatrix{{1.0,2.0,3.0},{2.0,5.0,3.0},{1.0,0.0,8.0}}",
m.toString());
m = new SparseRealMatrix(1, 1);
assertEquals("SparseRealMatrix{{0.0}}", m.toString());
}
public void testSetSubMatrix() throws Exception {
SparseRealMatrix m = createSparseMatrix(testData);
m.setSubMatrix(detData2, 1, 1);
RealMatrix expected = createSparseMatrix(new double[][] {
{ 1.0, 2.0, 3.0 }, { 2.0, 1.0, 3.0 }, { 1.0, 2.0, 4.0 } });
assertEquals(expected, m);
m.setSubMatrix(detData2, 0, 0);
expected = createSparseMatrix(new double[][] {
{ 1.0, 3.0, 3.0 }, { 2.0, 4.0, 3.0 }, { 1.0, 2.0, 4.0 } });
assertEquals(expected, m);
m.setSubMatrix(testDataPlus2, 0, 0);
expected = createSparseMatrix(new double[][] {
{ 3.0, 4.0, 5.0 }, { 4.0, 7.0, 5.0 }, { 3.0, 2.0, 10.0 } });
assertEquals(expected, m);
// javadoc example
SparseRealMatrix matrix =
(SparseRealMatrix) createSparseMatrix(new double[][] {
{ 1, 2, 3, 4 }, { 5, 6, 7, 8 }, { 9, 0, 1, 2 } });
matrix.setSubMatrix(new double[][] { { 3, 4 }, { 5, 6 } }, 1, 1);
expected = createSparseMatrix(new double[][] {
{ 1, 2, 3, 4 }, { 5, 3, 4, 8 }, { 9, 5, 6, 2 } });
assertEquals(expected, matrix);
// dimension overflow
try {
m.setSubMatrix(testData, 1, 1);
fail("expecting MatrixIndexException");
} catch (MatrixIndexException e) {
// expected
}
// dimension underflow
try {
m.setSubMatrix(testData, -1, 1);
fail("expecting MatrixIndexException");
} catch (MatrixIndexException e) {
// expected
}
try {
m.setSubMatrix(testData, 1, -1);
fail("expecting MatrixIndexException");
} catch (MatrixIndexException e) {
// expected
}
// null
try {
m.setSubMatrix(null, 1, 1);
fail("expecting NullPointerException");
} catch (NullPointerException e) {
// expected
}
try {
new SparseRealMatrix(0, 0);
fail("expecting IllegalArgumentException");
} catch (IllegalArgumentException e) {
// expected
}
// ragged
try {
m.setSubMatrix(new double[][] { { 1 }, { 2, 3 } }, 0, 0);
fail("expecting IllegalArgumentException");
} catch (IllegalArgumentException e) {
// expected
}
// empty
try {
m.setSubMatrix(new double[][] { {} }, 0, 0);
fail("expecting IllegalArgumentException");
} catch (IllegalArgumentException e) {
// expected
}
}
// --------------- -----------------Protected methods
/** verifies that two matrices are close (1-norm) */
protected void assertClose(String msg, RealMatrix m, RealMatrix n,
double tolerance) {
assertTrue(msg, m.subtract(n).getNorm() < tolerance);
}
/** verifies that two vectors are close (sup norm) */
protected void assertClose(String msg, double[] m, double[] n,
double tolerance) {
if (m.length != n.length) {
fail("vectors not same length");
}
for (int i = 0; i < m.length; i++) {
assertEquals(msg + " " + i + " elements differ", m[i], n[i],
tolerance);
}
}
private SparseRealMatrix createSparseMatrix(double[][] data) {
SparseRealMatrix matrix = new SparseRealMatrix(data.length, data[0].length);
for (int row = 0; row < data.length; row++) {
for (int col = 0; col < data[row].length; col++) {
matrix.setEntry(row, col, data[row][col]);
}
}
return matrix;
}
}