toposort use iterator to avoid stackoverflow (#12286)

Co-authored-by: tangdonghai <tangdonghai@meituan.com>
This commit is contained in:
tang donghai 2023-05-15 22:20:15 +08:00 committed by GitHub
parent 223e28ef16
commit 5d203f8337
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 136 additions and 49 deletions

View File

@ -126,6 +126,8 @@ Optimizations
* GITHUB#12270 Don't generate stacktrace in CollectionTerminatedException. (Armin Braun)
* GITHUB#12286 Toposort use iterator to avoid stackoverflow. (Tang Donghai)
Bug Fixes
---------------------
(No changes)

View File

@ -1273,9 +1273,14 @@ public final class Operations {
}
/**
* Returns the topological sort of all states reachable from the initial state. Behavior is
* undefined if this automaton has cycles. CPU cost is O(numTransitions), and the implementation
* is recursive so an automaton matching long strings may exhaust the java stack.
* Returns the topological sort of all states reachable from the initial state. This method
* assumes that the automaton does not contain cycles, and will throw an IllegalArgumentException
* if a cycle is detected. The CPU cost is O(numTransitions), and the implementation is
* non-recursive, so it will not exhaust the java stack for automaton matching long strings. If
* there are dead states in the automaton, they will be removed from the returned array.
*
* @param a the Automaton to be sorted
* @return the topologically sorted array of state ids
*/
public static int[] topoSortStates(Automaton a) {
if (a.getNumStates() == 0) {
@ -1283,8 +1288,7 @@ public final class Operations {
}
int numStates = a.getNumStates();
int[] states = new int[numStates];
final BitSet visited = new BitSet(numStates);
int upto = topoSortStatesRecurse(a, visited, states, 0, 0, 0);
int upto = topoSortStates(a, states);
if (upto < states.length) {
// There were dead states
@ -1303,24 +1307,49 @@ public final class Operations {
return states;
}
// TODO: not great that this is recursive... in theory a
// large automata could exceed java's stack so the maximum level of recursion is bounded to 1000
private static int topoSortStatesRecurse(
Automaton a, BitSet visited, int[] states, int upto, int state, int level) {
if (level > MAX_RECURSION_LEVEL) {
throw new IllegalArgumentException("input automaton is too large: " + level);
}
/**
* Performs a topological sort on the states of the given Automaton.
*
* @param a The automaton whose states are to be topologically sorted.
* @param states An int array which stores the states.
* @return the number of states in the final sorted list.
* @throws IllegalArgumentException if the input automaton has a cycle.
*/
private static int topoSortStates(Automaton a, int[] states) {
BitSet onStack = new BitSet(a.getNumStates());
BitSet visited = new BitSet(a.getNumStates());
var stack = new ArrayDeque<Integer>();
stack.push(0); // Assuming that the initial state is 0.
int upto = 0;
Transition t = new Transition();
int count = a.initTransition(state, t);
for (int i = 0; i < count; i++) {
a.getNextTransition(t);
if (!visited.get(t.dest)) {
visited.set(t.dest);
upto = topoSortStatesRecurse(a, visited, states, upto, t.dest, level + 1);
while (!stack.isEmpty()) {
int state = stack.peek(); // Just peek, don't remove the state yet
int count = a.initTransition(state, t);
boolean pushed = false;
for (int i = 0; i < count; i++) {
a.getNextTransition(t);
if (!visited.get(t.dest)) {
visited.set(t.dest);
stack.push(t.dest); // Push the next unvisited state onto the stack
onStack.set(state);
pushed = true;
break; // Exit the loop, we'll continue from here in the next iteration
} else if (onStack.get(t.dest)) {
// If the state is on the current recursion stack, we have detected a cycle
throw new IllegalArgumentException("Input automaton has a cycle.");
}
}
// If we haven't pushed any new state onto the stack, we're done with this state
if (!pushed) {
onStack.clear(state); // remove the node from the current recursion stack
stack.pop();
states[upto] = state;
upto++;
}
}
states[upto] = state;
upto++;
return upto;
}
}

View File

@ -17,6 +17,7 @@
package org.apache.lucene.util.automaton;
import static org.apache.lucene.util.automaton.Operations.DEFAULT_DETERMINIZE_WORK_LIMIT;
import static org.apache.lucene.util.automaton.Operations.topoSortStates;
import com.carrotsearch.randomizedtesting.generators.RandomNumbers;
import java.util.ArrayList;
@ -69,6 +70,42 @@ public class TestOperations extends LuceneTestCase {
assertTrue(Operations.isEmpty(concat));
}
/**
* Test case for the topoSortStates method when the input Automaton contains a cycle. This test
* case constructs an Automaton with two disjoint sets of statesone without a cycle and one with
* a cycle. The topoSortStates method should detect the presence of a cycle and throw an
* IllegalArgumentException.
*/
public void testCycledAutomaton() {
Automaton a = generateRandomAutomaton(true);
IllegalArgumentException exc =
expectThrows(IllegalArgumentException.class, () -> topoSortStates(a));
assertTrue(exc.getMessage().contains("Input automaton has a cycle"));
}
public void testTopoSortStates() {
Automaton a = generateRandomAutomaton(false);
int[] sorted = topoSortStates(a);
int[] stateMap = new int[a.getNumStates()];
Arrays.fill(stateMap, -1);
int order = 0;
for (int state : sorted) {
assertEquals(-1, stateMap[state]);
stateMap[state] = (order++);
}
Transition transition = new Transition();
for (int state : sorted) {
int count = a.initTransition(state, transition);
for (int i = 0; i < count; i++) {
a.getNextTransition(transition);
// ensure dest's order is higher than current state
assertTrue(stateMap[transition.dest] > stateMap[state]);
}
}
}
/** Test optimization to concatenate() with empty String to an NFA */
public void testEmptySingletonNFAConcatenate() {
Automaton singleton = Automata.makeString("");
@ -136,19 +173,6 @@ public class TestOperations extends LuceneTestCase {
assertTrue(exc.getMessage().contains("input automaton is too large"));
}
public void testTopoSortEatsStack() {
char[] chars = new char[50000];
TestUtil.randomFixedLengthUnicodeString(random(), chars, 0, chars.length);
String bigString1 = new String(chars);
TestUtil.randomFixedLengthUnicodeString(random(), chars, 0, chars.length);
String bigString2 = new String(chars);
Automaton a =
Operations.union(Automata.makeString(bigString1), Automata.makeString(bigString2));
IllegalArgumentException exc =
expectThrows(IllegalArgumentException.class, () -> Operations.topoSortStates(a));
assertTrue(exc.getMessage().contains("input automaton is too large"));
}
/**
* Returns the set of all accepted strings.
*
@ -182,4 +206,52 @@ public class TestOperations extends LuceneTestCase {
return result;
}
/**
* This method creates a random Automaton by generating states at multiple levels. At each level,
* a random number of states are created, and transitions are added between the states of the
* current and the previous level randomly, If the 'hasCycle' parameter is true, a transition is
* added from the first state of the last level back to the initial state to create a cycle in the
* Automaton..
*
* @param hasCycle if true, the generated Automaton will have a cycle; if false, it won't have a
* cycle.
* @return a randomly generated Automaton instance.
*/
private Automaton generateRandomAutomaton(boolean hasCycle) {
Automaton a = new Automaton();
List<Integer> lastLevelStates = new ArrayList<>();
int initialState = a.createState();
int maxLevel = random().nextInt(4, 10);
lastLevelStates.add(initialState);
for (int level = 1; level < maxLevel; level++) {
int numStates = random().nextInt(3, 10);
List<Integer> nextLevelStates = new ArrayList<>();
for (int i = 0; i < numStates; i++) {
int nextState = a.createState();
nextLevelStates.add(nextState);
}
for (int lastState : lastLevelStates) {
for (int nextState : nextLevelStates) {
// if hasCycle is enabled, we will always add a transition, so we could make sure the
// generated Automaton has a cycle.
if (hasCycle || random().nextInt(7) >= 1) {
a.addTransition(lastState, nextState, random().nextInt(10));
}
}
}
lastLevelStates = nextLevelStates;
}
if (hasCycle) {
int lastState = lastLevelStates.get(0);
a.addTransition(lastState, initialState, random().nextInt(10));
}
a.finishState();
return a;
}
}

View File

@ -1325,22 +1325,6 @@ public class TestAnalyzingSuggester extends LuceneTestCase {
return asList;
}
// TODO: we need BaseSuggesterTestCase?
public void testTooLongSuggestion() throws Exception {
Analyzer a = new MockAnalyzer(random());
Directory tempDir = getDirectory();
AnalyzingSuggester suggester = new AnalyzingSuggester(tempDir, "suggest", a);
String bigString = TestUtil.randomSimpleString(random(), 30000, 30000);
IllegalArgumentException ex =
expectThrows(
IllegalArgumentException.class,
() -> {
suggester.build(new InputArrayIterator(new Input[] {new Input(bigString, 7)}));
});
assertTrue(ex.getMessage().contains("input automaton is too large"));
IOUtils.close(a, tempDir);
}
private Directory getDirectory() {
return newDirectory();
}