diff --git a/src/main/java/org/apache/commons/collections/CollectionUtils.java b/src/main/java/org/apache/commons/collections/CollectionUtils.java index da01b448d..157859170 100644 --- a/src/main/java/org/apache/commons/collections/CollectionUtils.java +++ b/src/main/java/org/apache/commons/collections/CollectionUtils.java @@ -28,6 +28,7 @@ import java.util.ListIterator; import java.util.Map; import java.util.Set; +import org.apache.commons.collections.bag.HashBag; import org.apache.commons.collections.collection.PredicatedCollection; import org.apache.commons.collections.collection.SynchronizedCollection; import org.apache.commons.collections.collection.TransformedCollection; @@ -245,6 +246,12 @@ public class CollectionUtils { * Returns a new {@link Collection} containing a minus a subset of * b. Only the elements of b that satisfy the predicate * condition, p are subtracted from a. + * + *

The cardinality of each element e in the returned {@link Collection} + * that satisfies the predicate condition will be the cardinality of e in a + * minus the cardinality of e in b, or zero, whichever is greater.

+ *

The cardinality of each element e in the returned {@link Collection} that does not + * satisfy the predicate condition will be equal to the cardinality of e in a.

* * @param a the collection to subtract from, must not be null * @param b the collection to subtract, must not be null @@ -256,12 +263,19 @@ public class CollectionUtils { * @since 4.0 * @see Collection#removeAll */ - public static Collection subtract(final Iterable a, final Iterable b, final Predicate p) { - ArrayList list = new ArrayList(); - addAll(list, a); + public static Collection subtract(final Iterable a, + final Iterable b, + final Predicate p) { + final ArrayList list = new ArrayList(); + final HashBag bag = new HashBag(); for (O element : b) { if (p.evaluate(element)) { - list.remove(element); + bag.add(element); + } + } + for (O element : a) { + if (!bag.remove(element, 1)) { + list.add(element); } } return list;