diff --git a/src/java/org/apache/commons/math/linear/SparseRealMatrix.java b/src/java/org/apache/commons/math/linear/SparseRealMatrix.java index e5c02625c..33f3b7e7d 100644 --- a/src/java/org/apache/commons/math/linear/SparseRealMatrix.java +++ b/src/java/org/apache/commons/math/linear/SparseRealMatrix.java @@ -150,6 +150,75 @@ public class SparseRealMatrix extends AbstractRealMatrix { } + /** {@inheritDoc} */ + @Override + public RealMatrix multiply(final RealMatrix m) + throws IllegalArgumentException { + try { + return multiply((SparseRealMatrix) m); + } catch (ClassCastException cce) { + + // safety check + checkMultiplicationCompatible(m); + + final int outCols = m.getColumnDimension(); + final DenseRealMatrix out = new DenseRealMatrix(rowDimension, outCols); + for (OpenIntToDoubleHashMap.Iterator iterator = entries.iterator(); iterator.hasNext();) { + iterator.advance(); + final double value = iterator.value(); + final int key = iterator.key(); + final int i = key / columnDimension; + final int k = key % columnDimension; + for (int j = 0; j < outCols; ++j) { + out.addToEntry(i, j, value * m.getEntry(k, j)); + } + } + + return out; + + } + } + + /** + * Returns the result of postmultiplying this by m. + * + * @param m matrix to postmultiply by + * @return this * m + * @throws IllegalArgumentException + * if columnDimension(this) != rowDimension(m) + */ + public SparseRealMatrix multiply(SparseRealMatrix m) throws IllegalArgumentException { + + // safety check + checkMultiplicationCompatible(m); + + final int outCols = m.getColumnDimension(); + SparseRealMatrix out = new SparseRealMatrix(rowDimension, outCols); + for (OpenIntToDoubleHashMap.Iterator iterator = entries.iterator(); iterator.hasNext();) { + iterator.advance(); + final double value = iterator.value(); + final int key = iterator.key(); + final int i = key / columnDimension; + final int k = key % columnDimension; + for (int j = 0; j < outCols; ++j) { + final int rightKey = m.computeKey(k, j); + if (m.entries.containsKey(rightKey)) { + final int outKey = out.computeKey(i, j); + final double outValue = + out.entries.get(outKey) + value * m.entries.get(rightKey); + if (outValue == 0.0) { + out.entries.remove(outKey); + } else { + out.entries.put(outKey, outValue); + } + } + } + } + + return out; + + } + /** {@inheritDoc} */ @Override public double getEntry(int row, int column) throws MatrixIndexException { diff --git a/src/site/xdoc/changes.xml b/src/site/xdoc/changes.xml index d6413e8c0..8f1486b76 100644 --- a/src/site/xdoc/changes.xml +++ b/src/site/xdoc/changes.xml @@ -39,6 +39,9 @@ The type attribute can be add,update,fix,remove. + + Greatly improved multiplication speed for sparse matrices + Fixed threading issues with MathException and MathRuntimeException diff --git a/src/test/org/apache/commons/math/linear/SparseRealMatrixTest.java b/src/test/org/apache/commons/math/linear/SparseRealMatrixTest.java index d83097726..21eedc98f 100644 --- a/src/test/org/apache/commons/math/linear/SparseRealMatrixTest.java +++ b/src/test/org/apache/commons/math/linear/SparseRealMatrixTest.java @@ -16,13 +16,13 @@ */ package org.apache.commons.math.linear; -import org.apache.commons.math.linear.decomposition.LUDecompositionImpl; -import org.apache.commons.math.linear.decomposition.NonSquareMatrixException; - import junit.framework.Test; import junit.framework.TestCase; import junit.framework.TestSuite; +import org.apache.commons.math.linear.decomposition.LUDecompositionImpl; +import org.apache.commons.math.linear.decomposition.NonSquareMatrixException; + /** * Test cases for the {@link SparseRealMatrix} class. * @@ -196,6 +196,8 @@ public final class SparseRealMatrixTest extends TestCase { SparseRealMatrix m2 = createSparseMatrix(testData2); assertClose("inverse multiply", m.multiply(mInv), identity, entryTolerance); + assertClose("inverse multiply", m.multiply(new DenseRealMatrix(testDataInv)), identity, + entryTolerance); assertClose("inverse multiply", mInv.multiply(m), identity, entryTolerance); assertClose("identity multiply", m.multiply(identity), m,