Reduce memory use of MinimizationOperations#minimize (#13511)

It is relatively easy to consume a massive amount of memory
for the minimize operation, with its lists of boxed Integer (even though these are mostly cached,
it's still more than 4b per instance to store them instead of plain storage) and neverending
duplicate+empty StateList instances.
The boxed integer situation we can fix and probably speedup by using the hppc primitive collections.
To fix the duplicate/empty StateList instances, we can use a constant. This requires some hacky forking
on the write path but that's about it.
This is partly motivated by ES users at times creating broken, very long prefix queries that can then eat up
GBs of heap. With this change, the examples I've been looking at become about 6x cheaper heap wise, making it
less likely that kind of mistakes impacts stability.
This commit is contained in:
Armin Braun 2024-06-25 16:15:30 +02:00 committed by GitHub
parent 3ae59a9809
commit 33a4c1d8ef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 44 additions and 33 deletions

View File

@ -29,10 +29,11 @@
package org.apache.lucene.util.automaton;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashSet;
import java.util.LinkedList;
import org.apache.lucene.internal.hppc.IntArrayList;
import org.apache.lucene.internal.hppc.IntCursor;
import org.apache.lucene.internal.hppc.IntHashSet;
/**
* Operations for minimizing automata.
@ -75,13 +76,9 @@ public final class MinimizationOperations {
final int[] sigma = a.getStartPoints();
final int sigmaLen = sigma.length, statesLen = a.getNumStates();
@SuppressWarnings({"rawtypes", "unchecked"})
final ArrayList<Integer>[][] reverse =
(ArrayList<Integer>[][]) new ArrayList[statesLen][sigmaLen];
@SuppressWarnings({"rawtypes", "unchecked"})
final HashSet<Integer>[] partition = (HashSet<Integer>[]) new HashSet[statesLen];
@SuppressWarnings({"rawtypes", "unchecked"})
final ArrayList<Integer>[] splitblock = (ArrayList<Integer>[]) new ArrayList[statesLen];
final IntArrayList[][] reverse = new IntArrayList[statesLen][sigmaLen];
final IntHashSet[] partition = new IntHashSet[statesLen];
final IntArrayList[] splitblock = new IntArrayList[statesLen];
final int[] block = new int[statesLen];
final StateList[][] active = new StateList[statesLen][sigmaLen];
final StateListNode[][] active2 = new StateListNode[statesLen][sigmaLen];
@ -91,10 +88,10 @@ public final class MinimizationOperations {
refine = new BitSet(statesLen),
refine2 = new BitSet(statesLen);
for (int q = 0; q < statesLen; q++) {
splitblock[q] = new ArrayList<>();
partition[q] = new HashSet<>();
splitblock[q] = new IntArrayList();
partition[q] = new IntHashSet();
for (int x = 0; x < sigmaLen; x++) {
active[q][x] = new StateList();
active[q][x] = StateList.EMPTY;
}
}
// find initial partition and reverse edges
@ -106,9 +103,9 @@ public final class MinimizationOperations {
transition.source = q;
transition.transitionUpto = -1;
for (int x = 0; x < sigmaLen; x++) {
final ArrayList<Integer>[] r = reverse[a.next(transition, sigma[x])];
final IntArrayList[] r = reverse[a.next(transition, sigma[x])];
if (r[x] == null) {
r[x] = new ArrayList<>();
r[x] = new IntArrayList();
}
r[x].add(q);
}
@ -116,9 +113,15 @@ public final class MinimizationOperations {
// initialize active sets
for (int j = 0; j <= 1; j++) {
for (int x = 0; x < sigmaLen; x++) {
for (int q : partition[j]) {
for (IntCursor qCursor : partition[j]) {
int q = qCursor.value;
if (reverse[q][x] != null) {
active2[q][x] = active[j][x].add(q);
StateList stateList = active[j][x];
if (stateList == StateList.EMPTY) {
stateList = new StateList();
active[j][x] = stateList;
}
active2[q][x] = stateList.add(q);
}
}
}
@ -143,9 +146,10 @@ public final class MinimizationOperations {
pending2.clear(x * statesLen + p);
// find states that need to be split off their blocks
for (StateListNode m = active[p][x].first; m != null; m = m.next) {
final ArrayList<Integer> r = reverse[m.q][x];
final IntArrayList r = reverse[m.q][x];
if (r != null) {
for (int i : r) {
for (IntCursor iCursor : r) {
final int i = iCursor.value;
if (!split.get(i)) {
split.set(i);
final int j = block[i];
@ -161,11 +165,12 @@ public final class MinimizationOperations {
// refine blocks
for (int j = refine.nextSetBit(0); j >= 0; j = refine.nextSetBit(j + 1)) {
final ArrayList<Integer> sb = splitblock[j];
final IntArrayList sb = splitblock[j];
if (sb.size() < partition[j].size()) {
final HashSet<Integer> b1 = partition[j];
final HashSet<Integer> b2 = partition[k];
for (int s : sb) {
final IntHashSet b1 = partition[j];
final IntHashSet b2 = partition[k];
for (IntCursor iCursor : sb) {
final int s = iCursor.value;
b1.remove(s);
b2.add(s);
block[s] = k;
@ -173,7 +178,12 @@ public final class MinimizationOperations {
final StateListNode sn = active2[s][c];
if (sn != null && sn.sl == active[j][c]) {
sn.remove();
active2[s][c] = active[k][c].add(s);
StateList stateList = active[k][c];
if (stateList == StateList.EMPTY) {
stateList = new StateList();
active[k][c] = stateList;
}
active2[s][c] = stateList.add(s);
}
}
}
@ -191,7 +201,8 @@ public final class MinimizationOperations {
k++;
}
refine2.clear(j);
for (int s : sb) {
for (IntCursor iCursor : sb) {
final int s = iCursor.value;
split.clear(s);
}
sb.clear();
@ -215,17 +226,11 @@ public final class MinimizationOperations {
for (int n = 0; n < k; n++) {
// System.out.println(" n=" + n);
boolean isInitial = false;
for (int q : partition[n]) {
if (q == 0) {
isInitial = true;
// System.out.println(" isInitial!");
break;
}
}
boolean isInitial = partition[n].contains(0);
int newState;
if (isInitial) {
// System.out.println(" isInitial!");
newState = 0;
} else {
newState = result.createState();
@ -233,7 +238,8 @@ public final class MinimizationOperations {
// System.out.println(" newState=" + newState);
for (int q : partition[n]) {
for (IntCursor qCursor : partition[n]) {
int q = qCursor.value;
stateMap[q] = newState;
// System.out.println(" q=" + q + " isAccept?=" + a.isAccept(q));
result.setAccept(newState, a.isAccept(q));
@ -268,11 +274,16 @@ public final class MinimizationOperations {
static final class StateList {
// Empty list that should never be mutated, used as a memory saving optimization instead of null
// so we don't need to branch the read path in #minimize
static final StateList EMPTY = new StateList();
int size;
StateListNode first, last;
StateListNode add(int q) {
assert this != EMPTY;
return new StateListNode(q, this);
}
}