Speed up docvalues set query by making use of sortedness (#12128)

LongHashSet is used for the set of numbers, but it has some issues:
* tries to hard to extend AbstractSet, mostly for testing
* causes traps with boxing if you aren't careful
* complex hashcode/equals

Practically we should take advantage of the fact numbers come in sorted
order for multivalued fields: just like range queries do. So we use
min/max to our advantage, including termination of docvalues iteration

Actually it is generally a win to just check min/max even in the single-valued
case: these constant time comparisons are cheap and can avoid hashing,
etc.

In the worst-case, if all of your query Sets contain both the minimum and maximum
possible values, then it won't help, but it doesn't hurt either.
This commit is contained in:
Robert Muir 2023-02-06 12:14:02 -05:00 committed by GitHub
parent a6bceb7cf0
commit 10d9c7440b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 104 additions and 76 deletions

View File

@ -16,15 +16,16 @@
*/ */
package org.apache.lucene.document; package org.apache.lucene.document;
import java.util.AbstractSet;
import java.util.Arrays; import java.util.Arrays;
import java.util.Iterator; import java.util.HashSet;
import java.util.NoSuchElementException; import java.util.Objects;
import java.util.Set;
import org.apache.lucene.util.Accountable; import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.RamUsageEstimator; import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.packed.PackedInts; import org.apache.lucene.util.packed.PackedInts;
final class LongHashSet extends AbstractSet<Long> implements Accountable { /** Set of longs, optimized for docvalues usage */
final class LongHashSet implements Accountable {
private static final long BASE_RAM_BYTES = private static final long BASE_RAM_BYTES =
RamUsageEstimator.shallowSizeOfInstance(LongHashSet.class); RamUsageEstimator.shallowSizeOfInstance(LongHashSet.class);
@ -34,8 +35,12 @@ final class LongHashSet extends AbstractSet<Long> implements Accountable {
final int mask; final int mask;
final boolean hasMissingValue; final boolean hasMissingValue;
final int size; final int size;
final int hashCode; /** minimum value in the set, or Long.MAX_VALUE for an empty set */
final long minValue;
/** maximum value in the set, or Long.MIN_VALUE for an empty set */
final long maxValue;
/** Construct a set. Values must be in sorted order. */
LongHashSet(long[] values) { LongHashSet(long[] values) {
int tableSize = Math.toIntExact(values.length * 3L / 2); int tableSize = Math.toIntExact(values.length * 3L / 2);
tableSize = 1 << PackedInts.bitsRequired(tableSize); // make it a power of 2 tableSize = 1 << PackedInts.bitsRequired(tableSize); // make it a power of 2
@ -45,19 +50,21 @@ final class LongHashSet extends AbstractSet<Long> implements Accountable {
mask = tableSize - 1; mask = tableSize - 1;
boolean hasMissingValue = false; boolean hasMissingValue = false;
int size = 0; int size = 0;
int hashCode = 0; long previousValue = Long.MIN_VALUE; // for assert
for (long value : values) { for (long value : values) {
if (value == MISSING || add(value)) { if (value == MISSING || add(value)) {
if (value == MISSING) { if (value == MISSING) {
hasMissingValue = true; hasMissingValue = true;
} }
++size; ++size;
hashCode += Long.hashCode(value);
} }
assert value >= previousValue : "values must be provided in sorted order";
previousValue = value;
} }
this.hasMissingValue = hasMissingValue; this.hasMissingValue = hasMissingValue;
this.size = size; this.size = size;
this.hashCode = hashCode; this.minValue = values.length == 0 ? Long.MAX_VALUE : values[0];
this.maxValue = values.length == 0 ? Long.MIN_VALUE : values[values.length - 1];
} }
private boolean add(long l) { private boolean add(long l) {
@ -74,6 +81,12 @@ final class LongHashSet extends AbstractSet<Long> implements Accountable {
} }
} }
/**
* check for membership in the set.
*
* <p>You should use {@link #minValue} and {@link #maxValue} to guide/terminate iteration before
* calling this.
*/
boolean contains(long l) { boolean contains(long l) {
if (l == MISSING) { if (l == MISSING) {
return hasMissingValue; return hasMissingValue;
@ -88,33 +101,49 @@ final class LongHashSet extends AbstractSet<Long> implements Accountable {
} }
} }
@Override
public int size() {
return size;
}
@Override @Override
public int hashCode() { public int hashCode() {
return hashCode; return Objects.hash(size, minValue, maxValue, mask, hasMissingValue, Arrays.hashCode(table));
} }
@Override @Override
public boolean equals(Object obj) { public boolean equals(Object obj) {
if (obj != null && obj.getClass() == LongHashSet.class) { if (obj != null && obj instanceof LongHashSet) {
LongHashSet that = (LongHashSet) obj; LongHashSet that = (LongHashSet) obj;
if (hashCode != that.hashCode return size == that.size
|| size != that.size && minValue == that.minValue
|| hasMissingValue != that.hasMissingValue) { && maxValue == that.maxValue
return false; && mask == that.mask
} && hasMissingValue == that.hasMissingValue
for (long v : table) { && Arrays.equals(table, that.table);
if (v != MISSING && that.contains(v) == false) {
return false;
}
}
return true;
} }
return super.equals(obj); return false;
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder("[");
boolean seenValue = false;
if (hasMissingValue) {
sb.append(MISSING);
seenValue = true;
}
for (long v : table) {
if (v != MISSING) {
if (seenValue) {
sb.append(", ");
}
sb.append(v);
seenValue = true;
}
}
sb.append("]");
return sb.toString();
}
/** number of elements in the set */
int size() {
return size;
} }
@Override @Override
@ -122,41 +151,17 @@ final class LongHashSet extends AbstractSet<Long> implements Accountable {
return BASE_RAM_BYTES + RamUsageEstimator.sizeOfObject(table); return BASE_RAM_BYTES + RamUsageEstimator.sizeOfObject(table);
} }
@Override // for testing only
public boolean contains(Object o) { Set<Long> toSet() {
return o instanceof Long && contains(((Long) o).longValue()); Set<Long> set = new HashSet<>();
} if (hasMissingValue) {
set.add(MISSING);
@Override }
public Iterator<Long> iterator() { for (long v : table) {
return new Iterator<Long>() { if (v != MISSING) {
set.add(v);
private boolean hasNext = hasMissingValue;
private int i = -1;
private long value = MISSING;
@Override
public boolean hasNext() {
if (hasNext) {
return true;
}
while (++i < table.length) {
value = table[i];
if (value != MISSING) {
return hasNext = true;
}
}
return false;
} }
}
@Override return set;
public Long next() {
if (hasNext() == false) {
throw new NoSuchElementException();
}
hasNext = false;
return value;
}
};
} }
} }

View File

@ -17,6 +17,7 @@
package org.apache.lucene.document; package org.apache.lucene.document;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays;
import java.util.Objects; import java.util.Objects;
import org.apache.lucene.index.DocValues; import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.LeafReaderContext;
@ -45,6 +46,7 @@ final class SortedNumericDocValuesSetQuery extends Query implements Accountable
SortedNumericDocValuesSetQuery(String field, long[] numbers) { SortedNumericDocValuesSetQuery(String field, long[] numbers) {
this.field = Objects.requireNonNull(field); this.field = Objects.requireNonNull(field);
Arrays.sort(numbers);
this.numbers = new LongHashSet(numbers); this.numbers = new LongHashSet(numbers);
} }
@ -112,12 +114,15 @@ final class SortedNumericDocValuesSetQuery extends Query implements Accountable
new TwoPhaseIterator(singleton) { new TwoPhaseIterator(singleton) {
@Override @Override
public boolean matches() throws IOException { public boolean matches() throws IOException {
return numbers.contains(singleton.longValue()); long value = singleton.longValue();
return value >= numbers.minValue
&& value <= numbers.maxValue
&& numbers.contains(value);
} }
@Override @Override
public float matchCost() { public float matchCost() {
return 5; // lookup in the set return 5; // 2 comparisions, possible lookup in the set
} }
}; };
} else { } else {
@ -127,7 +132,12 @@ final class SortedNumericDocValuesSetQuery extends Query implements Accountable
public boolean matches() throws IOException { public boolean matches() throws IOException {
int count = values.docValueCount(); int count = values.docValueCount();
for (int i = 0; i < count; i++) { for (int i = 0; i < count; i++) {
if (numbers.contains(values.nextValue())) { final long value = values.nextValue();
if (value < numbers.minValue) {
continue;
} else if (value > numbers.maxValue) {
return false; // values are sorted, terminate
} else if (numbers.contains(value)) {
return true; return true;
} }
} }
@ -136,7 +146,7 @@ final class SortedNumericDocValuesSetQuery extends Query implements Accountable
@Override @Override
public float matchCost() { public float matchCost() {
return 5; // lookup in the set return 5; // 2 comparisons, possible lookup in the set
} }
}; };
} }

