diff --git a/src/main/java/org/apache/commons/lang3/EnumUtils.java b/src/main/java/org/apache/commons/lang3/EnumUtils.java index 774b31659..ca7953fb0 100644 --- a/src/main/java/org/apache/commons/lang3/EnumUtils.java +++ b/src/main/java/org/apache/commons/lang3/EnumUtils.java @@ -123,58 +123,86 @@ public static > E getEnum(Class enumClass, String enumName) *

Do not use this method if you have more than 64 values in your Enum, as this * would create a value greater than a long can hold.

* - * @param enumClass the class of the enum we are working with, not null - * @param set the set of enum values we want to convert + * @param enumClass the class of the enum we are working with, not {@code null} + * @param values the values we want to convert, not {@code null} * @param the type of the enumeration * @return a long whose binary value represents the given set of enum values. + * @throws NullPointerException if {@code enumClass} or {@code values} is {@code null} + * @throws IllegalArgumentException if {@code enumClass} is not an enum class or has more than 64 values + * @since 3.0.1 */ - public static > long generateBitVector(Class enumClass, EnumSet set) { - if (enumClass == null) { - throw new IllegalArgumentException("EnumClass must be defined."); - } - final E[] constants = enumClass.getEnumConstants(); - if (constants != null && constants.length > 64) { - throw new IllegalArgumentException("EnumClass is too big to be stored in a 64-bit value."); - } + public static > long generateBitVector(Class enumClass, Iterable values) { + checkBitVectorable(enumClass); + Validate.notNull(values); long total = 0; - if (set != null) { - if (constants != null && constants.length > 0) { - for (E constant : constants) { - if (set.contains(constant)) { - total += Math.pow(2, constant.ordinal()); - } - } - } + for (E constant : values) { + total |= (1 << constant.ordinal()); } return total; } + /** + *

Creates a long bit vector representation of the given array of Enum values.

+ * + *

This generates a value that is usable by {@link EnumUtils#processBitVector}.

+ * + *

Do not use this method if you have more than 64 values in your Enum, as this + * would create a value greater than a long can hold.

+ * + * @param enumClass the class of the enum we are working with, not {@code null} + * @param values the values we want to convert, not {@code null} + * @param the type of the enumeration + * @return a long whose binary value represents the given set of enum values. + * @throws NullPointerException if {@code enumClass} or {@code values} is {@code null} + * @throws IllegalArgumentException if {@code enumClass} is not an enum class or has more than 64 values + * @since 3.0.1 + */ + public static > long generateBitVector(Class enumClass, E... values) { + Validate.noNullElements(values); + return generateBitVector(enumClass, Arrays. asList(values)); + } + /** *

Convert a long value created by {@link EnumUtils#generateBitVector} into the set of * enum values that it represents.

* *

If you store this value, beware any changes to the enum that would affect ordinal values.

- * @param enumClass the class of the enum we are working with, not null + * @param enumClass the class of the enum we are working with, not {@code null} * @param value the long value representation of a set of enum values * @param the type of the enumeration * @return a set of enum values + * @throws NullPointerException if {@code enumClass} is {@code null} + * @throws IllegalArgumentException if {@code enumClass} is not an enum class or has more than 64 values + * @since 3.0.1 */ public static > EnumSet processBitVector(Class enumClass, long value) { - if (enumClass == null) { - throw new IllegalArgumentException("EnumClass must be defined."); - } - final E[] constants = enumClass.getEnumConstants(); - if (constants != null && constants.length > 64) { - throw new IllegalArgumentException("EnumClass is too big to be stored in a 64-bit value."); - } - final EnumSet results = EnumSet.noneOf(enumClass); - if (constants != null && constants.length > 0) { - for (E constant : constants) { - if ((value & (1 << constant.ordinal())) != 0) { - results.add(constant); - } + final E[] constants = checkBitVectorable(enumClass).getEnumConstants(); + final EnumSet results = EnumSet.noneOf(enumClass); + for (E constant : constants) { + if ((value & (1 << constant.ordinal())) != 0) { + results.add(constant); } } return results; } + + /** + * Validate that {@code enumClass} is compatible with representation in a {@code long}. + * @param the type of the enumeration + * @param enumClass to check + * @return {@code enumClass} + * @throws NullPointerException if {@code enumClass} is {@code null} + * @throws IllegalArgumentException if {@code enumClass} is not an enum class or has more than 64 values + * @since 3.0.1 + */ + private static > Class checkBitVectorable(Class enumClass) { + Validate.notNull(enumClass, "EnumClass must be defined."); + + final E[] constants = enumClass.getEnumConstants(); + Validate.isTrue(constants != null, "%s does not seem to be an Enum type", enumClass); + Validate.isTrue(constants.length <= Long.SIZE, "Cannot store %s %s values in %s bits", constants.length, + enumClass.getSimpleName(), Long.SIZE); + + return enumClass; + } } diff --git a/src/test/java/org/apache/commons/lang3/EnumUtilsTest.java b/src/test/java/org/apache/commons/lang3/EnumUtilsTest.java index f3e7e322d..de38abb60 100644 --- a/src/test/java/org/apache/commons/lang3/EnumUtilsTest.java +++ b/src/test/java/org/apache/commons/lang3/EnumUtilsTest.java @@ -20,6 +20,7 @@ import static org.junit.Assert.*; +import java.util.ArrayList; import java.util.EnumSet; import java.util.List; import java.util.Map; @@ -89,19 +90,56 @@ public void test_getEnum_nullClass() { EnumUtils.getEnum((Class) null, "PURPLE"); } - @Test(expected=IllegalArgumentException.class) + @Test(expected=NullPointerException.class) public void test_generateBitVector_nullClass() { EnumUtils.generateBitVector(null, EnumSet.of(Traffic.RED)); } + @Test(expected=NullPointerException.class) + public void test_generateBitVector_nullIterable() { + EnumUtils.generateBitVector(null, (Iterable) null); + } + + @Test(expected=NullPointerException.class) + public void test_generateBitVector_nullClassWithArray() { + EnumUtils.generateBitVector(null, Traffic.RED); + } + + @Test(expected=NullPointerException.class) + public void test_generateBitVector_nullArray() { + EnumUtils.generateBitVector(null, (Traffic[]) null); + } + @Test(expected=IllegalArgumentException.class) public void test_generateBitVector_longClass() { EnumUtils.generateBitVector(TooMany.class, EnumSet.of(TooMany.A1)); } + @Test(expected=IllegalArgumentException.class) + public void test_generateBitVector_longClassWithArray() { + EnumUtils.generateBitVector(TooMany.class, TooMany.A1); + } + + @SuppressWarnings("unchecked") + @Test(expected=IllegalArgumentException.class) + public void test_generateBitVector_nonEnumClass() { + @SuppressWarnings("rawtypes") + Class rawType = Object.class; + @SuppressWarnings("rawtypes") + List rawList = new ArrayList(); + EnumUtils.generateBitVector(rawType, rawList); + } + + @SuppressWarnings("unchecked") + @Test(expected=IllegalArgumentException.class) + public void test_generateBitVector_nonEnumClassWithArray() { + @SuppressWarnings("rawtypes") + Class rawType = Object.class; + EnumUtils.generateBitVector(rawType); + } + @Test public void test_generateBitVector() { - assertEquals(0L, EnumUtils.generateBitVector(Traffic.class, null)); assertEquals(0L, EnumUtils.generateBitVector(Traffic.class, EnumSet.noneOf(Traffic.class))); assertEquals(1L, EnumUtils.generateBitVector(Traffic.class, EnumSet.of(Traffic.RED))); assertEquals(2L, EnumUtils.generateBitVector(Traffic.class, EnumSet.of(Traffic.AMBER))); @@ -112,7 +150,21 @@ public void test_generateBitVector() { assertEquals(7L, EnumUtils.generateBitVector(Traffic.class, EnumSet.of(Traffic.RED, Traffic.AMBER, Traffic.GREEN))); } - @Test(expected=IllegalArgumentException.class) + @Test + public void test_generateBitVectorFromArray() { + assertEquals(0L, EnumUtils.generateBitVector(Traffic.class)); + assertEquals(1L, EnumUtils.generateBitVector(Traffic.class, Traffic.RED)); + assertEquals(2L, EnumUtils.generateBitVector(Traffic.class, Traffic.AMBER)); + assertEquals(4L, EnumUtils.generateBitVector(Traffic.class, Traffic.GREEN)); + assertEquals(3L, EnumUtils.generateBitVector(Traffic.class, Traffic.RED, Traffic.AMBER)); + assertEquals(5L, EnumUtils.generateBitVector(Traffic.class, Traffic.RED, Traffic.GREEN)); + assertEquals(6L, EnumUtils.generateBitVector(Traffic.class, Traffic.AMBER, Traffic.GREEN)); + assertEquals(7L, EnumUtils.generateBitVector(Traffic.class, Traffic.RED, Traffic.AMBER, Traffic.GREEN)); + //gracefully handles duplicates: + assertEquals(7L, EnumUtils.generateBitVector(Traffic.class, Traffic.RED, Traffic.AMBER, Traffic.GREEN, Traffic.GREEN)); + } + + @Test(expected=NullPointerException.class) public void test_processBitVector_nullClass() { final Class empty = null; EnumUtils.processBitVector(empty, 0L);