Remove NUMERIC_HASHING_THRESHOLD (#10313)

* Make NUMERIC_HASHING_THRESHOLD configurable

Change the default numeric hashing threshold to 1 and make it configurable.

Benchmarks attached to this PR show that binary searches are not more faster
than doing a set contains check. The attached flamegraph shows the amount of
time a query spent in the binary search. Given the benchmarks, we can expect
to see roughly a 2x speed up in this part of the query which works out to
~ a 10% faster query in this instance.

* Remove NUMERIC_HASHING_THRESHOLD

* Remove stale docs
This commit is contained in:
Suneet Saldanha 2020-08-25 20:05:39 -07:00 committed by GitHub
parent 91bb27cdf7
commit a9de00d43a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 145 additions and 50 deletions

View File

@ -0,0 +1,103 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.benchmark;
import it.unimi.dsi.fastutil.longs.LongArraySet;
import it.unimi.dsi.fastutil.longs.LongOpenHashSet;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.infra.Blackhole;
import java.util.Arrays;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
@State(Scope.Benchmark)
@BenchmarkMode(Mode.AverageTime)
@Warmup(iterations = 5)
@Measurement(iterations = 10)
@Fork(value = 1)
public class ContainsBenchmark
{
private static final long[] LONGS;
private static final long[] SORTED_LONGS;
private static final LongOpenHashSet LONG_HASH_SET;
private static final LongArraySet LONG_ARRAY_SET;
private long worstSearchValue;
private long worstSearchValueBin;
static {
LONGS = new long[16];
for (int i = 0; i < LONGS.length; i++) {
LONGS[i] = ThreadLocalRandom.current().nextInt(Short.MAX_VALUE);
}
LONG_HASH_SET = new LongOpenHashSet(LONGS);
LONG_ARRAY_SET = new LongArraySet(LONGS);
SORTED_LONGS = Arrays.copyOf(LONGS, LONGS.length);
Arrays.sort(SORTED_LONGS);
}
@Setup
public void setUp()
{
worstSearchValue = LONGS[LONGS.length - 1];
worstSearchValueBin = SORTED_LONGS[(SORTED_LONGS.length - 1) >>> 1];
}
@Benchmark
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
public void linearSearch(Blackhole blackhole)
{
boolean found = LONG_ARRAY_SET.contains(worstSearchValue);
blackhole.consume(found);
}
@Benchmark
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
public void hashSetSearch(Blackhole blackhole)
{
boolean found = LONG_HASH_SET.contains(worstSearchValueBin);
blackhole.consume(found);
}
@Benchmark
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
public void binarySearch(Blackhole blackhole)
{
boolean found = Arrays.binarySearch(SORTED_LONGS, worstSearchValueBin) >= 0;
blackhole.consume(found);
}
}

View File

@ -66,7 +66,6 @@ import org.apache.druid.segment.vector.VectorColumnSelectorFactory;
import javax.annotation.Nullable;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashSet;
@ -78,10 +77,6 @@ import java.util.Set;
public class InDimFilter extends AbstractOptimizableDimFilter implements Filter
{
// determined through benchmark that binary search on long[] is faster than HashSet until ~16 elements
// Hashing threshold is not applied to String for now, String still uses ImmutableSortedSet
public static final int NUMERIC_HASHING_THRESHOLD = 16;
// Values can contain `null` object
private final Set<String> values;
private final String dimension;
@ -113,6 +108,25 @@ public class InDimFilter extends AbstractOptimizableDimFilter implements Filter
);
}
/**
*
* @param dimension
* @param values This collection instance can be reused if possible to avoid copying a big collection.
* Callers should <b>not</b> modify the collection after it is passed to this constructor.
*/
public InDimFilter(
String dimension,
Set<String> values
)
{
this(
dimension,
values,
null,
null
);
}
/**
* This constructor should be called only in unit tests since accepting a Collection makes copying more likely.
*/
@ -483,14 +497,9 @@ public class InDimFilter extends AbstractOptimizableDimFilter implements Filter
}
}
if (longs.size() > NUMERIC_HASHING_THRESHOLD) {
final LongOpenHashSet longHashSet = new LongOpenHashSet(longs);
return longHashSet::contains;
} else {
final long[] longArray = longs.toLongArray();
Arrays.sort(longArray);
return input -> Arrays.binarySearch(longArray, input) >= 0;
}
}
private static DruidFloatPredicate createFloatPredicate(final Set<String> values)
@ -503,16 +512,8 @@ public class InDimFilter extends AbstractOptimizableDimFilter implements Filter
}
}
if (floatBits.size() > NUMERIC_HASHING_THRESHOLD) {
final IntOpenHashSet floatBitsHashSet = new IntOpenHashSet(floatBits);
return input -> floatBitsHashSet.contains(Float.floatToIntBits(input));
} else {
final int[] floatBitsArray = floatBits.toIntArray();
Arrays.sort(floatBitsArray);
return input -> Arrays.binarySearch(floatBitsArray, Float.floatToIntBits(input)) >= 0;
}
}
private static DruidDoublePredicate createDoublePredicate(final Set<String> values)
@ -525,16 +526,8 @@ public class InDimFilter extends AbstractOptimizableDimFilter implements Filter
}
}
if (doubleBits.size() > NUMERIC_HASHING_THRESHOLD) {
final LongOpenHashSet doubleBitsHashSet = new LongOpenHashSet(doubleBits);
return input -> doubleBitsHashSet.contains(Double.doubleToLongBits(input));
} else {
final long[] doubleBitsArray = doubleBits.toLongArray();
Arrays.sort(doubleBitsArray);
return input -> Arrays.binarySearch(doubleBitsArray, Double.doubleToLongBits(input)) >= 0;
}
}
@VisibleForTesting

