LUCENE-8273: Fix end() and posInc handling

This commit is contained in:
Alan Woodward 2018-05-18 13:00:20 +01:00
parent 6826c37669
commit b1ee23c525
5 changed files with 229 additions and 70 deletions

View File

@ -22,6 +22,8 @@ import java.util.function.Function;
import org.apache.lucene.analysis.TokenFilter;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.OffsetAttribute;
import org.apache.lucene.analysis.tokenattributes.PositionIncrementAttribute;
import org.apache.lucene.util.AttributeSource;
/**
@ -35,25 +37,37 @@ import org.apache.lucene.util.AttributeSource;
public abstract class ConditionalTokenFilter extends TokenFilter {
private enum TokenState {
READING, PREBUFFERING, BUFFERING, DELEGATING
READING, PREBUFFERING, DELEGATING
}
private final class OneTimeWrapper extends TokenStream {
private final OffsetAttribute offsetAtt;
public OneTimeWrapper(AttributeSource attributeSource) {
super(attributeSource);
this.offsetAtt = attributeSource.addAttribute(OffsetAttribute.class);
}
@Override
public boolean incrementToken() throws IOException {
if (state == TokenState.PREBUFFERING) {
state = TokenState.BUFFERING;
state = TokenState.DELEGATING;
return true;
}
if (state == TokenState.DELEGATING) {
return false;
assert state == TokenState.DELEGATING;
boolean more = input.incrementToken();
if (more && shouldFilter()) {
return true;
}
return ConditionalTokenFilter.this.incrementToken();
if (more) {
endOffset = offsetAtt.endOffset();
bufferedState = captureState();
}
else {
exhausted = true;
}
return false;
}
@Override
@ -64,15 +78,28 @@ public abstract class ConditionalTokenFilter extends TokenFilter {
@Override
public void end() throws IOException {
endCalled = true;
ConditionalTokenFilter.this.end();
// imitate Tokenizer.end() call - endAttributes, set final offset
if (exhausted) {
if (endCalled == false) {
input.end();
}
endCalled = true;
endOffset = offsetAtt.endOffset();
}
endAttributes();
offsetAtt.setOffset(endOffset, endOffset);
}
}
private final TokenStream delegate;
private TokenState state = TokenState.READING;
private boolean lastTokenFiltered;
private State bufferedState = null;
private boolean exhausted;
private boolean endCalled;
private int endOffset;
private PositionIncrementAttribute posIncAtt = addAttribute(PositionIncrementAttribute.class);
/**
* Create a new BypassingTokenFilter
@ -81,7 +108,7 @@ public abstract class ConditionalTokenFilter extends TokenFilter {
*/
protected ConditionalTokenFilter(TokenStream input, Function<TokenStream, TokenStream> inputFactory) {
super(input);
this.delegate = inputFactory.apply(new OneTimeWrapper(this));
this.delegate = inputFactory.apply(new OneTimeWrapper(this.input));
}
/**
@ -95,13 +122,20 @@ public abstract class ConditionalTokenFilter extends TokenFilter {
this.delegate.reset();
this.state = TokenState.READING;
this.lastTokenFiltered = false;
this.bufferedState = null;
this.exhausted = false;
this.endOffset = -1;
this.endCalled = false;
}
@Override
public void end() throws IOException {
super.end();
if (endCalled == false && lastTokenFiltered) {
if (endCalled == false) {
super.end();
endCalled = true;
}
endOffset = getAttribute(OffsetAttribute.class).endOffset();
if (lastTokenFiltered) {
this.delegate.end();
}
}
@ -116,7 +150,17 @@ public abstract class ConditionalTokenFilter extends TokenFilter {
public final boolean incrementToken() throws IOException {
while (true) {
if (state == TokenState.READING) {
if (bufferedState != null) {
restoreState(bufferedState);
bufferedState = null;
lastTokenFiltered = false;
return true;
}
if (exhausted == true) {
return false;
}
if (input.incrementToken() == false) {
exhausted = true;
return false;
}
if (shouldFilter()) {
@ -128,17 +172,27 @@ public abstract class ConditionalTokenFilter extends TokenFilter {
// to ensure that it can continue to emit more tokens
delegate.reset();
boolean more = delegate.incrementToken();
state = TokenState.DELEGATING;
return more;
if (more) {
state = TokenState.DELEGATING;
}
else {
lastTokenFiltered = false;
state = TokenState.READING;
if (bufferedState != null) {
delegate.end();
int posInc = posIncAtt.getPositionIncrement();
restoreState(bufferedState);
posIncAtt.setPositionIncrement(posIncAtt.getPositionIncrement() + posInc);
bufferedState = null;
return true;
}
}
return more || bufferedState != null;
}
lastTokenFiltered = false;
return true;
}
if (state == TokenState.BUFFERING) {
return input.incrementToken();
}
if (state == TokenState.DELEGATING) {
clearAttributes();
if (delegate.incrementToken()) {
return true;
}

View File

@ -73,6 +73,7 @@ import org.apache.lucene.analysis.hunspell.Dictionary;
import org.apache.lucene.analysis.hunspell.TestHunspellStemFilter;
import org.apache.lucene.analysis.miscellaneous.ConditionalTokenFilter;
import org.apache.lucene.analysis.miscellaneous.DelimitedTermFrequencyTokenFilter;
import org.apache.lucene.analysis.miscellaneous.FingerprintFilter;
import org.apache.lucene.analysis.miscellaneous.HyphenatedWordsFilter;
import org.apache.lucene.analysis.miscellaneous.LimitTokenCountFilter;
import org.apache.lucene.analysis.miscellaneous.LimitTokenOffsetFilter;
@ -93,7 +94,6 @@ import org.apache.lucene.store.RAMDirectory;
import org.apache.lucene.util.AttributeFactory;
import org.apache.lucene.util.AttributeSource;
import org.apache.lucene.util.CharsRef;
import org.apache.lucene.util.LuceneTestCase;
import org.apache.lucene.util.Rethrow;
import org.apache.lucene.util.TestUtil;
import org.apache.lucene.util.Version;
@ -108,7 +108,6 @@ import org.tartarus.snowball.SnowballProgram;
import org.xml.sax.InputSource;
/** tests random analysis chains */
@LuceneTestCase.AwaitsFix(bugUrl="https://issues.apache.org/jira/browse/LUCENE-8273")
public class TestRandomChains extends BaseTokenStreamTestCase {
static List<Constructor<? extends Tokenizer>> tokenizers;
@ -117,6 +116,12 @@ public class TestRandomChains extends BaseTokenStreamTestCase {
private static final Predicate<Object[]> ALWAYS = (objects -> true);
private static final Set<Class<?>> avoidConditionals = new HashSet<>();
static {
// Fingerprint filter needs to consume the whole tokenstream, so conditionals don't make sense here
avoidConditionals.add(FingerprintFilter.class);
}
private static final Map<Constructor<?>,Predicate<Object[]>> brokenConstructors = new HashMap<>();
static {
try {
@ -703,7 +708,7 @@ public class TestRandomChains extends BaseTokenStreamTestCase {
while (true) {
final Constructor<? extends TokenFilter> ctor = tokenfilters.get(random.nextInt(tokenfilters.size()));
if (random.nextBoolean()) {
if (random.nextBoolean() && avoidConditionals.contains(ctor.getDeclaringClass()) == false) {
long seed = random.nextLong();
spec.stream = new ConditionalTokenFilter(spec.stream, in -> {
final Object args[] = newFilterArgs(random, in, ctor.getParameterTypes());

View File

@ -19,16 +19,27 @@ package org.apache.lucene.analysis.miscellaneous;
import java.io.IOException;
import java.io.StringReader;
import java.util.Collections;
import java.util.Random;
import java.util.function.Function;
import java.util.regex.Pattern;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.BaseTokenStreamTestCase;
import org.apache.lucene.analysis.CannedTokenStream;
import org.apache.lucene.analysis.CharArraySet;
import org.apache.lucene.analysis.CharacterUtils;
import org.apache.lucene.analysis.FilteringTokenFilter;
import org.apache.lucene.analysis.MockAnalyzer;
import org.apache.lucene.analysis.Token;
import org.apache.lucene.analysis.TokenFilter;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.Tokenizer;
import org.apache.lucene.analysis.ValidatingTokenFilter;
import org.apache.lucene.analysis.core.TypeTokenFilter;
import org.apache.lucene.analysis.ngram.NGramTokenizer;
import org.apache.lucene.analysis.shingle.ShingleFilter;
import org.apache.lucene.analysis.standard.ClassicTokenizer;
import org.apache.lucene.analysis.synonym.SolrSynonymParser;
import org.apache.lucene.analysis.synonym.SynonymGraphFilter;
import org.apache.lucene.analysis.synonym.SynonymMap;
@ -77,23 +88,23 @@ public class TestConditionalTokenFilter extends BaseTokenStreamTestCase {
}
}
private class SkipMatchingFilter extends ConditionalTokenFilter {
private final Pattern pattern;
private final CharTermAttribute termAtt = addAttribute(CharTermAttribute.class);
SkipMatchingFilter(TokenStream input, Function<TokenStream, TokenStream> inputFactory, String termRegex) {
super(input, inputFactory);
pattern = Pattern.compile(termRegex);
}
@Override
protected boolean shouldFilter() throws IOException {
return pattern.matcher(termAtt.toString()).matches() == false;
}
}
public void testSimple() throws IOException {
CannedTokenStream cts = new CannedTokenStream(
new Token("Alice", 1, 0, 5),
new Token("Bob", 1, 6, 9),
new Token("Clara", 1, 10, 15),
new Token("David", 1, 16, 21)
);
TokenStream t = new ConditionalTokenFilter(cts, AssertingLowerCaseFilter::new) {
CharTermAttribute termAtt = addAttribute(CharTermAttribute.class);
@Override
protected boolean shouldFilter() throws IOException {
return termAtt.toString().contains("o") == false;
}
};
TokenStream stream = whitespaceMockTokenizer("Alice Bob Clara David");
TokenStream t = new SkipMatchingFilter(stream, AssertingLowerCaseFilter::new, ".*o.*");
assertTokenStreamContents(t, new String[]{ "alice", "Bob", "clara", "david" });
assertTrue(closed);
assertTrue(reset);
@ -103,6 +114,7 @@ public class TestConditionalTokenFilter extends BaseTokenStreamTestCase {
private final class TokenSplitter extends TokenFilter {
final CharTermAttribute termAtt = addAttribute(CharTermAttribute.class);
State state = null;
String half;
protected TokenSplitter(TokenStream input) {
@ -112,6 +124,7 @@ public class TestConditionalTokenFilter extends BaseTokenStreamTestCase {
@Override
public boolean incrementToken() throws IOException {
if (half == null) {
state = captureState();
if (input.incrementToken() == false) {
return false;
}
@ -119,6 +132,7 @@ public class TestConditionalTokenFilter extends BaseTokenStreamTestCase {
termAtt.setLength(4);
return true;
}
restoreState(state);
termAtt.setEmpty().append(half);
half = null;
return true;
@ -126,21 +140,8 @@ public class TestConditionalTokenFilter extends BaseTokenStreamTestCase {
}
public void testMultitokenWrapping() throws IOException {
CannedTokenStream cts = new CannedTokenStream(
new Token("tokenpos1", 0, 9),
new Token("tokenpos2", 10, 19),
new Token("tokenpos3", 20, 29),
new Token("tokenpos4", 30, 39)
);
TokenStream ts = new ConditionalTokenFilter(cts, TokenSplitter::new) {
final CharTermAttribute termAtt = addAttribute(CharTermAttribute.class);
@Override
protected boolean shouldFilter() throws IOException {
return termAtt.toString().contains("2") == false;
}
};
TokenStream stream = whitespaceMockTokenizer("tokenpos1 tokenpos2 tokenpos3 tokenpos4");
TokenStream ts = new SkipMatchingFilter(stream, TokenSplitter::new, ".*2.*");
assertTokenStreamContents(ts, new String[]{
"toke", "npos1", "tokenpos2", "toke", "npos3", "toke", "npos4"
});
@ -194,13 +195,7 @@ public class TestConditionalTokenFilter extends BaseTokenStreamTestCase {
public void testWrapGraphs() throws Exception {
CannedTokenStream cts = new CannedTokenStream(
new Token("a", 0, 1),
new Token("b", 2, 3),
new Token("c", 4, 5),
new Token("d", 6, 7),
new Token("e", 8, 9)
);
TokenStream stream = whitespaceMockTokenizer("a b c d e");
SynonymMap sm;
try (Analyzer analyzer = new MockAnalyzer(random())) {
@ -209,13 +204,7 @@ public class TestConditionalTokenFilter extends BaseTokenStreamTestCase {
sm = parser.build();
}
TokenStream ts = new ConditionalTokenFilter(cts, in -> new SynonymGraphFilter(in, sm, true)) {
CharTermAttribute termAtt = addAttribute(CharTermAttribute.class);
@Override
protected boolean shouldFilter() throws IOException {
return "c".equals(termAtt.toString()) == false;
}
};
TokenStream ts = new SkipMatchingFilter(stream, in -> new SynonymGraphFilter(in, sm, true), "c");
assertTokenStreamContents(ts, new String[]{
"f", "a", "b", "c", "d", "e"
@ -230,4 +219,115 @@ public class TestConditionalTokenFilter extends BaseTokenStreamTestCase {
}
public void testReadaheadWithNoFiltering() throws IOException {
Analyzer analyzer = new Analyzer() {
@Override
protected TokenStreamComponents createComponents(String fieldName) {
Tokenizer source = new ClassicTokenizer();
TokenStream sink = new ConditionalTokenFilter(source, in -> new ShingleFilter(in, 2)) {
@Override
protected boolean shouldFilter() throws IOException {
return true;
}
};
return new TokenStreamComponents(source, sink);
}
};
String input = "one two three four";
try (TokenStream ts = analyzer.tokenStream("", input)) {
assertTokenStreamContents(ts, new String[]{
"one", "one two",
"two", "two three",
"three", "three four",
"four"
});
}
}
public void testReadaheadWithFiltering() throws IOException {
CharArraySet exclusions = new CharArraySet(2, true);
exclusions.add("three");
Analyzer analyzer = new Analyzer() {
@Override
protected TokenStreamComponents createComponents(String fieldName) {
Tokenizer source = new ClassicTokenizer();
TokenStream sink = new TermExclusionFilter(exclusions, source, in -> new ShingleFilter(in, 2));
return new TokenStreamComponents(source, sink);
}
};
String input = "one two three four";
try (TokenStream ts = analyzer.tokenStream("", input)) {
assertTokenStreamContents(ts, new String[]{
"one", "one two",
"two",
"three",
"four"
});
}
}
public void testMultipleConditionalFilters() throws IOException {
TokenStream stream = whitespaceMockTokenizer("Alice Bob Clara David");
TokenStream t = new SkipMatchingFilter(stream, in -> {
TruncateTokenFilter truncateFilter = new TruncateTokenFilter(in, 2);
return new AssertingLowerCaseFilter(truncateFilter);
}, ".*o.*");
assertTokenStreamContents(t, new String[]{"al", "Bob", "cl", "da"});
assertTrue(closed);
assertTrue(reset);
assertTrue(ended);
}
public void testFilteredTokenFilters() throws IOException {
CharArraySet exclusions = new CharArraySet(2, true);
exclusions.add("foobar");
TokenStream ts = whitespaceMockTokenizer("wuthering foobar abc");
ts = new TermExclusionFilter(exclusions, ts, in -> new LengthFilter(in, 1, 4));
assertTokenStreamContents(ts, new String[]{ "foobar", "abc" });
ts = whitespaceMockTokenizer("foobar abc");
ts = new TermExclusionFilter(exclusions, ts, in -> new LengthFilter(in, 1, 4));
assertTokenStreamContents(ts, new String[]{ "foobar", "abc" });
}
public void testConsistentOffsets() throws IOException {
long seed = random().nextLong();
Analyzer analyzer = new Analyzer() {
@Override
protected TokenStreamComponents createComponents(String fieldName) {
Tokenizer source = new NGramTokenizer();
TokenStream sink = new KeywordRepeatFilter(source);
sink = new ConditionalTokenFilter(sink, in -> new TypeTokenFilter(in, Collections.singleton("word"))) {
Random random = new Random(seed);
@Override
protected boolean shouldFilter() throws IOException {
return random.nextBoolean();
}
@Override
public void reset() throws IOException {
super.reset();
random = new Random(seed);
}
};
sink = new ValidatingTokenFilter(sink, "last stage");
return new TokenStreamComponents(source, sink);
}
};
checkRandomData(random(), analyzer, 1);
}
}

View File

@ -194,7 +194,7 @@ public abstract class BaseTokenStreamTestCase extends LuceneTestCase {
checkClearAtt.getAndResetClearCalled(); // reset it, because we called clearAttribute() before
assertTrue("token "+i+" does not exist", ts.incrementToken());
assertTrue("clearAttributes() was not called correctly in TokenStream chain", checkClearAtt.getAndResetClearCalled());
assertTrue("clearAttributes() was not called correctly in TokenStream chain at token " + i, checkClearAtt.getAndResetClearCalled());
assertEquals("term "+i, output[i], termAtt.toString());
if (startOffsets != null) {
@ -438,7 +438,7 @@ public abstract class BaseTokenStreamTestCase extends LuceneTestCase {
} finally {
// consume correctly
ts.reset();
while (ts.incrementToken()) {}
while (ts.incrementToken()) { }
ts.end();
ts.close();
}

View File

@ -96,11 +96,11 @@ public final class ValidatingTokenFilter extends TokenFilter {
if (!posToStartOffset.containsKey(pos)) {
// First time we've seen a token leaving from this position:
posToStartOffset.put(pos, startOffset);
//System.out.println(" + s " + pos + " -> " + startOffset);
// System.out.println(name + " + s " + pos + " -> " + startOffset);
} else {
// We've seen a token leaving from this position
// before; verify the startOffset is the same:
//System.out.println(" + vs " + pos + " -> " + startOffset);
// System.out.println(name + " + vs " + pos + " -> " + startOffset);
final int oldStartOffset = posToStartOffset.get(pos);
if (oldStartOffset != startOffset) {
throw new IllegalStateException(name + ": inconsistent startOffset at pos=" + pos + ": " + oldStartOffset + " vs " + startOffset + "; token=" + termAtt);
@ -112,11 +112,11 @@ public final class ValidatingTokenFilter extends TokenFilter {
if (!posToEndOffset.containsKey(endPos)) {
// First time we've seen a token arriving to this position:
posToEndOffset.put(endPos, endOffset);
//System.out.println(" + e " + endPos + " -> " + endOffset);
//System.out.println(name + " + e " + endPos + " -> " + endOffset);
} else {
// We've seen a token arriving to this position
// before; verify the endOffset is the same:
//System.out.println(" + ve " + endPos + " -> " + endOffset);
//System.out.println(name + " + ve " + endPos + " -> " + endOffset);
final int oldEndOffset = posToEndOffset.get(endPos);
if (oldEndOffset != endOffset) {
throw new IllegalStateException(name + ": inconsistent endOffset at pos=" + endPos + ": " + oldEndOffset + " vs " + endOffset + "; token=" + termAtt);