From 3cb71436a7b7f8f3178cc8efbdf8248982ac41e0 Mon Sep 17 00:00:00 2001 From: gf2121 <52390227+gf2121@users.noreply.github.com> Date: Fri, 29 Sep 2023 20:11:29 -0500 Subject: [PATCH] Sort update terms with stable radix sorter (#12591) --- lucene/CHANGES.txt | 2 + .../lucene/index/FieldUpdatesBuffer.java | 67 ++++++++----- .../org/apache/lucene/util/BytesRefArray.java | 89 +++++++++-------- .../lucene/util/BytesRefComparator.java | 33 ++++++- .../org/apache/lucene/util/BytesRefHash.java | 9 +- .../lucene/util/FixedLengthBytesRefArray.java | 75 +++----------- .../lucene/util/StableStringSorter.java | 80 +++++++++++++++ .../lucene/util/StringMSBRadixSorter.java | 77 --------------- .../org/apache/lucene/util/StringSorter.java | 98 +++++++++++++++++++ .../lucene/index/TestFieldUpdatesBuffer.java | 28 ++---- .../apache/lucene/util/TestBytesRefArray.java | 58 ++++++++++- ...RadixSorter.java => TestStringSorter.java} | 64 +++++++++++- .../lucene/misc/index/BPIndexReorderer.java | 2 +- 13 files changed, 443 insertions(+), 239 deletions(-) create mode 100644 lucene/core/src/java/org/apache/lucene/util/StableStringSorter.java delete mode 100644 lucene/core/src/java/org/apache/lucene/util/StringMSBRadixSorter.java create mode 100644 lucene/core/src/java/org/apache/lucene/util/StringSorter.java rename lucene/core/src/test/org/apache/lucene/util/{TestStringMSBRadixSorter.java => TestStringSorter.java} (60%) diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index d4c1a223740..e361f8b9205 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -158,6 +158,8 @@ Optimizations * GITHUB#12382: Faster top-level conjunctions on term queries when sorting by descending score. (Adrien Grand) + +* GITHUB#12591: Use stable radix sort to speed up the sorting of update terms. (Guo Feng) Changes in runtime behavior --------------------- diff --git a/lucene/core/src/java/org/apache/lucene/index/FieldUpdatesBuffer.java b/lucene/core/src/java/org/apache/lucene/index/FieldUpdatesBuffer.java index bc04116e6b3..0f1ef3439bf 100644 --- a/lucene/core/src/java/org/apache/lucene/index/FieldUpdatesBuffer.java +++ b/lucene/core/src/java/org/apache/lucene/index/FieldUpdatesBuffer.java @@ -19,11 +19,11 @@ package org.apache.lucene.index; import java.io.IOException; import java.util.Arrays; -import java.util.Comparator; import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.Bits; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.BytesRefArray; +import org.apache.lucene.util.BytesRefComparator; import org.apache.lucene.util.BytesRefIterator; import org.apache.lucene.util.Counter; import org.apache.lucene.util.FixedBitSet; @@ -50,7 +50,7 @@ final class FieldUpdatesBuffer { // we use a very simple approach and store the update term values without de-duplication // which is also not a common case to keep updating the same value more than once... // we might pay a higher price in terms of memory in certain cases but will gain - // on CPU for those. We also save on not needing to sort in order to apply the terms in order + // on CPU for those. We also use a stable sort to sort in order to apply the terms in order // since by definition we store them in order. private final BytesRefArray termValues; private BytesRefArray.SortState termSortState; @@ -212,19 +212,36 @@ final class FieldUpdatesBuffer { finished = true; final boolean sortedTerms = hasSingleValue() && hasValues == null && fields.length == 1; if (sortedTerms) { - // sort by ascending by term, then sort descending by docsUpTo so that we can skip updates - // with lower docUpTo. - termSortState = - termValues.sort( - Comparator.naturalOrder(), - (i1, i2) -> - Integer.compare( - docsUpTo[getArrayIndex(docsUpTo.length, i2)], - docsUpTo[getArrayIndex(docsUpTo.length, i1)])); + termSortState = termValues.sort(BytesRefComparator.NATURAL, true); + assert assertTermAndDocInOrder(); bytesUsed.addAndGet(termSortState.ramBytesUsed()); } } + private boolean assertTermAndDocInOrder() { + try { + BytesRefArray.IndexedBytesRefIterator iterator = termValues.iterator(termSortState); + BytesRef last = null; + int lastOrd = -1; + BytesRef current; + while ((current = iterator.next()) != null) { + if (last != null) { + int cmp = current.compareTo(last); + assert cmp >= 0 : "term in reverse order"; + assert cmp != 0 + || docsUpTo[getArrayIndex(docsUpTo.length, lastOrd)] + <= docsUpTo[getArrayIndex(docsUpTo.length, iterator.ord())] + : "doc id in reverse order"; + } + last = BytesRef.deepCopyOf(current); + lastOrd = iterator.ord(); + } + } catch (IOException e) { + assert false : e.getMessage(); + } + return true; + } + BufferedUpdateIterator iterator() { if (finished == false) { throw new IllegalStateException("buffer is not finished yet"); @@ -336,22 +353,22 @@ final class FieldUpdatesBuffer { BytesRef nextTerm() throws IOException { if (lookAheadTermIterator != null) { - final BytesRef lastTerm = bufferedUpdate.termValue; - BytesRef lookAheadTerm; - while ((lookAheadTerm = lookAheadTermIterator.next()) != null - && lookAheadTerm.equals(lastTerm)) { - BytesRef discardedTerm = - termValuesIterator.next(); // discard as the docUpTo of the previous update is higher - assert discardedTerm.equals(lookAheadTerm) - : "[" + discardedTerm + "] != [" + lookAheadTerm + "]"; - assert docsUpTo[getArrayIndex(docsUpTo.length, termValuesIterator.ord())] - <= bufferedUpdate.docUpTo - : docsUpTo[getArrayIndex(docsUpTo.length, termValuesIterator.ord())] - + ">" - + bufferedUpdate.docUpTo; + if (bufferedUpdate.termValue == null) { + lookAheadTermIterator.next(); } + BytesRef lastTerm, aheadTerm; + do { + aheadTerm = lookAheadTermIterator.next(); + lastTerm = termValuesIterator.next(); + } while (aheadTerm != null + // Shortcut to avoid equals, we did a stable sort before, so aheadTerm can only equal + // lastTerm when aheadTerm has a lager ord. + && lookAheadTermIterator.ord() > termValuesIterator.ord() + && aheadTerm.equals(lastTerm)); + return lastTerm; + } else { + return termValuesIterator.next(); } - return termValuesIterator.next(); } } diff --git a/lucene/core/src/java/org/apache/lucene/util/BytesRefArray.java b/lucene/core/src/java/org/apache/lucene/util/BytesRefArray.java index 4ab96d8519c..6f69843586b 100644 --- a/lucene/core/src/java/org/apache/lucene/util/BytesRefArray.java +++ b/lucene/core/src/java/org/apache/lucene/util/BytesRefArray.java @@ -19,7 +19,6 @@ package org.apache.lucene.util; import java.util.Arrays; import java.util.Comparator; import java.util.Objects; -import java.util.function.IntBinaryOperator; /** * A simple append only random-access {@link BytesRef} array that stores full copies of the appended @@ -120,54 +119,64 @@ public final class BytesRefArray implements SortableBytesRefArray { /** * Returns a {@link SortState} representing the order of elements in this array. This is a * non-destructive operation. + * + * @param comp The comparator to compare {@link BytesRef}s. A radix sort optimization is available + * if the comparator implements {@link BytesRefComparator} + * @param stable If the sort needs to be stable + * @return A {@link SortState} that could be used in {@link BytesRefArray#iterator(SortState)} */ - public SortState sort(final Comparator comp, final IntBinaryOperator tieComparator) { + public SortState sort(final Comparator comp, boolean stable) { final int[] orderedEntries = new int[size()]; for (int i = 0; i < orderedEntries.length; i++) { orderedEntries[i] = i; } - new IntroSorter() { - @Override - protected void swap(int i, int j) { - final int o = orderedEntries[i]; - orderedEntries[i] = orderedEntries[j]; - orderedEntries[j] = o; - } + StringSorter sorter; + if (stable) { + sorter = + new StableStringSorter(comp) { - @Override - protected int compare(int i, int j) { - final int idx1 = orderedEntries[i], idx2 = orderedEntries[j]; - setBytesRef(scratch1, scratchBytes1, idx1); - setBytesRef(scratch2, scratchBytes2, idx2); - return compare(idx1, scratchBytes1, idx2, scratchBytes2); - } + private final int[] tmp = new int[size()]; - @Override - protected void setPivot(int i) { - pivotIndex = orderedEntries[i]; - setBytesRef(pivotBuilder, pivot, pivotIndex); - } + @Override + protected void get(BytesRefBuilder builder, BytesRef result, int i) { + BytesRefArray.this.setBytesRef(builder, result, orderedEntries[i]); + } - @Override - protected int comparePivot(int j) { - final int index = orderedEntries[j]; - setBytesRef(scratch2, scratchBytes2, index); - return compare(pivotIndex, pivot, index, scratchBytes2); - } + @Override + protected void save(int i, int j) { + tmp[j] = orderedEntries[i]; + } - private int compare(int i1, BytesRef b1, int i2, BytesRef b2) { - int res = comp.compare(b1, b2); - return res == 0 ? tieComparator.applyAsInt(i1, i2) : res; - } + @Override + protected void restore(int i, int j) { + System.arraycopy(tmp, i, orderedEntries, i, j - i); + } - private int pivotIndex; - private final BytesRef pivot = new BytesRef(); - private final BytesRef scratchBytes1 = new BytesRef(); - private final BytesRef scratchBytes2 = new BytesRef(); - private final BytesRefBuilder pivotBuilder = new BytesRefBuilder(); - private final BytesRefBuilder scratch1 = new BytesRefBuilder(); - private final BytesRefBuilder scratch2 = new BytesRefBuilder(); - }.sort(0, size()); + @Override + protected void swap(int i, int j) { + int o = orderedEntries[i]; + orderedEntries[i] = orderedEntries[j]; + orderedEntries[j] = o; + } + }; + } else { + sorter = + new StringSorter(comp) { + @Override + protected void get(BytesRefBuilder builder, BytesRef result, int i) { + BytesRefArray.this.setBytesRef(builder, result, orderedEntries[i]); + } + + @Override + protected void swap(int i, int j) { + int o = orderedEntries[i]; + orderedEntries[i] = orderedEntries[j]; + orderedEntries[j] = o; + } + }; + } + + sorter.sort(0, size()); return new SortState(orderedEntries); } @@ -188,7 +197,7 @@ public final class BytesRefArray implements SortableBytesRefArray { */ @Override public BytesRefIterator iterator(final Comparator comp) { - return iterator(sort(comp, (i, j) -> 0)); + return iterator(sort(comp, false)); } /** diff --git a/lucene/core/src/java/org/apache/lucene/util/BytesRefComparator.java b/lucene/core/src/java/org/apache/lucene/util/BytesRefComparator.java index 05261216fd1..0465ec13e6c 100644 --- a/lucene/core/src/java/org/apache/lucene/util/BytesRefComparator.java +++ b/lucene/core/src/java/org/apache/lucene/util/BytesRefComparator.java @@ -16,6 +16,7 @@ */ package org.apache.lucene.util; +import java.util.Arrays; import java.util.Comparator; /** @@ -26,6 +27,29 @@ import java.util.Comparator; */ public abstract class BytesRefComparator implements Comparator { + /** Comparing ByteRefs in natual order. */ + public static final BytesRefComparator NATURAL = + new BytesRefComparator(Integer.MAX_VALUE) { + @Override + protected int byteAt(BytesRef ref, int i) { + if (ref.length <= i) { + return -1; + } + return ref.bytes[ref.offset + i] & 0xFF; + } + + @Override + public int compare(BytesRef o1, BytesRef o2, int k) { + return Arrays.compareUnsigned( + o1.bytes, + o1.offset + k, + o1.offset + o1.length, + o2.bytes, + o2.offset + k, + o2.offset + o2.length); + } + }; + final int comparedBytesCount; /** @@ -45,8 +69,13 @@ public abstract class BytesRefComparator implements Comparator { protected abstract int byteAt(BytesRef ref, int i); @Override - public int compare(BytesRef o1, BytesRef o2) { - for (int i = 0; i < comparedBytesCount; ++i) { + public final int compare(BytesRef o1, BytesRef o2) { + return compare(o1, o2, 0); + } + + /** Compare two bytes refs that first k bytes are already guaranteed to be equal. */ + public int compare(BytesRef o1, BytesRef o2, int k) { + for (int i = k; i < comparedBytesCount; ++i) { final int b1 = byteAt(o1, i); final int b2 = byteAt(o2, i); if (b1 != b2) { diff --git a/lucene/core/src/java/org/apache/lucene/util/BytesRefHash.java b/lucene/core/src/java/org/apache/lucene/util/BytesRefHash.java index 99ad781ebb6..26c16c07d90 100644 --- a/lucene/core/src/java/org/apache/lucene/util/BytesRefHash.java +++ b/lucene/core/src/java/org/apache/lucene/util/BytesRefHash.java @@ -149,9 +149,7 @@ public final class BytesRefHash implements Accountable { */ public int[] sort() { final int[] compact = compact(); - new StringMSBRadixSorter() { - - BytesRef scratch = new BytesRef(); + new StringSorter(BytesRefComparator.NATURAL) { @Override protected void swap(int i, int j) { @@ -161,9 +159,8 @@ public final class BytesRefHash implements Accountable { } @Override - protected BytesRef get(int i) { - pool.setBytesRef(scratch, bytesStart[compact[i]]); - return scratch; + protected void get(BytesRefBuilder builder, BytesRef result, int i) { + pool.setBytesRef(result, bytesStart[compact[i]]); } }.sort(0, count); return compact; diff --git a/lucene/core/src/java/org/apache/lucene/util/FixedLengthBytesRefArray.java b/lucene/core/src/java/org/apache/lucene/util/FixedLengthBytesRefArray.java index 18d81120d63..0f1a9df3cd5 100644 --- a/lucene/core/src/java/org/apache/lucene/util/FixedLengthBytesRefArray.java +++ b/lucene/core/src/java/org/apache/lucene/util/FixedLengthBytesRefArray.java @@ -106,43 +106,20 @@ final class FixedLengthBytesRefArray implements SortableBytesRefArray { orderedEntries[i] = i; } - if (comp instanceof BytesRefComparator) { - BytesRefComparator bComp = (BytesRefComparator) comp; - new MSBRadixSorter(bComp.comparedBytesCount) { + new StringSorter(comp) { - BytesRef scratch; + { + scratchBytes1.length = valueLength; + scratchBytes2.length = valueLength; + pivot.length = valueLength; + } - { - scratch = new BytesRef(); - scratch.length = valueLength; - } - - @Override - protected void swap(int i, int j) { - int o = orderedEntries[i]; - orderedEntries[i] = orderedEntries[j]; - orderedEntries[j] = o; - } - - @Override - protected int byteAt(int i, int k) { - int index1 = orderedEntries[i]; - scratch.bytes = blocks[index1 / valuesPerBlock]; - scratch.offset = (index1 % valuesPerBlock) * valueLength; - return bComp.byteAt(scratch, k); - } - }.sort(0, size()); - return orderedEntries; - } - - final BytesRef pivot = new BytesRef(); - final BytesRef scratch1 = new BytesRef(); - final BytesRef scratch2 = new BytesRef(); - pivot.length = valueLength; - scratch1.length = valueLength; - scratch2.length = valueLength; - - new IntroSorter() { + @Override + protected void get(BytesRefBuilder builder, BytesRef result, int i) { + final int index = orderedEntries[i]; + result.bytes = blocks[index / valuesPerBlock]; + result.offset = (index % valuesPerBlock) * valueLength; + } @Override protected void swap(int i, int j) { @@ -150,34 +127,6 @@ final class FixedLengthBytesRefArray implements SortableBytesRefArray { orderedEntries[i] = orderedEntries[j]; orderedEntries[j] = o; } - - @Override - protected int compare(int i, int j) { - int index1 = orderedEntries[i]; - scratch1.bytes = blocks[index1 / valuesPerBlock]; - scratch1.offset = (index1 % valuesPerBlock) * valueLength; - - int index2 = orderedEntries[j]; - scratch2.bytes = blocks[index2 / valuesPerBlock]; - scratch2.offset = (index2 % valuesPerBlock) * valueLength; - - return comp.compare(scratch1, scratch2); - } - - @Override - protected void setPivot(int i) { - int index = orderedEntries[i]; - pivot.bytes = blocks[index / valuesPerBlock]; - pivot.offset = (index % valuesPerBlock) * valueLength; - } - - @Override - protected int comparePivot(int j) { - final int index = orderedEntries[j]; - scratch2.bytes = blocks[index / valuesPerBlock]; - scratch2.offset = (index % valuesPerBlock) * valueLength; - return comp.compare(pivot, scratch2); - } }.sort(0, size()); return orderedEntries; } diff --git a/lucene/core/src/java/org/apache/lucene/util/StableStringSorter.java b/lucene/core/src/java/org/apache/lucene/util/StableStringSorter.java new file mode 100644 index 00000000000..067a72e8084 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/util/StableStringSorter.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.util; + +import java.util.Comparator; + +abstract class StableStringSorter extends StringSorter { + + StableStringSorter(Comparator cmp) { + super(cmp); + } + + /** Save the i-th value into the j-th position in temporary storage. */ + protected abstract void save(int i, int j); + + /** Restore values between i-th and j-th(excluding) in temporary storage into original storage. */ + protected abstract void restore(int i, int j); + + @Override + protected Sorter radixSorter(BytesRefComparator cmp) { + return new StableMSBRadixSorter(cmp.comparedBytesCount) { + + @Override + protected void save(int i, int j) { + StableStringSorter.this.save(i, j); + } + + @Override + protected void restore(int i, int j) { + StableStringSorter.this.restore(i, j); + } + + @Override + protected void swap(int i, int j) { + StableStringSorter.this.swap(i, j); + } + + @Override + protected int byteAt(int i, int k) { + get(scratch1, scratchBytes1, i); + return cmp.byteAt(scratchBytes1, k); + } + + @Override + protected Sorter getFallbackSorter(int k) { + return fallbackSorter((o1, o2) -> cmp.compare(o1, o2, k)); + } + }; + } + + @Override + protected Sorter fallbackSorter(Comparator cmp) { + // TODO: Maybe tim sort is better? + return new InPlaceMergeSorter() { + @Override + protected int compare(int i, int j) { + return StableStringSorter.this.compare(i, j); + } + + @Override + protected void swap(int i, int j) { + StableStringSorter.this.swap(i, j); + } + }; + } +} diff --git a/lucene/core/src/java/org/apache/lucene/util/StringMSBRadixSorter.java b/lucene/core/src/java/org/apache/lucene/util/StringMSBRadixSorter.java deleted file mode 100644 index d8b51617a04..00000000000 --- a/lucene/core/src/java/org/apache/lucene/util/StringMSBRadixSorter.java +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.lucene.util; - -abstract class StringMSBRadixSorter extends MSBRadixSorter { - - StringMSBRadixSorter() { - super(Integer.MAX_VALUE); - } - - /** Get a {@link BytesRef} for the given index. */ - protected abstract BytesRef get(int i); - - @Override - protected int byteAt(int i, int k) { - BytesRef ref = get(i); - if (ref.length <= k) { - return -1; - } - return ref.bytes[ref.offset + k] & 0xff; - } - - @Override - protected Sorter getFallbackSorter(int k) { - return new IntroSorter() { - - private void get(int i, int k, BytesRef scratch) { - BytesRef ref = StringMSBRadixSorter.this.get(i); - assert ref.length >= k; - scratch.bytes = ref.bytes; - scratch.offset = ref.offset + k; - scratch.length = ref.length - k; - } - - @Override - protected void swap(int i, int j) { - StringMSBRadixSorter.this.swap(i, j); - } - - @Override - protected int compare(int i, int j) { - get(i, k, scratch1); - get(j, k, scratch2); - return scratch1.compareTo(scratch2); - } - - @Override - protected void setPivot(int i) { - get(i, k, pivot); - } - - @Override - protected int comparePivot(int j) { - get(j, k, scratch2); - return pivot.compareTo(scratch2); - } - - private final BytesRef pivot = new BytesRef(), - scratch1 = new BytesRef(), - scratch2 = new BytesRef(); - }; - } -} diff --git a/lucene/core/src/java/org/apache/lucene/util/StringSorter.java b/lucene/core/src/java/org/apache/lucene/util/StringSorter.java new file mode 100644 index 00000000000..92761935c72 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/util/StringSorter.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.util; + +import java.util.Comparator; + +abstract class StringSorter extends Sorter { + + private final Comparator cmp; + protected final BytesRefBuilder scratch1 = new BytesRefBuilder(); + protected final BytesRefBuilder scratch2 = new BytesRefBuilder(); + protected final BytesRefBuilder pivotBuilder = new BytesRefBuilder(); + protected final BytesRef scratchBytes1 = new BytesRef(); + protected final BytesRef scratchBytes2 = new BytesRef(); + protected final BytesRef pivot = new BytesRef(); + + StringSorter(Comparator cmp) { + this.cmp = cmp; + } + + protected abstract void get(BytesRefBuilder builder, BytesRef result, int i); + + @Override + protected int compare(int i, int j) { + get(scratch1, scratchBytes1, i); + get(scratch2, scratchBytes2, j); + return cmp.compare(scratchBytes1, scratchBytes2); + } + + @Override + public void sort(int from, int to) { + if (cmp instanceof BytesRefComparator bCmp) { + radixSorter(bCmp).sort(from, to); + } else { + fallbackSorter(cmp).sort(from, to); + } + } + + protected Sorter radixSorter(BytesRefComparator cmp) { + return new MSBRadixSorter(cmp.comparedBytesCount) { + @Override + protected void swap(int i, int j) { + StringSorter.this.swap(i, j); + } + + @Override + protected int byteAt(int i, int k) { + get(scratch1, scratchBytes1, i); + return cmp.byteAt(scratchBytes1, k); + } + + @Override + protected Sorter getFallbackSorter(int k) { + return fallbackSorter((o1, o2) -> cmp.compare(o1, o2, k)); + } + }; + } + + protected Sorter fallbackSorter(Comparator cmp) { + return new IntroSorter() { + @Override + protected void swap(int i, int j) { + StringSorter.this.swap(i, j); + } + + @Override + protected int compare(int i, int j) { + return StringSorter.this.compare(i, j); + } + + @Override + protected void setPivot(int i) { + get(pivotBuilder, pivot, i); + } + + @Override + protected int comparePivot(int j) { + get(scratch1, scratchBytes1, j); + return cmp.compare(pivot, scratchBytes1); + } + }; + } +} diff --git a/lucene/core/src/test/org/apache/lucene/index/TestFieldUpdatesBuffer.java b/lucene/core/src/test/org/apache/lucene/index/TestFieldUpdatesBuffer.java index f8b817d3ef4..5b43426b0e4 100644 --- a/lucene/core/src/test/org/apache/lucene/index/TestFieldUpdatesBuffer.java +++ b/lucene/core/src/test/org/apache/lucene/index/TestFieldUpdatesBuffer.java @@ -154,15 +154,7 @@ public class TestFieldUpdatesBuffer extends LuceneTestCase { assertFalse(buffer.isNumeric()); } - int randomDocUpTo() { - if (random().nextInt(5) == 0) { - return Integer.MAX_VALUE; - } else { - return random().nextInt(10000); - } - } - - DocValuesUpdate.BinaryDocValuesUpdate getRandomBinaryUpdate() { + DocValuesUpdate.BinaryDocValuesUpdate getRandomBinaryUpdate(int docIdUpTo) { String termField = RandomPicks.randomFrom(random(), Arrays.asList("id", "_id", "some_other_field")); String docId = "" + random().nextInt(10); @@ -171,10 +163,10 @@ public class TestFieldUpdatesBuffer extends LuceneTestCase { new Term(termField, docId), "binary", rarely() ? null : new BytesRef(TestUtil.randomRealisticUnicodeString(random()))); - return rarely() ? value.prepareForApply(randomDocUpTo()) : value; + return rarely() ? value.prepareForApply(docIdUpTo) : value; } - DocValuesUpdate.NumericDocValuesUpdate getRandomNumericUpdate() { + DocValuesUpdate.NumericDocValuesUpdate getRandomNumericUpdate(int docIdUpTo) { String termField = RandomPicks.randomFrom(random(), Arrays.asList("id", "_id", "some_other_field")); String docId = "" + random().nextInt(10); @@ -183,19 +175,19 @@ public class TestFieldUpdatesBuffer extends LuceneTestCase { new Term(termField, docId), "numeric", rarely() ? null : Long.valueOf(random().nextInt(100))); - return rarely() ? value.prepareForApply(randomDocUpTo()) : value; + return rarely() ? value.prepareForApply(docIdUpTo) : value; } public void testBinaryRandom() throws IOException { List updates = new ArrayList<>(); int numUpdates = 1 + random().nextInt(1000); Counter counter = Counter.newCounter(); - DocValuesUpdate.BinaryDocValuesUpdate randomUpdate = getRandomBinaryUpdate(); + DocValuesUpdate.BinaryDocValuesUpdate randomUpdate = getRandomBinaryUpdate(0); updates.add(randomUpdate); FieldUpdatesBuffer buffer = new FieldUpdatesBuffer(counter, randomUpdate, randomUpdate.docIDUpTo); for (int i = 0; i < numUpdates; i++) { - randomUpdate = getRandomBinaryUpdate(); + randomUpdate = getRandomBinaryUpdate(i + 1); updates.add(randomUpdate); if (randomUpdate.hasValue) { buffer.addUpdate(randomUpdate.term, randomUpdate.getValue(), randomUpdate.docIDUpTo); @@ -227,12 +219,12 @@ public class TestFieldUpdatesBuffer extends LuceneTestCase { List updates = new ArrayList<>(); int numUpdates = 1 + random().nextInt(1000); Counter counter = Counter.newCounter(); - DocValuesUpdate.NumericDocValuesUpdate randomUpdate = getRandomNumericUpdate(); + DocValuesUpdate.NumericDocValuesUpdate randomUpdate = getRandomNumericUpdate(0); updates.add(randomUpdate); FieldUpdatesBuffer buffer = new FieldUpdatesBuffer(counter, randomUpdate, randomUpdate.docIDUpTo); for (int i = 0; i < numUpdates; i++) { - randomUpdate = getRandomNumericUpdate(); + randomUpdate = getRandomNumericUpdate(i + 1); updates.add(randomUpdate); if (randomUpdate.hasValue) { buffer.addUpdate(randomUpdate.term, randomUpdate.getValue(), randomUpdate.docIDUpTo); @@ -272,7 +264,7 @@ public class TestFieldUpdatesBuffer extends LuceneTestCase { DocValuesUpdate.NumericDocValuesUpdate randomUpdate = new DocValuesUpdate.NumericDocValuesUpdate( new Term(termField, Integer.toString(random().nextInt(1000))), "numeric", docValue); - randomUpdate = randomUpdate.prepareForApply(randomDocUpTo()); + randomUpdate = randomUpdate.prepareForApply(0); updates.add(randomUpdate); FieldUpdatesBuffer buffer = new FieldUpdatesBuffer(counter, randomUpdate, randomUpdate.docIDUpTo); @@ -280,7 +272,7 @@ public class TestFieldUpdatesBuffer extends LuceneTestCase { randomUpdate = new DocValuesUpdate.NumericDocValuesUpdate( new Term(termField, Integer.toString(random().nextInt(1000))), "numeric", docValue); - randomUpdate = randomUpdate.prepareForApply(randomDocUpTo()); + randomUpdate = randomUpdate.prepareForApply(i + 1); updates.add(randomUpdate); buffer.addUpdate(randomUpdate.term, randomUpdate.getValue(), randomUpdate.docIDUpTo); } diff --git a/lucene/core/src/test/org/apache/lucene/util/TestBytesRefArray.java b/lucene/core/src/test/org/apache/lucene/util/TestBytesRefArray.java index 95ff5254742..e2e0125ba34 100644 --- a/lucene/core/src/test/org/apache/lucene/util/TestBytesRefArray.java +++ b/lucene/core/src/test/org/apache/lucene/util/TestBytesRefArray.java @@ -16,6 +16,7 @@ */ package org.apache.lucene.util; +import com.carrotsearch.randomizedtesting.generators.RandomPicks; import java.io.IOException; import java.util.ArrayList; import java.util.Collections; @@ -73,12 +74,12 @@ public class TestBytesRefArray extends LuceneTestCase { BytesRefArray list = new BytesRefArray(Counter.newCounter()); List stringList = new ArrayList<>(); - for (int j = 0; j < 2; j++) { + for (int j = 0; j < 5; j++) { if (j > 0 && random.nextBoolean()) { list.clear(); stringList.clear(); } - int entries = atLeast(500); + int entries = atLeast(200); BytesRefBuilder spare = new BytesRefBuilder(); final int initSize = list.size(); for (int i = 0; i < entries; i++) { @@ -89,7 +90,9 @@ public class TestBytesRefArray extends LuceneTestCase { } Collections.sort(stringList, TestUtil.STRING_CODEPOINT_COMPARATOR); - BytesRefIterator iter = list.iterator(Comparator.naturalOrder()); + BytesRefIterator iter = + list.iterator( + random().nextBoolean() ? Comparator.naturalOrder() : BytesRefComparator.NATURAL); int i = 0; BytesRef next; while ((next = iter.next()) != null) { @@ -100,4 +103,53 @@ public class TestBytesRefArray extends LuceneTestCase { assertEquals(i, stringList.size()); } } + + public void testStableSort() throws IOException { + Random random = random(); + BytesRefArray list = new BytesRefArray(Counter.newCounter()); + List stringList = new ArrayList<>(); + + for (int j = 0; j < 5; j++) { + if (j > 0 && random.nextBoolean()) { + list.clear(); + stringList.clear(); + } + int entries = atLeast(200); + String[] values = new String[20]; + for (int i = 0; i < values.length; i++) { + values[i] = TestUtil.randomRealisticUnicodeString(random); + } + BytesRefBuilder spare = new BytesRefBuilder(); + final int initSize = list.size(); + for (int i = 0; i < entries; i++) { + String randomRealisticUnicodeString = RandomPicks.randomFrom(random, values); + spare.copyChars(randomRealisticUnicodeString); + assertEquals(initSize + i, list.append(spare.get())); + stringList.add(randomRealisticUnicodeString); + } + + Collections.sort(stringList, TestUtil.STRING_CODEPOINT_COMPARATOR); + BytesRefArray.SortState state = + list.sort( + random().nextBoolean() ? Comparator.naturalOrder() : BytesRefComparator.NATURAL, + true); + BytesRefArray.IndexedBytesRefIterator iter = list.iterator(state); + int i = 0; + int lastOrd = -1; + BytesRef last = null; + BytesRef next; + while ((next = iter.next()) != null) { + assertEquals("entry " + i + " doesn't match", stringList.get(i), next.utf8ToString()); + i++; + + if (next.equals(last)) { + assertTrue("sort not stable: " + iter.ord() + " <= " + lastOrd, iter.ord() > lastOrd); + } + last = BytesRef.deepCopyOf(next); + lastOrd = iter.ord(); + } + assertNull(iter.next()); + assertEquals(i, stringList.size()); + } + } } diff --git a/lucene/core/src/test/org/apache/lucene/util/TestStringMSBRadixSorter.java b/lucene/core/src/test/org/apache/lucene/util/TestStringSorter.java similarity index 60% rename from lucene/core/src/test/org/apache/lucene/util/TestStringMSBRadixSorter.java rename to lucene/core/src/test/org/apache/lucene/util/TestStringSorter.java index aa06ac78635..0d9af398738 100644 --- a/lucene/core/src/test/org/apache/lucene/util/TestStringMSBRadixSorter.java +++ b/lucene/core/src/test/org/apache/lucene/util/TestStringSorter.java @@ -17,20 +17,32 @@ package org.apache.lucene.util; import java.util.Arrays; +import java.util.Comparator; +import java.util.stream.IntStream; import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.tests.util.TestUtil; -public class TestStringMSBRadixSorter extends LuceneTestCase { +public class TestStringSorter extends LuceneTestCase { private void test(BytesRef[] refs, int len) { + test(ArrayUtil.copyOfSubArray(refs, 0, len), len, BytesRefComparator.NATURAL); + test(ArrayUtil.copyOfSubArray(refs, 0, len), len, Comparator.naturalOrder()); + testStable(ArrayUtil.copyOfSubArray(refs, 0, len), len, BytesRefComparator.NATURAL); + testStable(ArrayUtil.copyOfSubArray(refs, 0, len), len, Comparator.naturalOrder()); + } + + private void test(BytesRef[] refs, int len, Comparator comparator) { BytesRef[] expected = ArrayUtil.copyOfSubArray(refs, 0, len); Arrays.sort(expected); - new StringMSBRadixSorter() { + new StringSorter(comparator) { @Override - protected BytesRef get(int i) { - return refs[i]; + protected void get(BytesRefBuilder builder, BytesRef result, int i) { + BytesRef ref = refs[i]; + result.offset = ref.offset; + result.length = ref.length; + result.bytes = ref.bytes; } @Override @@ -44,6 +56,50 @@ public class TestStringMSBRadixSorter extends LuceneTestCase { assertArrayEquals(expected, actual); } + private void testStable(BytesRef[] refs, int len, Comparator comparator) { + BytesRef[] expected = ArrayUtil.copyOfSubArray(refs, 0, len); + Arrays.sort(expected); + + int[] ord = new int[len]; + IntStream.range(0, len).forEach(i -> ord[i] = i); + new StableStringSorter(comparator) { + + final int[] tmp = new int[len]; + + @Override + protected void save(int i, int j) { + tmp[j] = ord[i]; + } + + @Override + protected void restore(int i, int j) { + System.arraycopy(tmp, i, ord, i, j - i); + } + + @Override + protected void get(BytesRefBuilder builder, BytesRef result, int i) { + BytesRef ref = refs[ord[i]]; + result.offset = ref.offset; + result.length = ref.length; + result.bytes = ref.bytes; + } + + @Override + protected void swap(int i, int j) { + int tmp = ord[i]; + ord[i] = ord[j]; + ord[j] = tmp; + } + }.sort(0, len); + + for (int i = 0; i < len; i++) { + assertEquals(expected[i], refs[ord[i]]); + if (i > 0 && expected[i].equals(expected[i - 1])) { + assertTrue("not stable: " + ord[i] + " <= " + ord[i - 1], ord[i] > ord[i - 1]); + } + } + } + public void testEmpty() { test(new BytesRef[random().nextInt(5)], 0); } diff --git a/lucene/misc/src/java/org/apache/lucene/misc/index/BPIndexReorderer.java b/lucene/misc/src/java/org/apache/lucene/misc/index/BPIndexReorderer.java index 5fe08b4d124..b36ba85d086 100644 --- a/lucene/misc/src/java/org/apache/lucene/misc/index/BPIndexReorderer.java +++ b/lucene/misc/src/java/org/apache/lucene/misc/index/BPIndexReorderer.java @@ -679,7 +679,7 @@ public final class BPIndexReorderer { } @Override - public int compare(BytesRef o1, BytesRef o2) { + public int compare(BytesRef o1, BytesRef o2, int k) { assert o1.length == 2 * Integer.BYTES; assert o2.length == 2 * Integer.BYTES; return ArrayUtil.compareUnsigned8(o1.bytes, o1.offset, o2.bytes, o2.offset);