diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 5a3eb836581..1c80a5e8680 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -68,6 +68,9 @@ Improvements * LUCENE-8918: PhraseQuery throws exceptions at construction time if it is passed null arguments. (Alan Woodward) +* LUCENE-8916: GraphTokenStreamFiniteStrings preserves all Token attributes + through its finite strings TokenStreams (Alan Woodward) + Other * LUCENE-8778 LUCENE-8911: Define analyzer SPI names as static final fields and document the names in Javadocs. diff --git a/lucene/analysis/common/src/test/org/apache/lucene/analysis/shingle/FixedShingleFilterTest.java b/lucene/analysis/common/src/test/org/apache/lucene/analysis/shingle/FixedShingleFilterTest.java index 3c7ffe07521..978ade5d535 100644 --- a/lucene/analysis/common/src/test/org/apache/lucene/analysis/shingle/FixedShingleFilterTest.java +++ b/lucene/analysis/common/src/test/org/apache/lucene/analysis/shingle/FixedShingleFilterTest.java @@ -18,11 +18,13 @@ package org.apache.lucene.analysis.shingle; import java.io.IOException; +import java.util.Iterator; import org.apache.lucene.analysis.BaseTokenStreamTestCase; import org.apache.lucene.analysis.CannedTokenStream; import org.apache.lucene.analysis.Token; import org.apache.lucene.analysis.TokenStream; +import org.apache.lucene.util.graph.GraphTokenStreamFiniteStrings; public class FixedShingleFilterTest extends BaseTokenStreamTestCase { @@ -227,4 +229,19 @@ public class FixedShingleFilterTest extends BaseTokenStreamTestCase { assertEquals("Shingle size must be between 2 and 4, got 5", e2.getMessage()); } + public void testWithGraphInput() throws IOException { + + TokenStream ts = new CannedTokenStream( + new Token("fuz", 0, 3), + new Token("foo", 1, 4, 6, 2), + new Token("bar", 0, 4, 6), + new Token("baz", 1, 4, 6) + ); + GraphTokenStreamFiniteStrings graph = new GraphTokenStreamFiniteStrings(ts); + Iterator it = graph.getFiniteStrings(); + assertTokenStreamContents(new FixedShingleFilter(it.next(), 2), new String[]{ "fuz foo"}); + assertTokenStreamContents(new FixedShingleFilter(it.next(), 2), new String[]{ "fuz bar", "bar baz"}); + + } + } diff --git a/lucene/core/src/java/org/apache/lucene/util/graph/GraphTokenStreamFiniteStrings.java b/lucene/core/src/java/org/apache/lucene/util/graph/GraphTokenStreamFiniteStrings.java index b6a99958e6a..b2b530d93af 100644 --- a/lucene/core/src/java/org/apache/lucene/util/graph/GraphTokenStreamFiniteStrings.java +++ b/lucene/core/src/java/org/apache/lucene/util/graph/GraphTokenStreamFiniteStrings.java @@ -22,18 +22,16 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.BitSet; import java.util.Collections; -import java.util.HashMap; import java.util.Iterator; import java.util.List; -import java.util.Map; import org.apache.lucene.analysis.TokenStream; -import org.apache.lucene.analysis.tokenattributes.BytesTermAttribute; import org.apache.lucene.analysis.tokenattributes.PositionIncrementAttribute; import org.apache.lucene.analysis.tokenattributes.PositionLengthAttribute; import org.apache.lucene.analysis.tokenattributes.TermToBytesRefAttribute; import org.apache.lucene.index.Term; -import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.ArrayUtil; +import org.apache.lucene.util.AttributeSource; import org.apache.lucene.util.IntsRef; import org.apache.lucene.util.automaton.Automaton; import org.apache.lucene.util.automaton.FiniteStringsIterator; @@ -48,19 +46,18 @@ import static org.apache.lucene.util.automaton.Operations.DEFAULT_MAX_DETERMINIZ * This class also provides helpers to explore the different paths of the {@link Automaton}. */ public final class GraphTokenStreamFiniteStrings { - private final Map idToTerm = new HashMap<>(); - private final Map idToInc = new HashMap<>(); + + private AttributeSource[] tokens = new AttributeSource[4]; private final Automaton det; private final Transition transition = new Transition(); private class FiniteStringsTokenStream extends TokenStream { - private final BytesTermAttribute termAtt = addAttribute(BytesTermAttribute.class); - private final PositionIncrementAttribute posIncAtt = addAttribute(PositionIncrementAttribute.class); private final IntsRef ids; private final int end; private int offset; FiniteStringsTokenStream(final IntsRef ids) { + super(tokens[0].cloneAttributes()); assert ids != null; this.ids = ids; this.offset = ids.offset; @@ -72,13 +69,7 @@ public final class GraphTokenStreamFiniteStrings { if (offset < end) { clearAttributes(); int id = ids.ints[offset]; - termAtt.setBytesRef(idToTerm.get(id)); - - int incr = 1; - if (idToInc.containsKey(id)) { - incr = idToInc.get(id); - } - posIncAtt.setPositionIncrement(incr); + tokens[id].copyTo(this); offset++; return true; } @@ -111,20 +102,26 @@ public final class GraphTokenStreamFiniteStrings { return false; } + /** + * Returns the list of tokens that start at the provided state + */ + public List getTerms(int state) { + int numT = det.initTransition(state, transition); + List tokens = new ArrayList<> (); + for (int i = 0; i < numT; i++) { + det.getNextTransition(transition); + tokens.addAll(Arrays.asList(this.tokens).subList(transition.min, transition.max + 1)); + } + return tokens; + } + /** * Returns the list of terms that start at the provided state */ public Term[] getTerms(String field, int state) { - int numT = det.initTransition(state, transition); - List terms = new ArrayList<> (); - for (int i = 0; i < numT; i++) { - det.getNextTransition(transition); - for (int id = transition.min; id <= transition.max; id++) { - Term term = new Term(field, idToTerm.get(id)); - terms.add(term); - } - } - return terms.toArray(new Term[terms.size()]); + return getTerms(state).stream() + .map(s -> new Term(field, s.addAttribute(TermToBytesRefAttribute.class).getBytesRef())) + .toArray(Term[]::new); } /** @@ -138,9 +135,9 @@ public final class GraphTokenStreamFiniteStrings { /** * Get all finite strings that start at {@code startState} and end at {@code endState}. */ - public Iterator getFiniteStrings(int startState, int endState) throws IOException { + public Iterator getFiniteStrings(int startState, int endState) { final FiniteStringsIterator it = new FiniteStringsIterator(det, startState, endState); - return new Iterator () { + return new Iterator<> () { IntsRef current; boolean finished = false; @@ -202,7 +199,7 @@ public final class GraphTokenStreamFiniteStrings { */ private Automaton build(final TokenStream in) throws IOException { Automaton.Builder builder = new Automaton.Builder(); - final TermToBytesRefAttribute termBytesAtt = in.addAttribute(TermToBytesRefAttribute.class); + final PositionIncrementAttribute posIncAtt = in.addAttribute(PositionIncrementAttribute.class); final PositionLengthAttribute posLengthAtt = in.addAttribute(PositionLengthAttribute.class); @@ -211,6 +208,7 @@ public final class GraphTokenStreamFiniteStrings { int pos = -1; int prevIncr = 1; int state = -1; + int id = -1; int gap = 0; while (in.incrementToken()) { int currentIncr = posIncAtt.getPositionIncrement(); @@ -233,12 +231,23 @@ public final class GraphTokenStreamFiniteStrings { state = builder.createState(); } - BytesRef term = termBytesAtt.getBytesRef(); - int id = getTermID(currentIncr, prevIncr, term); - //System.out.println("Adding transition: " + term.utf8ToString() + "@" + pos + "->" + endPos); + id++; + if (tokens.length < id + 1) { + tokens = ArrayUtil.grow(tokens, id + 1); + } + + tokens[id] = in.cloneAttributes(); builder.addTransition(pos, endPos, id); pos += gap; + // we always produce linear token graphs from getFiniteStrings(), so we need to adjust + // posLength and posIncrement accordingly + tokens[id].addAttribute(PositionLengthAttribute.class).setPositionLength(1); + if (currentIncr == 0) { + // stacked token should have the same increment as original token at this position + tokens[id].addAttribute(PositionIncrementAttribute.class).setPositionIncrement(prevIncr); + } + // only save last increment on non-zero increment in case we have multiple stacked tokens if (currentIncr > 0) { prevIncr = currentIncr; @@ -252,23 +261,6 @@ public final class GraphTokenStreamFiniteStrings { return builder.finish(); } - /** - * Gets an integer id for a given term and saves the position increment if needed. - */ - private int getTermID(int incr, int prevIncr, BytesRef term) { - assert term != null; - boolean isStackedGap = incr == 0 && prevIncr > 1; - int id = idToTerm.size(); - idToTerm.put(id, BytesRef.deepCopyOf(term)); - // stacked token should have the same increment as original token at this position - if (isStackedGap) { - idToInc.put(id, prevIncr); - } else if (incr > 1) { - idToInc.put(id, incr); - } - return id; - } - private static void articulationPointsRecurse(Automaton a, int state, int d, int[] depth, int[] low, int[] parent, BitSet visited, List points) { visited.set(state); diff --git a/lucene/core/src/test/org/apache/lucene/util/graph/TestGraphTokenStreamFiniteStrings.java b/lucene/core/src/test/org/apache/lucene/util/graph/TestGraphTokenStreamFiniteStrings.java index 1739fa0c7d6..d0bb996b565 100644 --- a/lucene/core/src/test/org/apache/lucene/util/graph/TestGraphTokenStreamFiniteStrings.java +++ b/lucene/core/src/test/org/apache/lucene/util/graph/TestGraphTokenStreamFiniteStrings.java @@ -21,8 +21,9 @@ import java.util.Iterator; import org.apache.lucene.analysis.CannedTokenStream; import org.apache.lucene.analysis.Token; import org.apache.lucene.analysis.TokenStream; -import org.apache.lucene.analysis.tokenattributes.BytesTermAttribute; +import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; import org.apache.lucene.analysis.tokenattributes.PositionIncrementAttribute; +import org.apache.lucene.analysis.tokenattributes.PositionLengthAttribute; import org.apache.lucene.index.Term; import org.apache.lucene.util.LuceneTestCase; @@ -44,14 +45,16 @@ public class TestGraphTokenStreamFiniteStrings extends LuceneTestCase { assertNotNull(terms); assertNotNull(increments); assertEquals(terms.length, increments.length); - BytesTermAttribute termAtt = ts.getAttribute(BytesTermAttribute.class); + CharTermAttribute termAtt = ts.getAttribute(CharTermAttribute.class); PositionIncrementAttribute incrAtt = ts.getAttribute(PositionIncrementAttribute.class); + PositionLengthAttribute lenAtt = ts.getAttribute(PositionLengthAttribute.class); int offset = 0; while (ts.incrementToken()) { // verify term and increment assert offset < terms.length; - assertEquals(terms[offset], termAtt.getBytesRef().utf8ToString()); + assertEquals(terms[offset], termAtt.toString()); assertEquals(increments[offset], incrAtt.getPositionIncrement()); + assertEquals(1, lenAtt.getPositionLength()); // we always output linear token streams offset++; }