LUCENE-10384: Simplify LongHeap. (#615)

The min/max ordering logic moves to NeighborQueue.
This commit is contained in:
Adrien Grand 2022-01-25 09:04:52 +01:00 committed by GitHub
parent eaf3cb6739
commit 07fe46ff86
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 61 additions and 112 deletions

View File

@ -60,7 +60,7 @@ final class PForUtil {
/** Encode 128 integers from {@code longs} into {@code out}. */
void encode(long[] longs, DataOutput out) throws IOException {
// Determine the top MAX_EXCEPTIONS + 1 values
final LongHeap top = LongHeap.create(LongHeap.Order.MIN, MAX_EXCEPTIONS + 1);
final LongHeap top = new LongHeap(MAX_EXCEPTIONS + 1);
for (int i = 0; i <= MAX_EXCEPTIONS; ++i) {
top.push(longs[i]);
}

View File

@ -17,27 +17,16 @@
package org.apache.lucene.util;
/**
* A heap that stores longs; a primitive priority queue that like all priority queues maintains a
* partial ordering of its elements such that the least element can always be found in constant
* A min heap that stores longs; a primitive priority queue that like all priority queues maintains
* a partial ordering of its elements such that the least element can always be found in constant
* time. Put()'s and pop()'s require log(size). This heap provides unbounded growth via {@link
* #push(long)}, and bounded-size insertion based on its nominal maxSize via {@link
* #insertWithOverflow(long)}. The heap may be either a min heap, in which case the least element is
* the smallest integer, or a max heap, when it is the largest, depending on the Order parameter.
* #insertWithOverflow(long)}. The heap is a min heap, meaning that the top element is the lowest
* value of the heap.
*
* @lucene.internal
*/
public abstract class LongHeap {
/**
* Used to specify the ordering of the heap. A min-heap provides access to the minimum element in
* constant time, and when bounded, retains the maximum <code>maxSize</code> elements. A max-heap
* conversely provides access to the maximum element in constant time, and when bounded retains
* the minimum <code>maxSize</code> elements.
*/
public enum Order {
MIN,
MAX
}
public final class LongHeap {
private final int maxSize;
@ -50,7 +39,7 @@ public abstract class LongHeap {
* @param maxSize the maximum size of the heap, or if negative, the initial size of an unbounded
* heap
*/
LongHeap(int maxSize) {
public LongHeap(int maxSize) {
final int heapSize;
if (maxSize < 1 || maxSize >= ArrayUtil.MAX_ARRAY_LENGTH) {
// Throw exception to prevent confusing OOME:
@ -63,33 +52,6 @@ public abstract class LongHeap {
this.heap = new long[heapSize];
}
public static LongHeap create(Order order, int maxSize) {
// TODO: override push() for unbounded queue
if (order == Order.MIN) {
return new LongHeap(maxSize) {
@Override
public boolean lessThan(long a, long b) {
return a < b;
}
};
} else {
return new LongHeap(maxSize) {
@Override
public boolean lessThan(long a, long b) {
return a > b;
}
};
}
}
/**
* Determines the ordering of objects in this priority queue. Subclasses must define this one
* method.
*
* @return <code>true</code> iff parameter <code>a</code> is less than parameter <code>b</code>.
*/
public abstract boolean lessThan(long a, long b);
/**
* Adds a value in log(size) time. Grows unbounded as needed to accommodate new values.
*
@ -114,7 +76,7 @@ public abstract class LongHeap {
*/
public boolean insertWithOverflow(long value) {
if (size >= maxSize) {
if (lessThan(value, heap[1])) {
if (value < heap[1]) {
return false;
}
updateTop(value);
@ -190,7 +152,7 @@ public abstract class LongHeap {
int i = origPos;
long value = heap[i]; // save bottom value
int j = i >>> 1;
while (j > 0 && lessThan(value, heap[j])) {
while (j > 0 && value < heap[j]) {
heap[i] = heap[j]; // shift parents down
i = j;
j = j >>> 1;
@ -202,15 +164,15 @@ public abstract class LongHeap {
long value = heap[i]; // save top value
int j = i << 1; // find smaller child
int k = j + 1;
if (k <= size && lessThan(heap[k], heap[j])) {
if (k <= size && heap[k] < heap[j]) {
j = k;
}
while (j <= size && lessThan(heap[j], value)) {
while (j <= size && heap[j] < value) {
heap[i] = heap[j]; // shift up child
i = j;
j = i << 1;
k = j + 1;
if (k <= size && lessThan(heap[k], heap[j])) {
if (k <= size && heap[k] < heap[j]) {
j = k;
}
}
@ -236,7 +198,8 @@ public abstract class LongHeap {
*
* @lucene.internal
*/
protected final long[] getHeapArray() {
// pkg-private for testing
final long[] getHeapArray() {
return heap;
}
}

View File

@ -29,17 +29,34 @@ import org.apache.lucene.util.NumericUtils;
*/
public class NeighborQueue {
private static enum Order {
NATURAL {
@Override
long apply(long v) {
return v;
}
},
REVERSED {
@Override
long apply(long v) {
// This cannot be just `-v` since Long.MIN_VALUE doesn't have a positive counterpart. It
// needs a function that returns MAX_VALUE for MIN_VALUE and vice-versa.
return -1 - v;
}
};
abstract long apply(long v);
}
private final LongHeap heap;
private final Order order;
// Used to track the number of neighbors visited during a single graph traversal
private int visitedCount;
NeighborQueue(int initialSize, boolean reversed) {
if (reversed) {
heap = LongHeap.create(LongHeap.Order.MAX, initialSize);
} else {
heap = LongHeap.create(LongHeap.Order.MIN, initialSize);
}
this.heap = new LongHeap(initialSize);
this.order = reversed ? Order.REVERSED : Order.NATURAL;
}
/** @return the number of elements in the heap */
@ -71,12 +88,12 @@ public class NeighborQueue {
}
private long encode(int node, float score) {
return (((long) NumericUtils.floatToSortableInt(score)) << 32) | node;
return order.apply((((long) NumericUtils.floatToSortableInt(score)) << 32) | node);
}
/** Removes the top element and returns its node id. */
public int pop() {
return (int) heap.pop();
return (int) order.apply(heap.pop());
}
int[] nodes() {
@ -90,12 +107,12 @@ public class NeighborQueue {
/** Returns the top element's node id. */
public int topNode() {
return (int) heap.top();
return (int) order.apply(heap.top());
}
/** Returns the top element's node score. */
public float topScore() {
return NumericUtils.sortableIntToFloat((int) (heap.top() >> 32));
return NumericUtils.sortableIntToFloat((int) (order.apply(heap.top()) >> 32));
}
public int visitedCount() {

View File

@ -16,9 +16,6 @@
*/
package org.apache.lucene.util;
import static org.apache.lucene.util.LongHeap.Order.MAX;
import static org.apache.lucene.util.LongHeap.Order.MIN;
import java.util.ArrayList;
import java.util.Random;
import org.apache.lucene.tests.util.LuceneTestCase;
@ -26,26 +23,11 @@ import org.apache.lucene.tests.util.TestUtil;
public class TestLongHeap extends LuceneTestCase {
private static class AssertingLongHeap extends LongHeap {
AssertingLongHeap(int count) {
super(count);
}
@Override
public boolean lessThan(long a, long b) {
return (a < b);
}
final void checkValidity() {
long[] heapArray = getHeapArray();
for (int i = 1; i <= size(); i++) {
int parent = i >>> 1;
if (parent > 1) {
if (lessThan(heapArray[parent], heapArray[i]) == false) {
assertEquals(heapArray[parent], heapArray[i]);
}
}
}
private static void checkValidity(LongHeap heap) {
long[] heapArray = heap.getHeapArray();
for (int i = 2; i <= heap.size(); i++) {
int parent = i >>> 1;
assert heapArray[parent] <= heapArray[i];
}
}
@ -54,7 +36,7 @@ public class TestLongHeap extends LuceneTestCase {
}
public static void testPQ(int count, Random gen) {
LongHeap pq = LongHeap.create(MIN, count);
LongHeap pq = new LongHeap(count);
long sum = 0, sum2 = 0;
for (int i = 0; i < count; i++) {
@ -75,7 +57,7 @@ public class TestLongHeap extends LuceneTestCase {
}
public void testClear() {
LongHeap pq = LongHeap.create(MIN, 3);
LongHeap pq = new LongHeap(3);
pq.push(2);
pq.push(3);
pq.push(1);
@ -85,7 +67,7 @@ public class TestLongHeap extends LuceneTestCase {
}
public void testExceedBounds() {
LongHeap pq = LongHeap.create(MIN, 1);
LongHeap pq = new LongHeap(1);
pq.push(2);
pq.push(0);
// expectThrows(ArrayIndexOutOfBoundsException.class, () -> pq.push(0));
@ -94,7 +76,7 @@ public class TestLongHeap extends LuceneTestCase {
}
public void testFixedSize() {
LongHeap pq = LongHeap.create(MIN, 3);
LongHeap pq = new LongHeap(3);
pq.insertWithOverflow(2);
pq.insertWithOverflow(3);
pq.insertWithOverflow(1);
@ -105,20 +87,8 @@ public class TestLongHeap extends LuceneTestCase {
assertEquals(3, pq.top());
}
public void testFixedSizeMax() {
LongHeap pq = LongHeap.create(MAX, 3);
pq.insertWithOverflow(2);
pq.insertWithOverflow(3);
pq.insertWithOverflow(1);
pq.insertWithOverflow(5);
pq.insertWithOverflow(7);
pq.insertWithOverflow(1);
assertEquals(3, pq.size());
assertEquals(2, pq.top());
}
public void testDuplicateValues() {
LongHeap pq = LongHeap.create(MIN, 3);
LongHeap pq = new LongHeap(3);
pq.push(2);
pq.push(3);
pq.push(1);
@ -131,7 +101,7 @@ public class TestLongHeap extends LuceneTestCase {
public void testInsertions() {
Random random = random();
int numDocsInPQ = TestUtil.nextInt(random, 1, 100);
AssertingLongHeap pq = new AssertingLongHeap(numDocsInPQ);
LongHeap pq = new LongHeap(numDocsInPQ);
Long lastLeast = null;
// Basic insertion of new content
@ -140,7 +110,7 @@ public class TestLongHeap extends LuceneTestCase {
long newEntry = Math.abs(random.nextLong());
sds.add(newEntry);
pq.insertWithOverflow(newEntry);
pq.checkValidity();
checkValidity(pq);
long newLeast = pq.top();
if ((lastLeast != null) && (newLeast != newEntry) && (newLeast != lastLeast)) {
// If there has been a change of least entry and it wasn't our new
@ -153,17 +123,16 @@ public class TestLongHeap extends LuceneTestCase {
}
public void testInvalid() {
expectThrows(IllegalArgumentException.class, () -> LongHeap.create(MAX, -1));
expectThrows(IllegalArgumentException.class, () -> LongHeap.create(MAX, 0));
expectThrows(
IllegalArgumentException.class, () -> LongHeap.create(MAX, ArrayUtil.MAX_ARRAY_LENGTH));
expectThrows(IllegalArgumentException.class, () -> new LongHeap(-1));
expectThrows(IllegalArgumentException.class, () -> new LongHeap(0));
expectThrows(IllegalArgumentException.class, () -> new LongHeap(ArrayUtil.MAX_ARRAY_LENGTH));
}
public void testUnbounded() {
int initialSize = random().nextInt(10) + 1;
LongHeap pq = LongHeap.create(MAX, initialSize);
LongHeap pq = new LongHeap(initialSize);
int num = random().nextInt(100) + 1;
long minValue = Long.MAX_VALUE;
long maxValue = Long.MIN_VALUE;
int count = 0;
for (int i = 0; i < num; i++) {
long value = random().nextLong();
@ -178,19 +147,19 @@ public class TestLongHeap extends LuceneTestCase {
}
}
}
minValue = Math.min(minValue, value);
maxValue = Math.max(maxValue, value);
}
assertEquals(count, pq.size());
long last = Long.MAX_VALUE;
long last = Long.MIN_VALUE;
while (pq.size() > 0) {
long top = pq.top();
long next = pq.pop();
assertEquals(top, next);
--count;
assertTrue(next <= last);
assertTrue(next >= last);
last = next;
}
assertEquals(0, count);
assertEquals(minValue, last);
assertEquals(maxValue, last);
}
}