From ab2fec1642b5ebac66cd5a564552db29d139b54a Mon Sep 17 00:00:00 2001 From: Alan Woodward Date: Mon, 18 Jun 2018 09:29:58 +0100 Subject: [PATCH] LUCENE-8237: Correct handling of position increments in sub-tokenstreams --- .../miscellaneous/ConditionalTokenFilter.java | 60 +++++++++++-------- .../analysis/core/TestRandomChains.java | 20 ++++--- .../TestConditionalTokenFilter.java | 59 +++++++++++++++--- .../lucene/analysis/FilteringTokenFilter.java | 2 - .../analysis/ValidatingTokenFilter.java | 9 ++- 5 files changed, 103 insertions(+), 47 deletions(-) diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/miscellaneous/ConditionalTokenFilter.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/miscellaneous/ConditionalTokenFilter.java index e41ce8268aa..b3ef2ab7041 100644 --- a/lucene/analysis/common/src/java/org/apache/lucene/analysis/miscellaneous/ConditionalTokenFilter.java +++ b/lucene/analysis/common/src/java/org/apache/lucene/analysis/miscellaneous/ConditionalTokenFilter.java @@ -43,24 +43,29 @@ public abstract class ConditionalTokenFilter extends TokenFilter { private final class OneTimeWrapper extends TokenStream { private final OffsetAttribute offsetAtt; + private final PositionIncrementAttribute posIncAtt; public OneTimeWrapper(AttributeSource attributeSource) { super(attributeSource); this.offsetAtt = attributeSource.addAttribute(OffsetAttribute.class); + this.posIncAtt = attributeSource.addAttribute(PositionIncrementAttribute.class); } @Override public boolean incrementToken() throws IOException { if (state == TokenState.PREBUFFERING) { + if (posIncAtt.getPositionIncrement() == 0) { + adjustPosition = true; + posIncAtt.setPositionIncrement(1); + } state = TokenState.DELEGATING; return true; } assert state == TokenState.DELEGATING; - boolean more = input.incrementToken(); - if (more && shouldFilter()) { - return true; - } - if (more) { + if (input.incrementToken()) { + if (shouldFilter()) { + return true; + } endOffset = offsetAtt.endOffset(); bufferedState = captureState(); } @@ -96,10 +101,11 @@ public abstract class ConditionalTokenFilter extends TokenFilter { private boolean lastTokenFiltered; private State bufferedState = null; private boolean exhausted; + private boolean adjustPosition; private State endState = null; private int endOffset; - private PositionIncrementAttribute posIncAtt = addAttribute(PositionIncrementAttribute.class); + private final PositionIncrementAttribute posIncAtt = addAttribute(PositionIncrementAttribute.class); /** * Create a new ConditionalTokenFilter @@ -124,6 +130,7 @@ public abstract class ConditionalTokenFilter extends TokenFilter { this.lastTokenFiltered = false; this.bufferedState = null; this.exhausted = false; + this.adjustPosition = false; this.endOffset = -1; this.endState = null; } @@ -168,16 +175,6 @@ public abstract class ConditionalTokenFilter extends TokenFilter { return false; } if (shouldFilter()) { - // we're chopping the underlying Tokenstream up into fragments, and presenting - // only those parts of it that pass the filter to the delegate, so the delegate is - // in effect seeing multiple tokenstream snippets. Tokenstreams can't have an initial - // position increment of 0, so if the snippet starts on a stacked token we need to - // offset it here and then correct the increment back again after delegation - boolean adjustPosition = false; - if (posIncAtt.getPositionIncrement() == 0) { - posIncAtt.setPositionIncrement(1); - adjustPosition = true; - } lastTokenFiltered = true; state = TokenState.PREBUFFERING; // we determine that the delegate has emitted all the tokens it can at the current @@ -192,20 +189,14 @@ public abstract class ConditionalTokenFilter extends TokenFilter { int posInc = posIncAtt.getPositionIncrement(); posIncAtt.setPositionIncrement(posInc - 1); } + adjustPosition = false; } else { lastTokenFiltered = false; state = TokenState.READING; - if (bufferedState != null) { - delegate.end(); - int posInc = posIncAtt.getPositionIncrement(); - restoreState(bufferedState); - posIncAtt.setPositionIncrement(posIncAtt.getPositionIncrement() + posInc); - bufferedState = null; - return true; - } + return endDelegating(); } - return more || bufferedState != null; + return true; } lastTokenFiltered = false; return true; @@ -216,8 +207,27 @@ public abstract class ConditionalTokenFilter extends TokenFilter { } // no more cached tokens state = TokenState.READING; + return endDelegating(); } } } + private boolean endDelegating() throws IOException { + if (bufferedState == null) { + assert exhausted == true; + return false; + } + delegate.end(); + int posInc = posIncAtt.getPositionIncrement(); + restoreState(bufferedState); + // System.out.println("Buffered posInc: " + posIncAtt.getPositionIncrement() + " Delegated posInc: " + posInc); + posIncAtt.setPositionIncrement(posIncAtt.getPositionIncrement() + posInc); + if (adjustPosition) { + posIncAtt.setPositionIncrement(posIncAtt.getPositionIncrement() - 1); + adjustPosition = false; + } + bufferedState = null; + return true; + } + } diff --git a/lucene/analysis/common/src/test/org/apache/lucene/analysis/core/TestRandomChains.java b/lucene/analysis/common/src/test/org/apache/lucene/analysis/core/TestRandomChains.java index f3a02697f0a..8ab57d7dacf 100644 --- a/lucene/analysis/common/src/test/org/apache/lucene/analysis/core/TestRandomChains.java +++ b/lucene/analysis/common/src/test/org/apache/lucene/analysis/core/TestRandomChains.java @@ -97,7 +97,6 @@ import org.apache.lucene.store.RAMDirectory; import org.apache.lucene.util.AttributeFactory; import org.apache.lucene.util.AttributeSource; import org.apache.lucene.util.CharsRef; -import org.apache.lucene.util.LuceneTestCase; import org.apache.lucene.util.Rethrow; import org.apache.lucene.util.TestUtil; import org.apache.lucene.util.Version; @@ -130,6 +129,9 @@ public class TestRandomChains extends BaseTokenStreamTestCase { // expose inconsistent offsets // https://issues.apache.org/jira/browse/LUCENE-4170 avoidConditionals.add(ShingleFilter.class); + // FlattenGraphFilter changes the output graph entirely, so wrapping it in a condition + // can break position lengths + avoidConditionals.add(FlattenGraphFilter.class); } private static final Map,Predicate> brokenConstructors = new HashMap<>(); @@ -626,7 +628,7 @@ public class TestRandomChains extends BaseTokenStreamTestCase { return sb.toString(); } - private T createComponent(Constructor ctor, Object[] args, StringBuilder descr) { + private T createComponent(Constructor ctor, Object[] args, StringBuilder descr, boolean isConditional) { try { final T instance = ctor.newInstance(args); /* @@ -635,6 +637,9 @@ public class TestRandomChains extends BaseTokenStreamTestCase { } */ descr.append("\n "); + if (isConditional) { + descr.append("Conditional:"); + } descr.append(ctor.getDeclaringClass().getName()); String params = Arrays.deepToString(args); params = params.substring(1, params.length()-1); @@ -673,7 +678,7 @@ public class TestRandomChains extends BaseTokenStreamTestCase { if (broken(ctor, args)) { continue; } - spec.tokenizer = createComponent(ctor, args, descr); + spec.tokenizer = createComponent(ctor, args, descr, false); if (spec.tokenizer != null) { spec.toString = descr.toString(); } @@ -693,7 +698,7 @@ public class TestRandomChains extends BaseTokenStreamTestCase { if (broken(ctor, args)) { continue; } - reader = createComponent(ctor, args, descr); + reader = createComponent(ctor, args, descr, false); if (reader != null) { spec.reader = reader; break; @@ -725,8 +730,7 @@ public class TestRandomChains extends BaseTokenStreamTestCase { if (broken(ctor, args)) { return in; } - descr.append("ConditionalTokenFilter: "); - TokenStream ts = createComponent(ctor, args, descr); + TokenStream ts = createComponent(ctor, args, descr, true); if (ts == null) { return in; } @@ -752,7 +756,7 @@ public class TestRandomChains extends BaseTokenStreamTestCase { if (broken(ctor, args)) { continue; } - final TokenFilter flt = createComponent(ctor, args, descr); + final TokenFilter flt = createComponent(ctor, args, descr, false); if (flt != null) { spec.stream = flt; break; @@ -849,7 +853,6 @@ public class TestRandomChains extends BaseTokenStreamTestCase { String toString; } - @LuceneTestCase.BadApple(bugUrl="https://issues.apache.org/jira/browse/SOLR-12028") // 12-Jun-2018 public void testRandomChains() throws Throwable { int numIterations = TEST_NIGHTLY ? atLeast(20) : 3; Random random = random(); @@ -878,7 +881,6 @@ public class TestRandomChains extends BaseTokenStreamTestCase { } // we might regret this decision... - @LuceneTestCase.BadApple(bugUrl="https://issues.apache.org/jira/browse/SOLR-12028") // 12-Jun-2018 public void testRandomChainsWithLargeStrings() throws Throwable { int numIterations = TEST_NIGHTLY ? atLeast(20) : 3; Random random = random(); diff --git a/lucene/analysis/common/src/test/org/apache/lucene/analysis/miscellaneous/TestConditionalTokenFilter.java b/lucene/analysis/common/src/test/org/apache/lucene/analysis/miscellaneous/TestConditionalTokenFilter.java index e0bbac44794..00a2df19dfc 100644 --- a/lucene/analysis/common/src/test/org/apache/lucene/analysis/miscellaneous/TestConditionalTokenFilter.java +++ b/lucene/analysis/common/src/test/org/apache/lucene/analysis/miscellaneous/TestConditionalTokenFilter.java @@ -43,6 +43,7 @@ import org.apache.lucene.analysis.ngram.NGramTokenizer; import org.apache.lucene.analysis.shingle.FixedShingleFilter; import org.apache.lucene.analysis.shingle.ShingleFilter; import org.apache.lucene.analysis.standard.ClassicTokenizer; +import org.apache.lucene.analysis.standard.StandardTokenizer; import org.apache.lucene.analysis.synonym.SolrSynonymParser; import org.apache.lucene.analysis.synonym.SynonymGraphFilter; import org.apache.lucene.analysis.synonym.SynonymMap; @@ -260,6 +261,7 @@ public class TestConditionalTokenFilter extends BaseTokenStreamTestCase { protected TokenStreamComponents createComponents(String fieldName) { Tokenizer source = new ClassicTokenizer(); TokenStream sink = new ProtectedTermFilter(protectedTerms, source, in -> new ShingleFilter(in, 2)); + sink = new ValidatingTokenFilter(sink, "1"); return new TokenStreamComponents(source, sink); } }; @@ -268,14 +270,53 @@ public class TestConditionalTokenFilter extends BaseTokenStreamTestCase { try (TokenStream ts = analyzer.tokenStream("", input)) { assertTokenStreamContents(ts, new String[]{ - "one", "one two", - "two", - "three", - "four" - }); + "one", "one two", "two", "three", "four" + }, new int[]{ + 0, 0, 4, 8, 14 + }, new int[]{ + 3, 7, 7, 13, 18 + }, new int[]{ + 1, 0, 1, 1, 1 + }, new int[]{ + 1, 2, 1, 1, 1 + }, 18); } } + public void testFilteringWithReadahead() throws IOException { + + CharArraySet protectedTerms = new CharArraySet(2, true); + protectedTerms.add("two"); + protectedTerms.add("two three"); + + Analyzer analyzer = new Analyzer() { + @Override + protected TokenStreamComponents createComponents(String fieldName) { + Tokenizer source = new StandardTokenizer(); + TokenStream sink = new ShingleFilter(source, 3); + sink = new ProtectedTermFilter(protectedTerms, sink, in -> new TypeTokenFilter(in, Collections.singleton("ALL"), true)); + return new TokenStreamComponents(source, sink); + } + }; + + String input = "one two three four"; + + try (TokenStream ts = analyzer.tokenStream("", input)) { + assertTokenStreamContents(ts, new String[]{ + "two", "two three" + }, new int[]{ + 4, 4 + }, new int[]{ + 7, 13 + }, new int[]{ + 2, 0 + }, new int[]{ + 1, 2 + }, 18); + } + + } + public void testMultipleConditionalFilters() throws IOException { TokenStream stream = whitespaceMockTokenizer("Alice Bob Clara David"); TokenStream t = new SkipMatchingFilter(stream, in -> { @@ -311,7 +352,8 @@ public class TestConditionalTokenFilter extends BaseTokenStreamTestCase { @Override protected TokenStreamComponents createComponents(String fieldName) { Tokenizer source = new NGramTokenizer(); - TokenStream sink = new KeywordRepeatFilter(source); + TokenStream sink = new ValidatingTokenFilter(new KeywordRepeatFilter(source), "stage 0"); + sink = new ValidatingTokenFilter(sink, "stage 1"); sink = new RandomSkippingFilter(sink, seed, in -> new TypeTokenFilter(in, Collections.singleton("word"))); sink = new ValidatingTokenFilter(sink, "last stage"); return new TokenStreamComponents(source, sink); @@ -361,11 +403,12 @@ public class TestConditionalTokenFilter extends BaseTokenStreamTestCase { @Override public boolean incrementToken() throws IOException { - if (reset) { + boolean more = input.incrementToken(); + if (more && reset) { assertEquals(1, posIncAtt.getPositionIncrement()); } reset = false; - return input.incrementToken(); + return more; } } diff --git a/lucene/core/src/java/org/apache/lucene/analysis/FilteringTokenFilter.java b/lucene/core/src/java/org/apache/lucene/analysis/FilteringTokenFilter.java index cecad101d8f..e942224056e 100644 --- a/lucene/core/src/java/org/apache/lucene/analysis/FilteringTokenFilter.java +++ b/lucene/core/src/java/org/apache/lucene/analysis/FilteringTokenFilter.java @@ -19,8 +19,6 @@ package org.apache.lucene.analysis; import java.io.IOException; -import org.apache.lucene.analysis.TokenFilter; -import org.apache.lucene.analysis.TokenStream; import org.apache.lucene.analysis.tokenattributes.PositionIncrementAttribute; /** diff --git a/lucene/test-framework/src/java/org/apache/lucene/analysis/ValidatingTokenFilter.java b/lucene/test-framework/src/java/org/apache/lucene/analysis/ValidatingTokenFilter.java index 603ef200825..b29da708189 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/analysis/ValidatingTokenFilter.java +++ b/lucene/test-framework/src/java/org/apache/lucene/analysis/ValidatingTokenFilter.java @@ -62,6 +62,9 @@ public final class ValidatingTokenFilter extends TokenFilter { @Override public boolean incrementToken() throws IOException { + + // System.out.println(name + ": incrementToken()"); + if (!input.incrementToken()) { return false; } @@ -69,15 +72,15 @@ public final class ValidatingTokenFilter extends TokenFilter { int startOffset = 0; int endOffset = 0; int posLen = 0; + + // System.out.println(name + ": " + this); if (posIncAtt != null) { pos += posIncAtt.getPositionIncrement(); if (pos == -1) { - throw new IllegalStateException("first posInc must be > 0"); + throw new IllegalStateException(name + ": first posInc must be > 0"); } } - - // System.out.println(" got token=" + termAtt + " pos=" + pos); if (offsetAtt != null) { startOffset = offsetAtt.startOffset();