diff --git a/src/java/org/apache/commons/math/linear/SparseRealVector.java b/src/java/org/apache/commons/math/linear/SparseRealVector.java index a8aecd5a9..8634434a4 100644 --- a/src/java/org/apache/commons/math/linear/SparseRealVector.java +++ b/src/java/org/apache/commons/math/linear/SparseRealVector.java @@ -21,7 +21,7 @@ import org.apache.commons.math.util.OpenIntToDoubleHashMap; import org.apache.commons.math.util.OpenIntToDoubleHashMap.Iterator; /** - * This class implements the {@link RealVector} interface with a {@link OpenIntToDoubleHashMap}. + * This class implements the {@link RealVector} interface with a {@link OpenIntToDoubleHashMap} backing store. * @version $Revision: 728186 $ $Date$ * @since 2.0 */ @@ -230,23 +230,18 @@ public class SparseRealVector implements RealVector { * Optimized method to add two SparseRealVectors * @param v Vector to add with * @return The sum of this with v + * @throws IllegalArgumentException If the dimensions don't match */ - public SparseRealVector add(SparseRealVector v) { + public SparseRealVector add(SparseRealVector v) throws IllegalArgumentException{ checkVectorDimensions(v.getDimension()); - SparseRealVector res = (SparseRealVector) copy(); - Iterator iter = res.getEntries().iterator(); + SparseRealVector res = (SparseRealVector)copy(); + Iterator iter = v.getEntries().iterator(); while (iter.hasNext()) { iter.advance(); int key = iter.key(); - if (v.getEntries().containsKey(key)) { - res.setEntry(key, iter.value() + v.getEntry(key)); - } - } - iter = v.getEntries().iterator(); - while (iter.hasNext()) { - iter.advance(); - int key = iter.key(); - if (!entries.containsKey(key)) { + if (entries.containsKey(key)) { + res.setEntry(key, entries.get(key) + iter.value()); + } else { res.setEntry(key, iter.value()); } } @@ -419,8 +414,9 @@ public class SparseRealVector implements RealVector { * Optimized method to compute distance * @param v The vector to compute distance to * @return The distance from this and v + * @throws IllegalArgumentException If the dimensions don't match */ - public double getDistance(SparseRealVector v) { + public double getDistance(SparseRealVector v) throws IllegalArgumentException { Iterator iter = entries.iterator(); double res = 0; while (iter.hasNext()) { @@ -1013,8 +1009,9 @@ public class SparseRealVector implements RealVector { * Optimized method to compute the outer product * @param v The vector to comput the outer product on * @return The outer product of this and v + * @throws IllegalArgumentException If the dimensions don't match */ - public SparseRealMatrix outerproduct(SparseRealVector v){ + public SparseRealMatrix outerproduct(SparseRealVector v) throws IllegalArgumentException{ checkVectorDimensions(v.getDimension()); SparseRealMatrix res = new SparseRealMatrix(virtualSize, virtualSize); Iterator iter = entries.iterator(); @@ -1109,19 +1106,23 @@ public class SparseRealVector implements RealVector { } } - /** {@inheritDoc} */ - public SparseRealVector subtract(SparseRealVector v) { + /** + * Optimized method to subtract SparseRealVectors + * @param v The vector to subtract from this + * @return The difference of this and v + * @throws IllegalArgumentException If the dimensions don't match + */ + public SparseRealVector subtract(SparseRealVector v) throws IllegalArgumentException{ checkVectorDimensions(v.getDimension()); - SparseRealVector res = new SparseRealVector(this); + SparseRealVector res = (SparseRealVector)copy(); Iterator iter = v.getEntries().iterator(); - OpenIntToDoubleHashMap values = res.getEntries(); while (iter.hasNext()) { iter.advance(); int key = iter.key(); if (entries.containsKey(key)) { - values.put(key, entries.get(key) - iter.value()); + res.setEntry(key, entries.get(key) - iter.value()); } else { - values.put(key, -iter.value()); + res.setEntry(key, -iter.value()); } } return res; @@ -1210,4 +1211,52 @@ public class SparseRealVector implements RealVector { return getData(); } + /* (non-Javadoc) + * @see java.lang.Object#hashCode() + */ + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + long temp; + temp = Double.doubleToLongBits(epsilon); + result = prime * result + (int) (temp ^ (temp >>> 32)); + result = prime * result + virtualSize; + return result; + } + + /* (non-Javadoc) + * @see java.lang.Object#equals(java.lang.Object) + */ + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (!(obj instanceof SparseRealVector)) + return false; + SparseRealVector other = (SparseRealVector) obj; + if (virtualSize != other.virtualSize) + return false; + if (Double.doubleToLongBits(epsilon) != Double + .doubleToLongBits(other.epsilon)) + return false; + Iterator iter = entries.iterator(); + while(iter.hasNext()){ + iter.advance(); + double test = iter.value() - other.getEntry(iter.key()); + if(Math.abs(test) > epsilon) + return false; + } + iter = other.getEntries().iterator(); + while(iter.hasNext()){ + iter.advance(); + double test = iter.value() - getEntry(iter.key()); + if(!isZero(test)) + return false; + } + return true; + } + } diff --git a/src/java/org/apache/commons/math/util/OpenIntToDoubleHashMap.java b/src/java/org/apache/commons/math/util/OpenIntToDoubleHashMap.java index 9fdf27c9e..af6a6cf16 100644 --- a/src/java/org/apache/commons/math/util/OpenIntToDoubleHashMap.java +++ b/src/java/org/apache/commons/math/util/OpenIntToDoubleHashMap.java @@ -20,7 +20,6 @@ package org.apache.commons.math.util; import java.io.IOException; import java.io.ObjectInputStream; import java.io.Serializable; -import java.util.Arrays; import java.util.ConcurrentModificationException; import java.util.NoSuchElementException; @@ -476,6 +475,7 @@ public class OpenIntToDoubleHashMap implements Serializable { return h ^ (h >>> 7) ^ (h >>> 4); } + /** Iterator class for the map. */ public class Iterator { @@ -595,5 +595,5 @@ public class OpenIntToDoubleHashMap implements Serializable { count = 0; } - + } diff --git a/src/test/org/apache/commons/math/linear/SparseRealVectorTest.java b/src/test/org/apache/commons/math/linear/SparseRealVectorTest.java index fee502217..3b8c6c466 100644 --- a/src/test/org/apache/commons/math/linear/SparseRealVectorTest.java +++ b/src/test/org/apache/commons/math/linear/SparseRealVectorTest.java @@ -1124,12 +1124,11 @@ public class SparseRealVectorTest extends TestCase { v.setEntry(1, 1); assertTrue(v.isInfinite()); - //TODO: backing store doesn't implement equals //TODO: differeciate from resetting to zero - //v.setEntry(0, 0); - //assertEquals(v, new SparseRealVector(new double[] { 0, 1, 2 })); - //assertNotSame(v, new SparseRealVector(new double[] { 0, 1, 2 + Math.ulp(2)})); - //assertNotSame(v, new SparseRealVector(new double[] { 0, 1, 2, 3 })); + v.setEntry(0, 0); + assertEquals(v, new SparseRealVector(new double[] { 0, 1, 2 })); + assertNotSame(v, new SparseRealVector(new double[] { 0, 1, 2 + Math.ulp(2)})); + assertNotSame(v, new SparseRealVector(new double[] { 0, 1, 2, 3 })); //assertEquals(new SparseRealVector(new double[] { Double.NaN, 1, 2 }).hashCode(), // new SparseRealVector(new double[] { 0, Double.NaN, 2 }).hashCode());