View File

@ -236,7 +236,7 @@ public class TopNQueryBuilder
{
final Set<String> filterValues = Sets.newHashSet(values);
filterValues.add(value);
dimFilter = new InDimFilter(dimensionName, filterValues, null, null);
dimFilter = new InDimFilter(dimensionName, filterValues);
return this;
}

View File

@ -438,9 +438,7 @@ public class JoinFilterAnalyzer
for (String correlatedBaseColumn : correlationAnalysis.getBaseColumns()) {
Filter rewrittenFilter = new InDimFilter(
correlatedBaseColumn,
newFilterValues,
null,
null
newFilterValues
).toFilter();
newFilters.add(rewrittenFilter);
}
@ -461,9 +459,7 @@ public class JoinFilterAnalyzer
Filter rewrittenFilter = new InDimFilter(
pushDownVirtualColumn.getOutputName(),
newFilterValues,
null,
null
newFilterValues
).toFilter();
newFilters.add(rewrittenFilter);
}

View File

@ -76,6 +76,7 @@ public class FloatAndDoubleFilteringTest extends BaseFilterTest
private static final String TIMESTAMP_COLUMN = "ts";
private static int EXECUTOR_NUM_THREADS = 16;
private static int EXECUTOR_NUM_TASKS = 2000;
private static final int NUM_FILTER_VALUES = 32;
private static final InputRowParser<Map<String, Object>> PARSER = new MapInputRowParser(
new TimeAndDimsParseSpec(
@ -200,8 +201,8 @@ public class FloatAndDoubleFilteringTest extends BaseFilterTest
);
// cross the hashing threshold to test hashset implementation, filter on even values
List<String> infilterValues = new ArrayList<>(InDimFilter.NUMERIC_HASHING_THRESHOLD * 2);
for (int i = 0; i < InDimFilter.NUMERIC_HASHING_THRESHOLD * 2; i++) {
List<String> infilterValues = new ArrayList<>(NUM_FILTER_VALUES);
for (int i = 0; i < NUM_FILTER_VALUES; i++) {
infilterValues.add(String.valueOf(i * 2));
}
assertFilterMatches(
@ -377,8 +378,8 @@ public class FloatAndDoubleFilteringTest extends BaseFilterTest
);
// cross the hashing threshold to test hashset implementation, filter on even values
List<String> infilterValues = new ArrayList<>(InDimFilter.NUMERIC_HASHING_THRESHOLD * 2);
for (int i = 0; i < InDimFilter.NUMERIC_HASHING_THRESHOLD * 2; i++) {
List<String> infilterValues = new ArrayList<>(NUM_FILTER_VALUES);
for (int i = 0; i < NUM_FILTER_VALUES; i++) {
infilterValues.add(String.valueOf(i * 2));
}
assertFilterMatchesMultithreaded(

View File

@ -73,6 +73,7 @@ public class LongFilteringTest extends BaseFilterTest
private static final String TIMESTAMP_COLUMN = "ts";
private static int EXECUTOR_NUM_THREADS = 16;
private static int EXECUTOR_NUM_TASKS = 2000;
private static final int NUM_FILTER_VALUES = 32;
private static final InputRowParser<Map<String, Object>> PARSER = new MapInputRowParser(
new TimeAndDimsParseSpec(
@ -245,8 +246,8 @@ public class LongFilteringTest extends BaseFilterTest
);
// cross the hashing threshold to test hashset implementation, filter on even values
List<String> infilterValues = new ArrayList<>(InDimFilter.NUMERIC_HASHING_THRESHOLD * 2);
for (int i = 0; i < InDimFilter.NUMERIC_HASHING_THRESHOLD * 2; i++) {
List<String> infilterValues = new ArrayList<>(NUM_FILTER_VALUES);
for (int i = 0; i < NUM_FILTER_VALUES; i++) {
infilterValues.add(String.valueOf(i * 2));
}
assertFilterMatches(
@ -393,8 +394,8 @@ public class LongFilteringTest extends BaseFilterTest
);
// cross the hashing threshold to test hashset implementation, filter on even values
List<String> infilterValues = new ArrayList<>(InDimFilter.NUMERIC_HASHING_THRESHOLD * 2);
for (int i = 0; i < InDimFilter.NUMERIC_HASHING_THRESHOLD * 2; i++) {
List<String> infilterValues = new ArrayList<>(NUM_FILTER_VALUES);
for (int i = 0; i < NUM_FILTER_VALUES; i++) {
infilterValues.add(String.valueOf(i * 2));
}
assertFilterMatchesMultithreaded(

View File

@ -67,6 +67,7 @@ import java.util.Map;
public class TimeFilteringTest extends BaseFilterTest
{
private static final String TIMESTAMP_COLUMN = "ts";
private static final int NUM_FILTER_VALUES = 32;
private static final InputRowParser<Map<String, Object>> PARSER = new MapInputRowParser(
new TimeAndDimsParseSpec(
@ -132,8 +133,8 @@ public class TimeFilteringTest extends BaseFilterTest
);
// cross the hashing threshold to test hashset implementation, filter on even values
List<String> infilterValues = new ArrayList<>(InDimFilter.NUMERIC_HASHING_THRESHOLD * 2);
for (int i = 0; i < InDimFilter.NUMERIC_HASHING_THRESHOLD * 2; i++) {
List<String> infilterValues = new ArrayList<>(NUM_FILTER_VALUES);
for (int i = 0; i < NUM_FILTER_VALUES; i++) {
infilterValues.add(String.valueOf(i * 2));
}
assertFilterMatches(