diff --git a/benchmarks/src/test/java/org/apache/druid/benchmark/ContainsBenchmark.java b/benchmarks/src/test/java/org/apache/druid/benchmark/ContainsBenchmark.java new file mode 100644 index 00000000000..4e6eb95f86b --- /dev/null +++ b/benchmarks/src/test/java/org/apache/druid/benchmark/ContainsBenchmark.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.benchmark; + +import it.unimi.dsi.fastutil.longs.LongArraySet; +import it.unimi.dsi.fastutil.longs.LongOpenHashSet; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; + +import java.util.Arrays; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; + +@State(Scope.Benchmark) +@BenchmarkMode(Mode.AverageTime) +@Warmup(iterations = 5) +@Measurement(iterations = 10) +@Fork(value = 1) +public class ContainsBenchmark +{ + private static final long[] LONGS; + private static final long[] SORTED_LONGS; + private static final LongOpenHashSet LONG_HASH_SET; + private static final LongArraySet LONG_ARRAY_SET; + + private long worstSearchValue; + private long worstSearchValueBin; + + static { + LONGS = new long[16]; + for (int i = 0; i < LONGS.length; i++) { + LONGS[i] = ThreadLocalRandom.current().nextInt(Short.MAX_VALUE); + } + + LONG_HASH_SET = new LongOpenHashSet(LONGS); + LONG_ARRAY_SET = new LongArraySet(LONGS); + SORTED_LONGS = Arrays.copyOf(LONGS, LONGS.length); + + Arrays.sort(SORTED_LONGS); + + } + + @Setup + public void setUp() + { + worstSearchValue = LONGS[LONGS.length - 1]; + worstSearchValueBin = SORTED_LONGS[(SORTED_LONGS.length - 1) >>> 1]; + } + + @Benchmark + @BenchmarkMode(Mode.AverageTime) + @OutputTimeUnit(TimeUnit.NANOSECONDS) + public void linearSearch(Blackhole blackhole) + { + boolean found = LONG_ARRAY_SET.contains(worstSearchValue); + blackhole.consume(found); + } + + @Benchmark + @BenchmarkMode(Mode.AverageTime) + @OutputTimeUnit(TimeUnit.NANOSECONDS) + public void hashSetSearch(Blackhole blackhole) + { + + boolean found = LONG_HASH_SET.contains(worstSearchValueBin); + blackhole.consume(found); + } + + @Benchmark + @BenchmarkMode(Mode.AverageTime) + @OutputTimeUnit(TimeUnit.NANOSECONDS) + public void binarySearch(Blackhole blackhole) + { + boolean found = Arrays.binarySearch(SORTED_LONGS, worstSearchValueBin) >= 0; + blackhole.consume(found); + } +} diff --git a/processing/src/main/java/org/apache/druid/query/filter/InDimFilter.java b/processing/src/main/java/org/apache/druid/query/filter/InDimFilter.java index 8b0520cd6c4..063d6d3b428 100644 --- a/processing/src/main/java/org/apache/druid/query/filter/InDimFilter.java +++ b/processing/src/main/java/org/apache/druid/query/filter/InDimFilter.java @@ -66,7 +66,6 @@ import org.apache.druid.segment.vector.VectorColumnSelectorFactory; import javax.annotation.Nullable; import java.nio.charset.StandardCharsets; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collection; import java.util.Comparator; import java.util.HashSet; @@ -78,10 +77,6 @@ import java.util.Set; public class InDimFilter extends AbstractOptimizableDimFilter implements Filter { - // determined through benchmark that binary search on long[] is faster than HashSet until ~16 elements - // Hashing threshold is not applied to String for now, String still uses ImmutableSortedSet - public static final int NUMERIC_HASHING_THRESHOLD = 16; - // Values can contain `null` object private final Set values; private final String dimension; @@ -113,6 +108,25 @@ public class InDimFilter extends AbstractOptimizableDimFilter implements Filter ); } + /** + * + * @param dimension + * @param values This collection instance can be reused if possible to avoid copying a big collection. + * Callers should not modify the collection after it is passed to this constructor. + */ + public InDimFilter( + String dimension, + Set values + ) + { + this( + dimension, + values, + null, + null + ); + } + /** * This constructor should be called only in unit tests since accepting a Collection makes copying more likely. */ @@ -483,14 +497,9 @@ public class InDimFilter extends AbstractOptimizableDimFilter implements Filter } } - if (longs.size() > NUMERIC_HASHING_THRESHOLD) { - final LongOpenHashSet longHashSet = new LongOpenHashSet(longs); - return longHashSet::contains; - } else { - final long[] longArray = longs.toLongArray(); - Arrays.sort(longArray); - return input -> Arrays.binarySearch(longArray, input) >= 0; - } + + final LongOpenHashSet longHashSet = new LongOpenHashSet(longs); + return longHashSet::contains; } private static DruidFloatPredicate createFloatPredicate(final Set values) @@ -503,16 +512,8 @@ public class InDimFilter extends AbstractOptimizableDimFilter implements Filter } } - if (floatBits.size() > NUMERIC_HASHING_THRESHOLD) { - final IntOpenHashSet floatBitsHashSet = new IntOpenHashSet(floatBits); - - return input -> floatBitsHashSet.contains(Float.floatToIntBits(input)); - } else { - final int[] floatBitsArray = floatBits.toIntArray(); - Arrays.sort(floatBitsArray); - - return input -> Arrays.binarySearch(floatBitsArray, Float.floatToIntBits(input)) >= 0; - } + final IntOpenHashSet floatBitsHashSet = new IntOpenHashSet(floatBits); + return input -> floatBitsHashSet.contains(Float.floatToIntBits(input)); } private static DruidDoublePredicate createDoublePredicate(final Set values) @@ -525,16 +526,8 @@ public class InDimFilter extends AbstractOptimizableDimFilter implements Filter } } - if (doubleBits.size() > NUMERIC_HASHING_THRESHOLD) { - final LongOpenHashSet doubleBitsHashSet = new LongOpenHashSet(doubleBits); - - return input -> doubleBitsHashSet.contains(Double.doubleToLongBits(input)); - } else { - final long[] doubleBitsArray = doubleBits.toLongArray(); - Arrays.sort(doubleBitsArray); - - return input -> Arrays.binarySearch(doubleBitsArray, Double.doubleToLongBits(input)) >= 0; - } + final LongOpenHashSet doubleBitsHashSet = new LongOpenHashSet(doubleBits); + return input -> doubleBitsHashSet.contains(Double.doubleToLongBits(input)); } @VisibleForTesting diff --git a/processing/src/main/java/org/apache/druid/query/topn/TopNQueryBuilder.java b/processing/src/main/java/org/apache/druid/query/topn/TopNQueryBuilder.java index 700ab58b0f8..6699085a1e1 100644 --- a/processing/src/main/java/org/apache/druid/query/topn/TopNQueryBuilder.java +++ b/processing/src/main/java/org/apache/druid/query/topn/TopNQueryBuilder.java @@ -236,7 +236,7 @@ public class TopNQueryBuilder { final Set filterValues = Sets.newHashSet(values); filterValues.add(value); - dimFilter = new InDimFilter(dimensionName, filterValues, null, null); + dimFilter = new InDimFilter(dimensionName, filterValues); return this; } diff --git a/processing/src/main/java/org/apache/druid/segment/join/filter/JoinFilterAnalyzer.java b/processing/src/main/java/org/apache/druid/segment/join/filter/JoinFilterAnalyzer.java index 67b1ee0b760..d3dd5cf2771 100644 --- a/processing/src/main/java/org/apache/druid/segment/join/filter/JoinFilterAnalyzer.java +++ b/processing/src/main/java/org/apache/druid/segment/join/filter/JoinFilterAnalyzer.java @@ -438,9 +438,7 @@ public class JoinFilterAnalyzer for (String correlatedBaseColumn : correlationAnalysis.getBaseColumns()) { Filter rewrittenFilter = new InDimFilter( correlatedBaseColumn, - newFilterValues, - null, - null + newFilterValues ).toFilter(); newFilters.add(rewrittenFilter); } @@ -461,9 +459,7 @@ public class JoinFilterAnalyzer Filter rewrittenFilter = new InDimFilter( pushDownVirtualColumn.getOutputName(), - newFilterValues, - null, - null + newFilterValues ).toFilter(); newFilters.add(rewrittenFilter); } diff --git a/processing/src/test/java/org/apache/druid/segment/filter/FloatAndDoubleFilteringTest.java b/processing/src/test/java/org/apache/druid/segment/filter/FloatAndDoubleFilteringTest.java index 1ef93152f36..4c597e76b31 100644 --- a/processing/src/test/java/org/apache/druid/segment/filter/FloatAndDoubleFilteringTest.java +++ b/processing/src/test/java/org/apache/druid/segment/filter/FloatAndDoubleFilteringTest.java @@ -76,6 +76,7 @@ public class FloatAndDoubleFilteringTest extends BaseFilterTest private static final String TIMESTAMP_COLUMN = "ts"; private static int EXECUTOR_NUM_THREADS = 16; private static int EXECUTOR_NUM_TASKS = 2000; + private static final int NUM_FILTER_VALUES = 32; private static final InputRowParser> PARSER = new MapInputRowParser( new TimeAndDimsParseSpec( @@ -200,8 +201,8 @@ public class FloatAndDoubleFilteringTest extends BaseFilterTest ); // cross the hashing threshold to test hashset implementation, filter on even values - List infilterValues = new ArrayList<>(InDimFilter.NUMERIC_HASHING_THRESHOLD * 2); - for (int i = 0; i < InDimFilter.NUMERIC_HASHING_THRESHOLD * 2; i++) { + List infilterValues = new ArrayList<>(NUM_FILTER_VALUES); + for (int i = 0; i < NUM_FILTER_VALUES; i++) { infilterValues.add(String.valueOf(i * 2)); } assertFilterMatches( @@ -377,8 +378,8 @@ public class FloatAndDoubleFilteringTest extends BaseFilterTest ); // cross the hashing threshold to test hashset implementation, filter on even values - List infilterValues = new ArrayList<>(InDimFilter.NUMERIC_HASHING_THRESHOLD * 2); - for (int i = 0; i < InDimFilter.NUMERIC_HASHING_THRESHOLD * 2; i++) { + List infilterValues = new ArrayList<>(NUM_FILTER_VALUES); + for (int i = 0; i < NUM_FILTER_VALUES; i++) { infilterValues.add(String.valueOf(i * 2)); } assertFilterMatchesMultithreaded( diff --git a/processing/src/test/java/org/apache/druid/segment/filter/LongFilteringTest.java b/processing/src/test/java/org/apache/druid/segment/filter/LongFilteringTest.java index 1cee93ba72e..48975d2c00b 100644 --- a/processing/src/test/java/org/apache/druid/segment/filter/LongFilteringTest.java +++ b/processing/src/test/java/org/apache/druid/segment/filter/LongFilteringTest.java @@ -73,6 +73,7 @@ public class LongFilteringTest extends BaseFilterTest private static final String TIMESTAMP_COLUMN = "ts"; private static int EXECUTOR_NUM_THREADS = 16; private static int EXECUTOR_NUM_TASKS = 2000; + private static final int NUM_FILTER_VALUES = 32; private static final InputRowParser> PARSER = new MapInputRowParser( new TimeAndDimsParseSpec( @@ -245,8 +246,8 @@ public class LongFilteringTest extends BaseFilterTest ); // cross the hashing threshold to test hashset implementation, filter on even values - List infilterValues = new ArrayList<>(InDimFilter.NUMERIC_HASHING_THRESHOLD * 2); - for (int i = 0; i < InDimFilter.NUMERIC_HASHING_THRESHOLD * 2; i++) { + List infilterValues = new ArrayList<>(NUM_FILTER_VALUES); + for (int i = 0; i < NUM_FILTER_VALUES; i++) { infilterValues.add(String.valueOf(i * 2)); } assertFilterMatches( @@ -393,8 +394,8 @@ public class LongFilteringTest extends BaseFilterTest ); // cross the hashing threshold to test hashset implementation, filter on even values - List infilterValues = new ArrayList<>(InDimFilter.NUMERIC_HASHING_THRESHOLD * 2); - for (int i = 0; i < InDimFilter.NUMERIC_HASHING_THRESHOLD * 2; i++) { + List infilterValues = new ArrayList<>(NUM_FILTER_VALUES); + for (int i = 0; i < NUM_FILTER_VALUES; i++) { infilterValues.add(String.valueOf(i * 2)); } assertFilterMatchesMultithreaded( diff --git a/processing/src/test/java/org/apache/druid/segment/filter/TimeFilteringTest.java b/processing/src/test/java/org/apache/druid/segment/filter/TimeFilteringTest.java index 0cb15ccda73..574fa0180c5 100644 --- a/processing/src/test/java/org/apache/druid/segment/filter/TimeFilteringTest.java +++ b/processing/src/test/java/org/apache/druid/segment/filter/TimeFilteringTest.java @@ -67,6 +67,7 @@ import java.util.Map; public class TimeFilteringTest extends BaseFilterTest { private static final String TIMESTAMP_COLUMN = "ts"; + private static final int NUM_FILTER_VALUES = 32; private static final InputRowParser> PARSER = new MapInputRowParser( new TimeAndDimsParseSpec( @@ -132,8 +133,8 @@ public class TimeFilteringTest extends BaseFilterTest ); // cross the hashing threshold to test hashset implementation, filter on even values - List infilterValues = new ArrayList<>(InDimFilter.NUMERIC_HASHING_THRESHOLD * 2); - for (int i = 0; i < InDimFilter.NUMERIC_HASHING_THRESHOLD * 2; i++) { + List infilterValues = new ArrayList<>(NUM_FILTER_VALUES); + for (int i = 0; i < NUM_FILTER_VALUES; i++) { infilterValues.add(String.valueOf(i * 2)); } assertFilterMatches(