View File

@ -25,11 +25,10 @@ import org.apache.lucene.tests.util.LuceneTestCase;
public class TestLongHashSet extends LuceneTestCase { public class TestLongHashSet extends LuceneTestCase {
private void assertEquals(Set<Long> set1, LongHashSet set2) { private void assertEquals(Set<Long> set1, LongHashSet longHashSet) {
Set<Long> set2 = longHashSet.toSet();
LuceneTestCase.assertEquals(set1, set2); LuceneTestCase.assertEquals(set1, set2);
LuceneTestCase.assertEquals(set2, set1);
LuceneTestCase.assertEquals(set2, set2);
assertEquals(set1.hashCode(), set2.hashCode());
if (set1.isEmpty() == false) { if (set1.isEmpty() == false) {
Set<Long> set3 = new HashSet<>(set1); Set<Long> set3 = new HashSet<>(set1);
@ -40,40 +39,53 @@ public class TestLongHashSet extends LuceneTestCase {
break; break;
} }
} }
assertNotEquals(set3, set2); assertNotEquals(set3, longHashSet);
} }
} }
private void assertNotEquals(Set<Long> set1, LongHashSet set2) { private void assertNotEquals(Set<Long> set1, LongHashSet longHashSet) {
assertFalse(set1.equals(set2)); Set<Long> set2 = longHashSet.toSet();
assertFalse(set2.equals(set1));
LongHashSet set3 = new LongHashSet(set1.stream().mapToLong(Long::longValue).toArray()); LuceneTestCase.assertNotEquals(set1, set2);
assertFalse(set2.equals(set3));
LongHashSet set3 = new LongHashSet(set1.stream().mapToLong(Long::longValue).sorted().toArray());
LuceneTestCase.assertNotEquals(set2, set3.toSet());
} }
public void testEmpty() { public void testEmpty() {
Set<Long> set1 = new HashSet<>(); Set<Long> set1 = new HashSet<>();
LongHashSet set2 = new LongHashSet(new long[] {}); LongHashSet set2 = new LongHashSet(new long[] {});
assertEquals(Long.MAX_VALUE, set2.minValue);
assertEquals(Long.MIN_VALUE, set2.maxValue);
assertEquals(set1, set2); assertEquals(set1, set2);
} }
public void testOneValue() { public void testOneValue() {
Set<Long> set1 = new HashSet<>(Arrays.asList(42L)); Set<Long> set1 = new HashSet<>(Arrays.asList(42L));
LongHashSet set2 = new LongHashSet(new long[] {42L}); LongHashSet set2 = new LongHashSet(new long[] {42L});
assertEquals(42L, set2.minValue);
assertEquals(42L, set2.maxValue);
assertEquals(set1, set2); assertEquals(set1, set2);
set1 = new HashSet<>(Arrays.asList(Long.MIN_VALUE)); set1 = new HashSet<>(Arrays.asList(Long.MIN_VALUE));
set2 = new LongHashSet(new long[] {Long.MIN_VALUE}); set2 = new LongHashSet(new long[] {Long.MIN_VALUE});
assertEquals(Long.MIN_VALUE, set2.minValue);
assertEquals(Long.MIN_VALUE, set2.maxValue);
assertEquals(set1, set2); assertEquals(set1, set2);
} }
public void testTwoValues() { public void testTwoValues() {
Set<Long> set1 = new HashSet<>(Arrays.asList(42L, Long.MAX_VALUE)); Set<Long> set1 = new HashSet<>(Arrays.asList(42L, Long.MAX_VALUE));
LongHashSet set2 = new LongHashSet(new long[] {42L, Long.MAX_VALUE}); LongHashSet set2 = new LongHashSet(new long[] {42L, Long.MAX_VALUE});
assertEquals(42, set2.minValue);
assertEquals(Long.MAX_VALUE, set2.maxValue);
assertEquals(set1, set2); assertEquals(set1, set2);
set1 = new HashSet<>(Arrays.asList(Long.MIN_VALUE, 42L)); set1 = new HashSet<>(Arrays.asList(Long.MIN_VALUE, 42L));
set2 = new LongHashSet(new long[] {Long.MIN_VALUE, 42L}); set2 = new LongHashSet(new long[] {Long.MIN_VALUE, 42L});
assertEquals(Long.MIN_VALUE, set2.minValue);
assertEquals(42, set2.maxValue);
assertEquals(set1, set2); assertEquals(set1, set2);
} }
@ -95,6 +107,7 @@ public class TestLongHashSet extends LuceneTestCase {
LongStream.of(values) LongStream.of(values)
.mapToObj(Long::valueOf) .mapToObj(Long::valueOf)
.collect(Collectors.toCollection(HashSet::new)); .collect(Collectors.toCollection(HashSet::new));
Arrays.sort(values);
LongHashSet set2 = new LongHashSet(values); LongHashSet set2 = new LongHashSet(values);
assertEquals(set1, set2); assertEquals(set1, set2);
} }