diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/miscellaneous/ConditionalTokenFilter.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/miscellaneous/ConditionalTokenFilter.java index e11530d7923..c8b91dc4b04 100644 --- a/lucene/analysis/common/src/java/org/apache/lucene/analysis/miscellaneous/ConditionalTokenFilter.java +++ b/lucene/analysis/common/src/java/org/apache/lucene/analysis/miscellaneous/ConditionalTokenFilter.java @@ -22,6 +22,8 @@ import java.util.function.Function; import org.apache.lucene.analysis.TokenFilter; import org.apache.lucene.analysis.TokenStream; +import org.apache.lucene.analysis.tokenattributes.OffsetAttribute; +import org.apache.lucene.analysis.tokenattributes.PositionIncrementAttribute; import org.apache.lucene.util.AttributeSource; /** @@ -35,25 +37,37 @@ import org.apache.lucene.util.AttributeSource; public abstract class ConditionalTokenFilter extends TokenFilter { private enum TokenState { - READING, PREBUFFERING, BUFFERING, DELEGATING + READING, PREBUFFERING, DELEGATING } private final class OneTimeWrapper extends TokenStream { + private final OffsetAttribute offsetAtt; + public OneTimeWrapper(AttributeSource attributeSource) { super(attributeSource); + this.offsetAtt = attributeSource.addAttribute(OffsetAttribute.class); } @Override public boolean incrementToken() throws IOException { if (state == TokenState.PREBUFFERING) { - state = TokenState.BUFFERING; + state = TokenState.DELEGATING; return true; } - if (state == TokenState.DELEGATING) { - return false; + assert state == TokenState.DELEGATING; + boolean more = input.incrementToken(); + if (more && shouldFilter()) { + return true; } - return ConditionalTokenFilter.this.incrementToken(); + if (more) { + endOffset = offsetAtt.endOffset(); + bufferedState = captureState(); + } + else { + exhausted = true; + } + return false; } @Override @@ -64,15 +78,28 @@ public abstract class ConditionalTokenFilter extends TokenFilter { @Override public void end() throws IOException { - endCalled = true; - ConditionalTokenFilter.this.end(); + // imitate Tokenizer.end() call - endAttributes, set final offset + if (exhausted) { + if (endCalled == false) { + input.end(); + } + endCalled = true; + endOffset = offsetAtt.endOffset(); + } + endAttributes(); + offsetAtt.setOffset(endOffset, endOffset); } } private final TokenStream delegate; private TokenState state = TokenState.READING; private boolean lastTokenFiltered; + private State bufferedState = null; + private boolean exhausted; private boolean endCalled; + private int endOffset; + + private PositionIncrementAttribute posIncAtt = addAttribute(PositionIncrementAttribute.class); /** * Create a new BypassingTokenFilter @@ -81,7 +108,7 @@ public abstract class ConditionalTokenFilter extends TokenFilter { */ protected ConditionalTokenFilter(TokenStream input, Function inputFactory) { super(input); - this.delegate = inputFactory.apply(new OneTimeWrapper(this)); + this.delegate = inputFactory.apply(new OneTimeWrapper(this.input)); } /** @@ -95,13 +122,20 @@ public abstract class ConditionalTokenFilter extends TokenFilter { this.delegate.reset(); this.state = TokenState.READING; this.lastTokenFiltered = false; + this.bufferedState = null; + this.exhausted = false; + this.endOffset = -1; this.endCalled = false; } @Override public void end() throws IOException { - super.end(); - if (endCalled == false && lastTokenFiltered) { + if (endCalled == false) { + super.end(); + endCalled = true; + } + endOffset = getAttribute(OffsetAttribute.class).endOffset(); + if (lastTokenFiltered) { this.delegate.end(); } } @@ -116,7 +150,17 @@ public abstract class ConditionalTokenFilter extends TokenFilter { public final boolean incrementToken() throws IOException { while (true) { if (state == TokenState.READING) { + if (bufferedState != null) { + restoreState(bufferedState); + bufferedState = null; + lastTokenFiltered = false; + return true; + } + if (exhausted == true) { + return false; + } if (input.incrementToken() == false) { + exhausted = true; return false; } if (shouldFilter()) { @@ -128,17 +172,27 @@ public abstract class ConditionalTokenFilter extends TokenFilter { // to ensure that it can continue to emit more tokens delegate.reset(); boolean more = delegate.incrementToken(); - state = TokenState.DELEGATING; - return more; + if (more) { + state = TokenState.DELEGATING; + } + else { + lastTokenFiltered = false; + state = TokenState.READING; + if (bufferedState != null) { + delegate.end(); + int posInc = posIncAtt.getPositionIncrement(); + restoreState(bufferedState); + posIncAtt.setPositionIncrement(posIncAtt.getPositionIncrement() + posInc); + bufferedState = null; + return true; + } + } + return more || bufferedState != null; } lastTokenFiltered = false; return true; } - if (state == TokenState.BUFFERING) { - return input.incrementToken(); - } if (state == TokenState.DELEGATING) { - clearAttributes(); if (delegate.incrementToken()) { return true; } diff --git a/lucene/analysis/common/src/test/org/apache/lucene/analysis/core/TestRandomChains.java b/lucene/analysis/common/src/test/org/apache/lucene/analysis/core/TestRandomChains.java index 83eb739bfcf..e393e5e97de 100644 --- a/lucene/analysis/common/src/test/org/apache/lucene/analysis/core/TestRandomChains.java +++ b/lucene/analysis/common/src/test/org/apache/lucene/analysis/core/TestRandomChains.java @@ -73,6 +73,7 @@ import org.apache.lucene.analysis.hunspell.Dictionary; import org.apache.lucene.analysis.hunspell.TestHunspellStemFilter; import org.apache.lucene.analysis.miscellaneous.ConditionalTokenFilter; import org.apache.lucene.analysis.miscellaneous.DelimitedTermFrequencyTokenFilter; +import org.apache.lucene.analysis.miscellaneous.FingerprintFilter; import org.apache.lucene.analysis.miscellaneous.HyphenatedWordsFilter; import org.apache.lucene.analysis.miscellaneous.LimitTokenCountFilter; import org.apache.lucene.analysis.miscellaneous.LimitTokenOffsetFilter; @@ -93,7 +94,6 @@ import org.apache.lucene.store.RAMDirectory; import org.apache.lucene.util.AttributeFactory; import org.apache.lucene.util.AttributeSource; import org.apache.lucene.util.CharsRef; -import org.apache.lucene.util.LuceneTestCase; import org.apache.lucene.util.Rethrow; import org.apache.lucene.util.TestUtil; import org.apache.lucene.util.Version; @@ -108,7 +108,6 @@ import org.tartarus.snowball.SnowballProgram; import org.xml.sax.InputSource; /** tests random analysis chains */ -@LuceneTestCase.AwaitsFix(bugUrl="https://issues.apache.org/jira/browse/LUCENE-8273") public class TestRandomChains extends BaseTokenStreamTestCase { static List> tokenizers; @@ -117,6 +116,12 @@ public class TestRandomChains extends BaseTokenStreamTestCase { private static final Predicate ALWAYS = (objects -> true); + private static final Set> avoidConditionals = new HashSet<>(); + static { + // Fingerprint filter needs to consume the whole tokenstream, so conditionals don't make sense here + avoidConditionals.add(FingerprintFilter.class); + } + private static final Map,Predicate> brokenConstructors = new HashMap<>(); static { try { @@ -703,7 +708,7 @@ public class TestRandomChains extends BaseTokenStreamTestCase { while (true) { final Constructor ctor = tokenfilters.get(random.nextInt(tokenfilters.size())); - if (random.nextBoolean()) { + if (random.nextBoolean() && avoidConditionals.contains(ctor.getDeclaringClass()) == false) { long seed = random.nextLong(); spec.stream = new ConditionalTokenFilter(spec.stream, in -> { final Object args[] = newFilterArgs(random, in, ctor.getParameterTypes()); diff --git a/lucene/analysis/common/src/test/org/apache/lucene/analysis/miscellaneous/TestConditionalTokenFilter.java b/lucene/analysis/common/src/test/org/apache/lucene/analysis/miscellaneous/TestConditionalTokenFilter.java index e804676fb7c..02d8a780be5 100644 --- a/lucene/analysis/common/src/test/org/apache/lucene/analysis/miscellaneous/TestConditionalTokenFilter.java +++ b/lucene/analysis/common/src/test/org/apache/lucene/analysis/miscellaneous/TestConditionalTokenFilter.java @@ -19,16 +19,27 @@ package org.apache.lucene.analysis.miscellaneous; import java.io.IOException; import java.io.StringReader; +import java.util.Collections; +import java.util.Random; +import java.util.function.Function; +import java.util.regex.Pattern; import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.analysis.BaseTokenStreamTestCase; import org.apache.lucene.analysis.CannedTokenStream; +import org.apache.lucene.analysis.CharArraySet; import org.apache.lucene.analysis.CharacterUtils; import org.apache.lucene.analysis.FilteringTokenFilter; import org.apache.lucene.analysis.MockAnalyzer; import org.apache.lucene.analysis.Token; import org.apache.lucene.analysis.TokenFilter; import org.apache.lucene.analysis.TokenStream; +import org.apache.lucene.analysis.Tokenizer; +import org.apache.lucene.analysis.ValidatingTokenFilter; +import org.apache.lucene.analysis.core.TypeTokenFilter; +import org.apache.lucene.analysis.ngram.NGramTokenizer; +import org.apache.lucene.analysis.shingle.ShingleFilter; +import org.apache.lucene.analysis.standard.ClassicTokenizer; import org.apache.lucene.analysis.synonym.SolrSynonymParser; import org.apache.lucene.analysis.synonym.SynonymGraphFilter; import org.apache.lucene.analysis.synonym.SynonymMap; @@ -77,23 +88,23 @@ public class TestConditionalTokenFilter extends BaseTokenStreamTestCase { } } + private class SkipMatchingFilter extends ConditionalTokenFilter { + private final Pattern pattern; + private final CharTermAttribute termAtt = addAttribute(CharTermAttribute.class); + SkipMatchingFilter(TokenStream input, Function inputFactory, String termRegex) { + super(input, inputFactory); + pattern = Pattern.compile(termRegex); + } + + @Override + protected boolean shouldFilter() throws IOException { + return pattern.matcher(termAtt.toString()).matches() == false; + } + } + public void testSimple() throws IOException { - - CannedTokenStream cts = new CannedTokenStream( - new Token("Alice", 1, 0, 5), - new Token("Bob", 1, 6, 9), - new Token("Clara", 1, 10, 15), - new Token("David", 1, 16, 21) - ); - - TokenStream t = new ConditionalTokenFilter(cts, AssertingLowerCaseFilter::new) { - CharTermAttribute termAtt = addAttribute(CharTermAttribute.class); - @Override - protected boolean shouldFilter() throws IOException { - return termAtt.toString().contains("o") == false; - } - }; - + TokenStream stream = whitespaceMockTokenizer("Alice Bob Clara David"); + TokenStream t = new SkipMatchingFilter(stream, AssertingLowerCaseFilter::new, ".*o.*"); assertTokenStreamContents(t, new String[]{ "alice", "Bob", "clara", "david" }); assertTrue(closed); assertTrue(reset); @@ -103,6 +114,7 @@ public class TestConditionalTokenFilter extends BaseTokenStreamTestCase { private final class TokenSplitter extends TokenFilter { final CharTermAttribute termAtt = addAttribute(CharTermAttribute.class); + State state = null; String half; protected TokenSplitter(TokenStream input) { @@ -112,6 +124,7 @@ public class TestConditionalTokenFilter extends BaseTokenStreamTestCase { @Override public boolean incrementToken() throws IOException { if (half == null) { + state = captureState(); if (input.incrementToken() == false) { return false; } @@ -119,6 +132,7 @@ public class TestConditionalTokenFilter extends BaseTokenStreamTestCase { termAtt.setLength(4); return true; } + restoreState(state); termAtt.setEmpty().append(half); half = null; return true; @@ -126,21 +140,8 @@ public class TestConditionalTokenFilter extends BaseTokenStreamTestCase { } public void testMultitokenWrapping() throws IOException { - CannedTokenStream cts = new CannedTokenStream( - new Token("tokenpos1", 0, 9), - new Token("tokenpos2", 10, 19), - new Token("tokenpos3", 20, 29), - new Token("tokenpos4", 30, 39) - ); - - TokenStream ts = new ConditionalTokenFilter(cts, TokenSplitter::new) { - final CharTermAttribute termAtt = addAttribute(CharTermAttribute.class); - @Override - protected boolean shouldFilter() throws IOException { - return termAtt.toString().contains("2") == false; - } - }; - + TokenStream stream = whitespaceMockTokenizer("tokenpos1 tokenpos2 tokenpos3 tokenpos4"); + TokenStream ts = new SkipMatchingFilter(stream, TokenSplitter::new, ".*2.*"); assertTokenStreamContents(ts, new String[]{ "toke", "npos1", "tokenpos2", "toke", "npos3", "toke", "npos4" }); @@ -194,13 +195,7 @@ public class TestConditionalTokenFilter extends BaseTokenStreamTestCase { public void testWrapGraphs() throws Exception { - CannedTokenStream cts = new CannedTokenStream( - new Token("a", 0, 1), - new Token("b", 2, 3), - new Token("c", 4, 5), - new Token("d", 6, 7), - new Token("e", 8, 9) - ); + TokenStream stream = whitespaceMockTokenizer("a b c d e"); SynonymMap sm; try (Analyzer analyzer = new MockAnalyzer(random())) { @@ -209,13 +204,7 @@ public class TestConditionalTokenFilter extends BaseTokenStreamTestCase { sm = parser.build(); } - TokenStream ts = new ConditionalTokenFilter(cts, in -> new SynonymGraphFilter(in, sm, true)) { - CharTermAttribute termAtt = addAttribute(CharTermAttribute.class); - @Override - protected boolean shouldFilter() throws IOException { - return "c".equals(termAtt.toString()) == false; - } - }; + TokenStream ts = new SkipMatchingFilter(stream, in -> new SynonymGraphFilter(in, sm, true), "c"); assertTokenStreamContents(ts, new String[]{ "f", "a", "b", "c", "d", "e" @@ -230,4 +219,115 @@ public class TestConditionalTokenFilter extends BaseTokenStreamTestCase { } + public void testReadaheadWithNoFiltering() throws IOException { + Analyzer analyzer = new Analyzer() { + @Override + protected TokenStreamComponents createComponents(String fieldName) { + Tokenizer source = new ClassicTokenizer(); + TokenStream sink = new ConditionalTokenFilter(source, in -> new ShingleFilter(in, 2)) { + @Override + protected boolean shouldFilter() throws IOException { + return true; + } + }; + return new TokenStreamComponents(source, sink); + } + }; + + String input = "one two three four"; + + try (TokenStream ts = analyzer.tokenStream("", input)) { + assertTokenStreamContents(ts, new String[]{ + "one", "one two", + "two", "two three", + "three", "three four", + "four" + }); + } + } + + public void testReadaheadWithFiltering() throws IOException { + + CharArraySet exclusions = new CharArraySet(2, true); + exclusions.add("three"); + + Analyzer analyzer = new Analyzer() { + @Override + protected TokenStreamComponents createComponents(String fieldName) { + Tokenizer source = new ClassicTokenizer(); + TokenStream sink = new TermExclusionFilter(exclusions, source, in -> new ShingleFilter(in, 2)); + return new TokenStreamComponents(source, sink); + } + }; + + String input = "one two three four"; + + try (TokenStream ts = analyzer.tokenStream("", input)) { + assertTokenStreamContents(ts, new String[]{ + "one", "one two", + "two", + "three", + "four" + }); + } + } + + public void testMultipleConditionalFilters() throws IOException { + TokenStream stream = whitespaceMockTokenizer("Alice Bob Clara David"); + TokenStream t = new SkipMatchingFilter(stream, in -> { + TruncateTokenFilter truncateFilter = new TruncateTokenFilter(in, 2); + return new AssertingLowerCaseFilter(truncateFilter); + }, ".*o.*"); + + assertTokenStreamContents(t, new String[]{"al", "Bob", "cl", "da"}); + assertTrue(closed); + assertTrue(reset); + assertTrue(ended); + } + + public void testFilteredTokenFilters() throws IOException { + + CharArraySet exclusions = new CharArraySet(2, true); + exclusions.add("foobar"); + + TokenStream ts = whitespaceMockTokenizer("wuthering foobar abc"); + ts = new TermExclusionFilter(exclusions, ts, in -> new LengthFilter(in, 1, 4)); + assertTokenStreamContents(ts, new String[]{ "foobar", "abc" }); + + ts = whitespaceMockTokenizer("foobar abc"); + ts = new TermExclusionFilter(exclusions, ts, in -> new LengthFilter(in, 1, 4)); + assertTokenStreamContents(ts, new String[]{ "foobar", "abc" }); + + } + + public void testConsistentOffsets() throws IOException { + + long seed = random().nextLong(); + Analyzer analyzer = new Analyzer() { + @Override + protected TokenStreamComponents createComponents(String fieldName) { + Tokenizer source = new NGramTokenizer(); + TokenStream sink = new KeywordRepeatFilter(source); + sink = new ConditionalTokenFilter(sink, in -> new TypeTokenFilter(in, Collections.singleton("word"))) { + Random random = new Random(seed); + @Override + protected boolean shouldFilter() throws IOException { + return random.nextBoolean(); + } + + @Override + public void reset() throws IOException { + super.reset(); + random = new Random(seed); + } + }; + sink = new ValidatingTokenFilter(sink, "last stage"); + return new TokenStreamComponents(source, sink); + } + }; + + checkRandomData(random(), analyzer, 1); + + } + } diff --git a/lucene/test-framework/src/java/org/apache/lucene/analysis/BaseTokenStreamTestCase.java b/lucene/test-framework/src/java/org/apache/lucene/analysis/BaseTokenStreamTestCase.java index e4897c59f00..c936a359523 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/analysis/BaseTokenStreamTestCase.java +++ b/lucene/test-framework/src/java/org/apache/lucene/analysis/BaseTokenStreamTestCase.java @@ -194,7 +194,7 @@ public abstract class BaseTokenStreamTestCase extends LuceneTestCase { checkClearAtt.getAndResetClearCalled(); // reset it, because we called clearAttribute() before assertTrue("token "+i+" does not exist", ts.incrementToken()); - assertTrue("clearAttributes() was not called correctly in TokenStream chain", checkClearAtt.getAndResetClearCalled()); + assertTrue("clearAttributes() was not called correctly in TokenStream chain at token " + i, checkClearAtt.getAndResetClearCalled()); assertEquals("term "+i, output[i], termAtt.toString()); if (startOffsets != null) { @@ -438,7 +438,7 @@ public abstract class BaseTokenStreamTestCase extends LuceneTestCase { } finally { // consume correctly ts.reset(); - while (ts.incrementToken()) {} + while (ts.incrementToken()) { } ts.end(); ts.close(); } diff --git a/lucene/test-framework/src/java/org/apache/lucene/analysis/ValidatingTokenFilter.java b/lucene/test-framework/src/java/org/apache/lucene/analysis/ValidatingTokenFilter.java index 7901eea5bf2..603ef200825 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/analysis/ValidatingTokenFilter.java +++ b/lucene/test-framework/src/java/org/apache/lucene/analysis/ValidatingTokenFilter.java @@ -96,11 +96,11 @@ public final class ValidatingTokenFilter extends TokenFilter { if (!posToStartOffset.containsKey(pos)) { // First time we've seen a token leaving from this position: posToStartOffset.put(pos, startOffset); - //System.out.println(" + s " + pos + " -> " + startOffset); + // System.out.println(name + " + s " + pos + " -> " + startOffset); } else { // We've seen a token leaving from this position // before; verify the startOffset is the same: - //System.out.println(" + vs " + pos + " -> " + startOffset); + // System.out.println(name + " + vs " + pos + " -> " + startOffset); final int oldStartOffset = posToStartOffset.get(pos); if (oldStartOffset != startOffset) { throw new IllegalStateException(name + ": inconsistent startOffset at pos=" + pos + ": " + oldStartOffset + " vs " + startOffset + "; token=" + termAtt); @@ -112,11 +112,11 @@ public final class ValidatingTokenFilter extends TokenFilter { if (!posToEndOffset.containsKey(endPos)) { // First time we've seen a token arriving to this position: posToEndOffset.put(endPos, endOffset); - //System.out.println(" + e " + endPos + " -> " + endOffset); + //System.out.println(name + " + e " + endPos + " -> " + endOffset); } else { // We've seen a token arriving to this position // before; verify the endOffset is the same: - //System.out.println(" + ve " + endPos + " -> " + endOffset); + //System.out.println(name + " + ve " + endPos + " -> " + endOffset); final int oldEndOffset = posToEndOffset.get(endPos); if (oldEndOffset != endOffset) { throw new IllegalStateException(name + ": inconsistent endOffset at pos=" + endPos + ": " + oldEndOffset + " vs " + endOffset + "; token=" + termAtt);