LUCENE-8916: GraphTokenStreamFiniteStrings preserves all attributes

This commit is contained in:
Alan Woodward 2019-07-19 11:19:21 +01:00
parent 24b94b8dac
commit 1ccef96767
4 changed files with 66 additions and 51 deletions

View File

@ -68,6 +68,9 @@ Improvements
* LUCENE-8918: PhraseQuery throws exceptions at construction time if it is passed
null arguments. (Alan Woodward)
* LUCENE-8916: GraphTokenStreamFiniteStrings preserves all Token attributes
through its finite strings TokenStreams (Alan Woodward)
Other
* LUCENE-8778 LUCENE-8911: Define analyzer SPI names as static final fields and document the names in Javadocs.

View File

@ -18,11 +18,13 @@
package org.apache.lucene.analysis.shingle;
import java.io.IOException;
import java.util.Iterator;
import org.apache.lucene.analysis.BaseTokenStreamTestCase;
import org.apache.lucene.analysis.CannedTokenStream;
import org.apache.lucene.analysis.Token;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.util.graph.GraphTokenStreamFiniteStrings;
public class FixedShingleFilterTest extends BaseTokenStreamTestCase {
@ -227,4 +229,19 @@ public class FixedShingleFilterTest extends BaseTokenStreamTestCase {
assertEquals("Shingle size must be between 2 and 4, got 5", e2.getMessage());
}
public void testWithGraphInput() throws IOException {
TokenStream ts = new CannedTokenStream(
new Token("fuz", 0, 3),
new Token("foo", 1, 4, 6, 2),
new Token("bar", 0, 4, 6),
new Token("baz", 1, 4, 6)
);
GraphTokenStreamFiniteStrings graph = new GraphTokenStreamFiniteStrings(ts);
Iterator<TokenStream> it = graph.getFiniteStrings();
assertTokenStreamContents(new FixedShingleFilter(it.next(), 2), new String[]{ "fuz foo"});
assertTokenStreamContents(new FixedShingleFilter(it.next(), 2), new String[]{ "fuz bar", "bar baz"});
}
}

View File

