diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/miscellaneous/ConditionalTokenFilter.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/miscellaneous/ConditionalTokenFilter.java index 6f9ea244fd5..7de4fbde645 100644 --- a/lucene/analysis/common/src/java/org/apache/lucene/analysis/miscellaneous/ConditionalTokenFilter.java +++ b/lucene/analysis/common/src/java/org/apache/lucene/analysis/miscellaneous/ConditionalTokenFilter.java @@ -80,10 +80,10 @@ public abstract class ConditionalTokenFilter extends TokenFilter { public void end() throws IOException { // imitate Tokenizer.end() call - endAttributes, set final offset if (exhausted) { - if (endCalled == false) { + if (endState == null) { input.end(); + endState = captureState(); } - endCalled = true; endOffset = offsetAtt.endOffset(); } endAttributes(); @@ -96,7 +96,7 @@ public abstract class ConditionalTokenFilter extends TokenFilter { private boolean lastTokenFiltered; private State bufferedState = null; private boolean exhausted; - private boolean endCalled; + private State endState = null; private int endOffset; private PositionIncrementAttribute posIncAtt = addAttribute(PositionIncrementAttribute.class); @@ -125,18 +125,22 @@ public abstract class ConditionalTokenFilter extends TokenFilter { this.bufferedState = null; this.exhausted = false; this.endOffset = -1; - this.endCalled = false; + this.endState = null; } @Override public void end() throws IOException { - if (endCalled == false) { + if (endState == null) { super.end(); - endCalled = true; + endState = captureState(); + } + else { + restoreState(endState); } endOffset = getAttribute(OffsetAttribute.class).endOffset(); if (lastTokenFiltered) { this.delegate.end(); + endState = captureState(); } } diff --git a/lucene/analysis/common/src/test/org/apache/lucene/analysis/miscellaneous/TestConditionalTokenFilter.java b/lucene/analysis/common/src/test/org/apache/lucene/analysis/miscellaneous/TestConditionalTokenFilter.java index fed7f68769b..511c725e6be 100644 --- a/lucene/analysis/common/src/test/org/apache/lucene/analysis/miscellaneous/TestConditionalTokenFilter.java +++ b/lucene/analysis/common/src/test/org/apache/lucene/analysis/miscellaneous/TestConditionalTokenFilter.java @@ -37,7 +37,10 @@ import org.apache.lucene.analysis.TokenStream; import org.apache.lucene.analysis.Tokenizer; import org.apache.lucene.analysis.ValidatingTokenFilter; import org.apache.lucene.analysis.core.TypeTokenFilter; +import org.apache.lucene.analysis.de.GermanStemFilter; +import org.apache.lucene.analysis.in.IndicNormalizationFilter; import org.apache.lucene.analysis.ngram.NGramTokenizer; +import org.apache.lucene.analysis.shingle.FixedShingleFilter; import org.apache.lucene.analysis.shingle.ShingleFilter; import org.apache.lucene.analysis.standard.ClassicTokenizer; import org.apache.lucene.analysis.synonym.SolrSynonymParser; @@ -308,19 +311,7 @@ public class TestConditionalTokenFilter extends BaseTokenStreamTestCase { protected TokenStreamComponents createComponents(String fieldName) { Tokenizer source = new NGramTokenizer(); TokenStream sink = new KeywordRepeatFilter(source); - sink = new ConditionalTokenFilter(sink, in -> new TypeTokenFilter(in, Collections.singleton("word"))) { - Random random = new Random(seed); - @Override - protected boolean shouldFilter() throws IOException { - return random.nextBoolean(); - } - - @Override - public void reset() throws IOException { - super.reset(); - random = new Random(seed); - } - }; + sink = new RandomSkippingFilter(sink, seed, in -> new TypeTokenFilter(in, Collections.singleton("word"))); sink = new ValidatingTokenFilter(sink, "last stage"); return new TokenStreamComponents(source, sink); } @@ -330,4 +321,64 @@ public class TestConditionalTokenFilter extends BaseTokenStreamTestCase { } + public void testEndWithShingles() throws IOException { + TokenStream ts = whitespaceMockTokenizer("cyk jvboq \u092e\u0962\u093f"); + ts = new GermanStemFilter(ts); + ts = new NonRandomSkippingFilter(ts, in -> new FixedShingleFilter(in, 2), true, false, true); + ts = new NonRandomSkippingFilter(ts, IndicNormalizationFilter::new, true); + + assertTokenStreamContents(ts, new String[]{"jvboq"}); + } + + private static class RandomSkippingFilter extends ConditionalTokenFilter { + + Random random; + final long seed; + + protected RandomSkippingFilter(TokenStream input, long seed, Function inputFactory) { + super(input, inputFactory); + this.seed = seed; + this.random = new Random(seed); + } + + @Override + protected boolean shouldFilter() throws IOException { + return random.nextBoolean(); + } + + @Override + public void reset() throws IOException { + super.reset(); + random = new Random(seed); + } + } + + private static class NonRandomSkippingFilter extends ConditionalTokenFilter { + + final boolean[] shouldFilters; + int pos; + + /** + * Create a new BypassingTokenFilter + * + * @param input the input TokenStream + * @param inputFactory a factory function to create a new instance of the TokenFilter to wrap + */ + protected NonRandomSkippingFilter(TokenStream input, Function inputFactory, boolean... shouldFilters) { + super(input, inputFactory); + this.shouldFilters = shouldFilters; + } + + @Override + protected boolean shouldFilter() throws IOException { + return shouldFilters[pos++ % shouldFilters.length]; + } + + @Override + public void reset() throws IOException { + super.reset(); + pos = 0; + } + } + }