LUCENE-9063: Speed up computation of impacts. (#1038)

The current design of CompetitiveImpactAccumulator treats norms in -128..127
as a special case that should be optimized. This commit goes a bit further by
treating it as the normal case, and only ever adding impacts to the TreeSet if
the norm is outside of the byte range. It avoids a number of operations on
TreeSets like adding impacts or removing redundant impacts.
This commit is contained in:
Adrien Grand 2019-11-26 11:49:57 +01:00 committed by GitHub
parent e37e56c795
commit ded8efa82a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 87 additions and 47 deletions

View File

@ -19,8 +19,7 @@ package org.apache.lucene.codecs.lucene50;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays; import java.util.Arrays;
import java.util.Set; import java.util.Collection;
import java.util.SortedSet;
import org.apache.lucene.codecs.CompetitiveImpactAccumulator; import org.apache.lucene.codecs.CompetitiveImpactAccumulator;
import org.apache.lucene.codecs.MultiLevelSkipListWriter; import org.apache.lucene.codecs.MultiLevelSkipListWriter;
@ -141,7 +140,7 @@ final class Lucene50SkipWriter extends MultiLevelSkipListWriter {
// sets of competitive freq,norm pairs should be empty at this point // sets of competitive freq,norm pairs should be empty at this point
assert Arrays.stream(curCompetitiveFreqNorms) assert Arrays.stream(curCompetitiveFreqNorms)
.map(CompetitiveImpactAccumulator::getCompetitiveFreqNormPairs) .map(CompetitiveImpactAccumulator::getCompetitiveFreqNormPairs)
.mapToInt(Set::size) .mapToInt(Collection::size)
.sum() == 0; .sum() == 0;
initialized = true; initialized = true;
} }
@ -205,7 +204,7 @@ final class Lucene50SkipWriter extends MultiLevelSkipListWriter {
} }
static void writeImpacts(CompetitiveImpactAccumulator acc, DataOutput out) throws IOException { static void writeImpacts(CompetitiveImpactAccumulator acc, DataOutput out) throws IOException {
SortedSet<Impact> impacts = acc.getCompetitiveFreqNormPairs(); Collection<Impact> impacts = acc.getCompetitiveFreqNormPairs();
Impact previous = new Impact(0, 0); Impact previous = new Impact(0, 0);
for (Impact impact : impacts) { for (Impact impact : impacts) {
assert impact.freq > previous.freq; assert impact.freq > previous.freq;

View File

@ -16,11 +16,13 @@
*/ */
package org.apache.lucene.codecs; package org.apache.lucene.codecs;
import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.Comparator; import java.util.Comparator;
import java.util.Iterator; import java.util.Iterator;
import java.util.SortedSet; import java.util.List;
import java.util.TreeSet; import java.util.TreeSet;
import org.apache.lucene.index.Impact; import org.apache.lucene.index.Impact;
@ -30,11 +32,14 @@ import org.apache.lucene.index.Impact;
*/ */
public final class CompetitiveImpactAccumulator { public final class CompetitiveImpactAccumulator {
// We speed up accumulation for common norm values by first computing // We speed up accumulation for common norm values with this array that maps
// the max freq for all norms in -128..127 // norm values in -128..127 to the maximum frequency observed for these norm
// values
private final int[] maxFreqs; private final int[] maxFreqs;
private boolean dirty; // This TreeSet stores competitive (freq,norm) pairs for norm values that fall
private final TreeSet<Impact> freqNormPairs; // outside of -128..127. It is always empty with the default similarity, which
// encodes norms as bytes.
private final TreeSet<Impact> otherFreqNormPairs;
/** Sole constructor. */ /** Sole constructor. */
public CompetitiveImpactAccumulator() { public CompetitiveImpactAccumulator() {
@ -51,14 +56,14 @@ public final class CompetitiveImpactAccumulator {
return cmp; return cmp;
} }
}; };
freqNormPairs = new TreeSet<>(comparator); otherFreqNormPairs = new TreeSet<>(comparator);
} }
/** Reset to the same state it was in after creation. */ /** Reset to the same state it was in after creation. */
public void clear() { public void clear() {
Arrays.fill(maxFreqs, 0); Arrays.fill(maxFreqs, 0);
dirty = false; otherFreqNormPairs.clear();
freqNormPairs.clear(); assertConsistent();
} }
/** Accumulate a (freq,norm) pair, updating this structure if there is no /** Accumulate a (freq,norm) pair, updating this structure if there is no
@ -67,34 +72,52 @@ public final class CompetitiveImpactAccumulator {
if (norm >= Byte.MIN_VALUE && norm <= Byte.MAX_VALUE) { if (norm >= Byte.MIN_VALUE && norm <= Byte.MAX_VALUE) {
int index = Byte.toUnsignedInt((byte) norm); int index = Byte.toUnsignedInt((byte) norm);
maxFreqs[index] = Math.max(maxFreqs[index], freq); maxFreqs[index] = Math.max(maxFreqs[index], freq);
dirty = true;
} else { } else {
add(new Impact(freq, norm)); add(new Impact(freq, norm), otherFreqNormPairs);
} }
assertConsistent();
} }
/** Merge {@code acc} into this. */ /** Merge {@code acc} into this. */
public void addAll(CompetitiveImpactAccumulator acc) { public void addAll(CompetitiveImpactAccumulator acc) {
for (Impact entry : acc.getCompetitiveFreqNormPairs()) { int[] maxFreqs = this.maxFreqs;
add(entry); int[] otherMaxFreqs = acc.maxFreqs;
for (int i = 0; i < maxFreqs.length; ++i) {
maxFreqs[i] = Math.max(maxFreqs[i], otherMaxFreqs[i]);
} }
for (Impact entry : acc.otherFreqNormPairs) {
add(entry, otherFreqNormPairs);
}
assertConsistent();
} }
/** Get the set of competitive freq and norm pairs, orderer by increasing freq and norm. */ /** Get the set of competitive freq and norm pairs, orderer by increasing freq and norm. */
public SortedSet<Impact> getCompetitiveFreqNormPairs() { public Collection<Impact> getCompetitiveFreqNormPairs() {
if (dirty) { List<Impact> impacts = new ArrayList<>();
for (int i = 0; i < maxFreqs.length; ++i) { int maxFreqForLowerNorms = 0;
if (maxFreqs[i] > 0) { for (int i = 0; i < maxFreqs.length; ++i) {
add(new Impact(maxFreqs[i], (byte) i)); int maxFreq = maxFreqs[i];
maxFreqs[i] = 0; if (maxFreq > maxFreqForLowerNorms) {
} impacts.add(new Impact(maxFreq, (byte) i));
maxFreqForLowerNorms = maxFreq;
} }
dirty = false;
} }
return Collections.unmodifiableSortedSet(freqNormPairs);
if (otherFreqNormPairs.isEmpty()) {
// Common case: all norms are bytes
return impacts;
}
TreeSet<Impact> freqNormPairs = new TreeSet<>(this.otherFreqNormPairs);
for (Impact impact : impacts) {
add(impact, freqNormPairs);
}
return Collections.unmodifiableSet(freqNormPairs);
} }
private void add(Impact newEntry) { private void add(Impact newEntry, TreeSet<Impact> freqNormPairs) {
Impact next = freqNormPairs.ceiling(newEntry); Impact next = freqNormPairs.ceiling(newEntry);
if (next == null) { if (next == null) {
// nothing is more competitive // nothing is more competitive
@ -122,6 +145,23 @@ public final class CompetitiveImpactAccumulator {
@Override @Override
public String toString() { public String toString() {
return getCompetitiveFreqNormPairs().toString(); return new ArrayList<>(getCompetitiveFreqNormPairs()).toString();
}
// Only called by assertions
private boolean assertConsistent() {
for (int freq : maxFreqs) {
assert freq >= 0;
}
int previousFreq = 0;
long previousNorm = 0;
for (Impact impact : otherFreqNormPairs) {
assert impact.norm < Byte.MIN_VALUE || impact.norm > Byte.MAX_VALUE;
assert previousFreq < impact.freq;
assert Long.compareUnsigned(previousNorm, impact.norm) < 0;
previousFreq = impact.freq;
previousNorm = impact.norm;
}
return true;
} }
} }

View File

@ -19,8 +19,7 @@ package org.apache.lucene.codecs.lucene84;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays; import java.util.Arrays;
import java.util.Set; import java.util.Collection;
import java.util.SortedSet;
import org.apache.lucene.codecs.CompetitiveImpactAccumulator; import org.apache.lucene.codecs.CompetitiveImpactAccumulator;
import org.apache.lucene.codecs.MultiLevelSkipListWriter; import org.apache.lucene.codecs.MultiLevelSkipListWriter;
@ -141,7 +140,7 @@ final class Lucene84SkipWriter extends MultiLevelSkipListWriter {
// sets of competitive freq,norm pairs should be empty at this point // sets of competitive freq,norm pairs should be empty at this point
assert Arrays.stream(curCompetitiveFreqNorms) assert Arrays.stream(curCompetitiveFreqNorms)
.map(CompetitiveImpactAccumulator::getCompetitiveFreqNormPairs) .map(CompetitiveImpactAccumulator::getCompetitiveFreqNormPairs)
.mapToInt(Set::size) .mapToInt(Collection::size)
.sum() == 0; .sum() == 0;
initialized = true; initialized = true;
} }
@ -205,7 +204,7 @@ final class Lucene84SkipWriter extends MultiLevelSkipListWriter {
} }
static void writeImpacts(CompetitiveImpactAccumulator acc, DataOutput out) throws IOException { static void writeImpacts(CompetitiveImpactAccumulator acc, DataOutput out) throws IOException {
SortedSet<Impact> impacts = acc.getCompetitiveFreqNormPairs(); Collection<Impact> impacts = acc.getCompetitiveFreqNormPairs();
Impact previous = new Impact(0, 0); Impact previous = new Impact(0, 0);
for (Impact impact : impacts) { for (Impact impact : impacts) {
assert impact.freq > previous.freq; assert impact.freq > previous.freq;

View File

@ -17,8 +17,10 @@
package org.apache.lucene.codecs; package org.apache.lucene.codecs;
import java.util.Collections; import java.util.Collections;
import java.util.HashSet; import java.util.Comparator;
import java.util.List;
import java.util.Set; import java.util.Set;
import java.util.TreeSet;
import org.apache.lucene.index.Impact; import org.apache.lucene.index.Impact;
import org.apache.lucene.util.LuceneTestCase; import org.apache.lucene.util.LuceneTestCase;
@ -27,59 +29,59 @@ public class TestCompetitiveFreqNormAccumulator extends LuceneTestCase {
public void testBasics() { public void testBasics() {
CompetitiveImpactAccumulator acc = new CompetitiveImpactAccumulator(); CompetitiveImpactAccumulator acc = new CompetitiveImpactAccumulator();
Set<Impact> expected = new HashSet<>(); Set<Impact> expected = new TreeSet<>(Comparator.comparingInt(i -> i.freq));
acc.add(3, 5); acc.add(3, 5);
expected.add(new Impact(3, 5)); expected.add(new Impact(3, 5));
assertEquals(expected, acc.getCompetitiveFreqNormPairs()); assertEquals(List.copyOf(expected), List.copyOf(acc.getCompetitiveFreqNormPairs()));
acc.add(6, 11); acc.add(6, 11);
expected.add(new Impact(6, 11)); expected.add(new Impact(6, 11));
assertEquals(expected, acc.getCompetitiveFreqNormPairs()); assertEquals(List.copyOf(expected), List.copyOf(acc.getCompetitiveFreqNormPairs()));
acc.add(10, 13); acc.add(10, 13);
expected.add(new Impact(10, 13)); expected.add(new Impact(10, 13));
assertEquals(expected, acc.getCompetitiveFreqNormPairs()); assertEquals(List.copyOf(expected), List.copyOf(acc.getCompetitiveFreqNormPairs()));
acc.add(1, 2); acc.add(1, 2);
expected.add(new Impact(1, 2)); expected.add(new Impact(1, 2));
assertEquals(expected, acc.getCompetitiveFreqNormPairs()); assertEquals(List.copyOf(expected), List.copyOf(acc.getCompetitiveFreqNormPairs()));
acc.add(7, 9); acc.add(7, 9);
expected.remove(new Impact(6, 11)); expected.remove(new Impact(6, 11));
expected.add(new Impact(7, 9)); expected.add(new Impact(7, 9));
assertEquals(expected, acc.getCompetitiveFreqNormPairs()); assertEquals(List.copyOf(expected), List.copyOf(acc.getCompetitiveFreqNormPairs()));
acc.add(8, 2); acc.add(8, 2);
expected.clear(); expected.clear();
expected.add(new Impact(10, 13)); expected.add(new Impact(10, 13));
expected.add(new Impact(8, 2)); expected.add(new Impact(8, 2));
assertEquals(expected, acc.getCompetitiveFreqNormPairs()); assertEquals(List.copyOf(expected), List.copyOf(acc.getCompetitiveFreqNormPairs()));
} }
public void testExtremeNorms() { public void testExtremeNorms() {
CompetitiveImpactAccumulator acc = new CompetitiveImpactAccumulator(); CompetitiveImpactAccumulator acc = new CompetitiveImpactAccumulator();
Set<Impact> expected = new HashSet<>(); Set<Impact> expected = new TreeSet<>(Comparator.comparingInt(i -> i.freq));
acc.add(3, 5); acc.add(3, 5);
expected.add(new Impact(3, 5)); expected.add(new Impact(3, 5));
assertEquals(expected, acc.getCompetitiveFreqNormPairs()); assertEquals(List.copyOf(expected), List.copyOf(acc.getCompetitiveFreqNormPairs()));
acc.add(10, 10000); acc.add(10, 10000);
expected.add(new Impact(10, 10000)); expected.add(new Impact(10, 10000));
assertEquals(expected, acc.getCompetitiveFreqNormPairs()); assertEquals(List.copyOf(expected), List.copyOf(acc.getCompetitiveFreqNormPairs()));
acc.add(5, 200); acc.add(5, 200);
expected.add(new Impact(5, 200)); expected.add(new Impact(5, 200));
assertEquals(expected, acc.getCompetitiveFreqNormPairs()); assertEquals(List.copyOf(expected), List.copyOf(acc.getCompetitiveFreqNormPairs()));
acc.add(20, -100); acc.add(20, -100);
expected.add(new Impact(20, -100)); expected.add(new Impact(20, -100));
assertEquals(expected, acc.getCompetitiveFreqNormPairs()); assertEquals(List.copyOf(expected), List.copyOf(acc.getCompetitiveFreqNormPairs()));
acc.add(30, -3); acc.add(30, -3);
expected.add(new Impact(30, -3)); expected.add(new Impact(30, -3));
assertEquals(expected, acc.getCompetitiveFreqNormPairs()); assertEquals(List.copyOf(expected), List.copyOf(acc.getCompetitiveFreqNormPairs()));
} }
public void testOmitFreqs() { public void testOmitFreqs() {
@ -89,7 +91,7 @@ public class TestCompetitiveFreqNormAccumulator extends LuceneTestCase {
acc.add(1, 7); acc.add(1, 7);
acc.add(1, 4); acc.add(1, 4);
assertEquals(Collections.singleton(new Impact(1, 4)), acc.getCompetitiveFreqNormPairs()); assertEquals(Collections.singletonList(new Impact(1, 4)), List.copyOf(acc.getCompetitiveFreqNormPairs()));
} }
public void testOmitNorms() { public void testOmitNorms() {
@ -99,6 +101,6 @@ public class TestCompetitiveFreqNormAccumulator extends LuceneTestCase {
acc.add(7, 1); acc.add(7, 1);
acc.add(4, 1); acc.add(4, 1);
assertEquals(Collections.singleton(new Impact(7, 1)), acc.getCompetitiveFreqNormPairs()); assertEquals(Collections.singletonList(new Impact(7, 1)), List.copyOf(acc.getCompetitiveFreqNormPairs()));
} }
} }