Greatly improved multiplication speed for sparse matrices

Jira: MATH-248

git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@762117 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Luc Maisonobe 2009-04-05 16:53:35 +00:00
parent 7dcc9d15c5
commit a2b1aa1695
3 changed files with 77 additions and 3 deletions

View File

@ -150,6 +150,75 @@ public class SparseRealMatrix extends AbstractRealMatrix {
}
/** {@inheritDoc} */
@Override
public RealMatrix multiply(final RealMatrix m)
throws IllegalArgumentException {
try {
return multiply((SparseRealMatrix) m);
} catch (ClassCastException cce) {
// safety check
checkMultiplicationCompatible(m);
final int outCols = m.getColumnDimension();
final DenseRealMatrix out = new DenseRealMatrix(rowDimension, outCols);
for (OpenIntToDoubleHashMap.Iterator iterator = entries.iterator(); iterator.hasNext();) {
iterator.advance();
final double value = iterator.value();
final int key = iterator.key();
final int i = key / columnDimension;
final int k = key % columnDimension;
for (int j = 0; j < outCols; ++j) {
out.addToEntry(i, j, value * m.getEntry(k, j));
}
}
return out;
}
}
/**
* Returns the result of postmultiplying this by m.
*
* @param m matrix to postmultiply by
* @return this * m
* @throws IllegalArgumentException
* if columnDimension(this) != rowDimension(m)
*/
public SparseRealMatrix multiply(SparseRealMatrix m) throws IllegalArgumentException {
// safety check
checkMultiplicationCompatible(m);
final int outCols = m.getColumnDimension();
SparseRealMatrix out = new SparseRealMatrix(rowDimension, outCols);
for (OpenIntToDoubleHashMap.Iterator iterator = entries.iterator(); iterator.hasNext();) {
iterator.advance();
final double value = iterator.value();
final int key = iterator.key();
final int i = key / columnDimension;
final int k = key % columnDimension;
for (int j = 0; j < outCols; ++j) {
final int rightKey = m.computeKey(k, j);
if (m.entries.containsKey(rightKey)) {
final int outKey = out.computeKey(i, j);
final double outValue =
out.entries.get(outKey) + value * m.entries.get(rightKey);
if (outValue == 0.0) {
out.entries.remove(outKey);
} else {
out.entries.put(outKey, outValue);
}
}
}
}
return out;
}
/** {@inheritDoc} */
@Override
public double getEntry(int row, int column) throws MatrixIndexException {

View File

@ -39,6 +39,9 @@ The <action> type attribute can be add,update,fix,remove.
</properties>
<body>
<release version="2.0" date="TBD" description="TBD">
<action dev="luc" type="fix" issue="MATH-248" >
Greatly improved multiplication speed for sparse matrices
</action>
<action dev="luc" type="fix" issue="MATH-253" due-to="Sebb">
Fixed threading issues with MathException and MathRuntimeException
</action>

View File

@ -16,13 +16,13 @@
*/
package org.apache.commons.math.linear;
import org.apache.commons.math.linear.decomposition.LUDecompositionImpl;
import org.apache.commons.math.linear.decomposition.NonSquareMatrixException;
import junit.framework.Test;
import junit.framework.TestCase;
import junit.framework.TestSuite;
import org.apache.commons.math.linear.decomposition.LUDecompositionImpl;
import org.apache.commons.math.linear.decomposition.NonSquareMatrixException;
/**
* Test cases for the {@link SparseRealMatrix} class.
*
@ -196,6 +196,8 @@ public final class SparseRealMatrixTest extends TestCase {
SparseRealMatrix m2 = createSparseMatrix(testData2);
assertClose("inverse multiply", m.multiply(mInv), identity,
entryTolerance);
assertClose("inverse multiply", m.multiply(new DenseRealMatrix(testDataInv)), identity,
entryTolerance);
assertClose("inverse multiply", mInv.multiply(m), identity,
entryTolerance);
assertClose("identity multiply", m.multiply(identity), m,