@ -22,18 +22,16 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.BytesTermAttribute;
import org.apache.lucene.analysis.tokenattributes.PositionIncrementAttribute;
import org.apache.lucene.analysis.tokenattributes.PositionLengthAttribute;
import org.apache.lucene.analysis.tokenattributes.TermToBytesRefAttribute;
import org.apache.lucene.index.Term;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.AttributeSource;
import org.apache.lucene.util.IntsRef;
import org.apache.lucene.util.automaton.Automaton;
import org.apache.lucene.util.automaton.FiniteStringsIterator;
@ -48,19 +46,18 @@ import static org.apache.lucene.util.automaton.Operations.DEFAULT_MAX_DETERMINIZ
* This class also provides helpers to explore the different paths of the {@link Automaton}.
*/
public final class GraphTokenStreamFiniteStrings {
private final Map<Integer, BytesRef> idToTerm = new HashMap<>();
private final Map<Integer, Integer> idToInc = new HashMap<>();
private AttributeSource[] tokens = new AttributeSource[4];
private final Automaton det;
private final Transition transition = new Transition();
private class FiniteStringsTokenStream extends TokenStream {
private final BytesTermAttribute termAtt = addAttribute(BytesTermAttribute.class);
private final PositionIncrementAttribute posIncAtt = addAttribute(PositionIncrementAttribute.class);
private final IntsRef ids;
private final int end;
private int offset;
FiniteStringsTokenStream(final IntsRef ids) {
super(tokens[0].cloneAttributes());
assert ids != null;
this.ids = ids;
this.offset = ids.offset;
@ -72,13 +69,7 @@ public final class GraphTokenStreamFiniteStrings {
if (offset < end) {
clearAttributes();
int id = ids.ints[offset];
termAtt.setBytesRef(idToTerm.get(id));
int incr = 1;
if (idToInc.containsKey(id)) {
incr = idToInc.get(id);
}
posIncAtt.setPositionIncrement(incr);
tokens[id].copyTo(this);
offset++;
return true;
}
@ -111,20 +102,26 @@ public final class GraphTokenStreamFiniteStrings {
return false;
}
/**
* Returns the list of tokens that start at the provided state
*/
public List<AttributeSource> getTerms(int state) {
int numT = det.initTransition(state, transition);
List<AttributeSource> tokens = new ArrayList<> ();
for (int i = 0; i < numT; i++) {
det.getNextTransition(transition);
tokens.addAll(Arrays.asList(this.tokens).subList(transition.min, transition.max + 1));
}
return tokens;
}
/**
* Returns the list of terms that start at the provided state
*/
public Term[] getTerms(String field, int state) {
int numT = det.initTransition(state, transition);
List<Term> terms = new ArrayList<> ();
for (int i = 0; i < numT; i++) {
det.getNextTransition(transition);
for (int id = transition.min; id <= transition.max; id++) {
Term term = new Term(field, idToTerm.get(id));
terms.add(term);
}
}
return terms.toArray(new Term[terms.size()]);
return getTerms(state).stream()
.map(s -> new Term(field, s.addAttribute(TermToBytesRefAttribute.class).getBytesRef()))
.toArray(Term[]::new);
}
/**
@ -138,9 +135,9 @@ public final class GraphTokenStreamFiniteStrings {
/**
* Get all finite strings that start at {@code startState} and end at {@code endState}.
*/
public Iterator<TokenStream> getFiniteStrings(int startState, int endState) throws IOException {
public Iterator<TokenStream> getFiniteStrings(int startState, int endState) {
final FiniteStringsIterator it = new FiniteStringsIterator(det, startState, endState);
return new Iterator<TokenStream> () {
return new Iterator<> () {
IntsRef current;
boolean finished = false;
@ -202,7 +199,7 @@ public final class GraphTokenStreamFiniteStrings {
*/
private Automaton build(final TokenStream in) throws IOException {
Automaton.Builder builder = new Automaton.Builder();
final TermToBytesRefAttribute termBytesAtt = in.addAttribute(TermToBytesRefAttribute.class);
final PositionIncrementAttribute posIncAtt = in.addAttribute(PositionIncrementAttribute.class);
final PositionLengthAttribute posLengthAtt = in.addAttribute(PositionLengthAttribute.class);
@ -211,6 +208,7 @@ public final class GraphTokenStreamFiniteStrings {
int pos = -1;
int prevIncr = 1;
int state = -1;
int id = -1;
int gap = 0;
while (in.incrementToken()) {
int currentIncr = posIncAtt.getPositionIncrement();
@ -233,12 +231,23 @@ public final class GraphTokenStreamFiniteStrings {
state = builder.createState();
}
BytesRef term = termBytesAtt.getBytesRef();
int id = getTermID(currentIncr, prevIncr, term);
//System.out.println("Adding transition: " + term.utf8ToString() + "@" + pos + "->" + endPos);
id++;
if (tokens.length < id + 1) {
tokens = ArrayUtil.grow(tokens, id + 1);
}
tokens[id] = in.cloneAttributes();
builder.addTransition(pos, endPos, id);
pos += gap;
// we always produce linear token graphs from getFiniteStrings(), so we need to adjust
// posLength and posIncrement accordingly
tokens[id].addAttribute(PositionLengthAttribute.class).setPositionLength(1);
if (currentIncr == 0) {
// stacked token should have the same increment as original token at this position
tokens[id].addAttribute(PositionIncrementAttribute.class).setPositionIncrement(prevIncr);
}
// only save last increment on non-zero increment in case we have multiple stacked tokens
if (currentIncr > 0) {
prevIncr = currentIncr;
@ -252,23 +261,6 @@ public final class GraphTokenStreamFiniteStrings {
return builder.finish();
}
/**
* Gets an integer id for a given term and saves the position increment if needed.
*/
private int getTermID(int incr, int prevIncr, BytesRef term) {
assert term != null;
boolean isStackedGap = incr == 0 && prevIncr > 1;
int id = idToTerm.size();
idToTerm.put(id, BytesRef.deepCopyOf(term));
// stacked token should have the same increment as original token at this position
if (isStackedGap) {
idToInc.put(id, prevIncr);
} else if (incr > 1) {
idToInc.put(id, incr);
}
return id;
}
private static void articulationPointsRecurse(Automaton a, int state, int d, int[] depth, int[] low, int[] parent,
BitSet visited, List<Integer> points) {
visited.set(state);

View File

@ -21,8 +21,9 @@ import java.util.Iterator;
import org.apache.lucene.analysis.CannedTokenStream;
import org.apache.lucene.analysis.Token;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.BytesTermAttribute;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.analysis.tokenattributes.PositionIncrementAttribute;
import org.apache.lucene.analysis.tokenattributes.PositionLengthAttribute;
import org.apache.lucene.index.Term;
import org.apache.lucene.util.LuceneTestCase;
@ -44,14 +45,16 @@ public class TestGraphTokenStreamFiniteStrings extends LuceneTestCase {
assertNotNull(terms);
assertNotNull(increments);
assertEquals(terms.length, increments.length);
BytesTermAttribute termAtt = ts.getAttribute(BytesTermAttribute.class);
CharTermAttribute termAtt = ts.getAttribute(CharTermAttribute.class);
PositionIncrementAttribute incrAtt = ts.getAttribute(PositionIncrementAttribute.class);
PositionLengthAttribute lenAtt = ts.getAttribute(PositionLengthAttribute.class);
int offset = 0;
while (ts.incrementToken()) {
// verify term and increment
assert offset < terms.length;
assertEquals(terms[offset], termAtt.getBytesRef().utf8ToString());
assertEquals(terms[offset], termAtt.toString());
assertEquals(increments[offset], incrAtt.getPositionIncrement());
assertEquals(1, lenAtt.getPositionLength()); // we always output linear token streams
offset++;
}