LUCENE-8202: Fix positionlength for FixedShingleFilter and add limits to shingle size and count

This commit is contained in:
Alan Woodward 2018-03-25 13:22:34 +01:00
parent 273a829c46
commit bbf8306615
2 changed files with 52 additions and 16 deletions

View File

@ -26,23 +26,22 @@ import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.analysis.tokenattributes.OffsetAttribute;
import org.apache.lucene.analysis.tokenattributes.PositionIncrementAttribute;
import org.apache.lucene.analysis.tokenattributes.PositionLengthAttribute;
import org.apache.lucene.analysis.tokenattributes.TypeAttribute;
import org.apache.lucene.util.AttributeSource;
/**
* A FixedShingleFilter constructs shingles (token n-grams) from a token stream.
* In other words, it creates combinations of tokens as a single token.
*
* <p>
* Unlike the {@link ShingleFilter}, FixedShingleFilter only emits shingles of a
* fixed size, and never emits unigrams, even at the end of a TokenStream. In
* addition, if the filter encounters stacked tokens (eg synonyms), then it will
* output stacked shingles
*
* <p>
* For example, the sentence "please divide this sentence into shingles"
* might be tokenized into shingles "please divide", "divide this",
* "this sentence", "sentence into", and "into shingles".
*
* <p>
* This filter handles position increments &gt; 1 by inserting filler tokens
* (tokens with termtext "_").
*
@ -52,22 +51,27 @@ public final class FixedShingleFilter extends TokenFilter {
private final Deque<Token> tokenPool = new ArrayDeque<>();
private static final int MAX_SHINGLE_STACK_SIZE = 1000;
private static final int MAX_SHINGLE_SIZE = 4;
private final int shingleSize;
private final String tokenSeparator;
private final Token gapToken = new Token(new AttributeSource());
private final Token endToken = new Token(new AttributeSource());
private final PositionIncrementAttribute incAtt = addAttribute(PositionIncrementAttribute.class);
private final PositionLengthAttribute posLenAtt = addAttribute(PositionLengthAttribute.class);
private final OffsetAttribute offsetAtt = addAttribute(OffsetAttribute.class);
private final CharTermAttribute termAtt = addAttribute(CharTermAttribute.class);
private final TypeAttribute typeAtt = addAttribute(TypeAttribute.class);
private Token[] currentShingleTokens;
private int currentShingleStackSize;
private boolean inputStreamExhausted = false;
/**
* Creates a FixedShingleFilter over an input token stream
*
* @param input the input stream
* @param shingleSize the shingle size
*/
@ -77,15 +81,16 @@ public final class FixedShingleFilter extends TokenFilter {
/**
* Creates a FixedShingleFilter over an input token stream
* @param input the input tokenstream
* @param shingleSize the shingle size
* @param tokenSeparator a String to use as a token separator
* @param fillerToken a String to use to represent gaps in the input stream (due to eg stopwords)
*
* @param input the input tokenstream
* @param shingleSize the shingle size
* @param tokenSeparator a String to use as a token separator
* @param fillerToken a String to use to represent gaps in the input stream (due to eg stopwords)
*/
public FixedShingleFilter(TokenStream input, int shingleSize, String tokenSeparator, String fillerToken) {
super(input);
if (shingleSize <= 1) {
throw new IllegalArgumentException("shingleSize must be two or greater");
if (shingleSize <= 1 || shingleSize > MAX_SHINGLE_SIZE) {
throw new IllegalArgumentException("Shingle size must be between 2 and " + MAX_SHINGLE_SIZE + ", got " + shingleSize);
}
this.shingleSize = shingleSize;
this.tokenSeparator = tokenSeparator;
@ -112,7 +117,6 @@ public final class FixedShingleFilter extends TokenFilter {
termAtt.setEmpty();
termAtt.append(currentShingleTokens[0].term());
typeAtt.setType("shingle");
posLenAtt.setPositionLength(shingleSize);
for (int i = 1; i < shingleSize; i++) {
termAtt.append(tokenSeparator).append(currentShingleTokens[i].term());
}
@ -125,6 +129,7 @@ public final class FixedShingleFilter extends TokenFilter {
this.tokenPool.clear();
this.currentShingleTokens[0] = null;
this.inputStreamExhausted = false;
this.currentShingleStackSize = 0;
}
@Override
@ -193,7 +198,7 @@ public final class FixedShingleFilter extends TokenFilter {
Token next = nextTokenInStream(token);
return next == endToken || next.posInc() != 0;
}
private boolean advanceStack() throws IOException {
for (int i = shingleSize - 1; i >= 1; i--) {
if (currentShingleTokens[i] != gapToken && lastInStack(currentShingleTokens[i]) == false) {
@ -201,9 +206,13 @@ public final class FixedShingleFilter extends TokenFilter {
for (int j = i + 1; j < shingleSize; j++) {
currentShingleTokens[j] = nextTokenInGraph(currentShingleTokens[j - 1]);
}
if (currentShingleStackSize++ > MAX_SHINGLE_STACK_SIZE) {
throw new IllegalStateException("Too many shingles (> " + MAX_SHINGLE_STACK_SIZE + ") at term [" + currentShingleTokens[0].term() + "]");
}
return true;
}
}
currentShingleStackSize = 0;
return false;
}
@ -249,8 +258,7 @@ public final class FixedShingleFilter extends TokenFilter {
finishInnerStream();
if (token == null) {
return endToken;
}
else {
} else {
token.nextToken = endToken;
return endToken;
}

View File

@ -43,7 +43,7 @@ public class FixedShingleFilterTest extends BaseTokenStreamTestCase {
new int[]{13, 18, 27, 32, 41,},
new String[]{"shingle", "shingle", "shingle", "shingle", "shingle",},
new int[]{1, 1, 1, 1, 1,},
new int[]{2, 2, 2, 2, 2});
new int[]{1, 1, 1, 1, 1});
}
@ -197,4 +197,32 @@ public class FixedShingleFilterTest extends BaseTokenStreamTestCase {
new int[] { 1, 0, 0, 0, 1, 0, });
}
public void testParameterLimits() {
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> {
new FixedShingleFilter(new CannedTokenStream(), 1);
});
assertEquals("Shingle size must be between 2 and 4, got 1", e.getMessage());
IllegalArgumentException e2 = expectThrows(IllegalArgumentException.class, () -> {
new FixedShingleFilter(new CannedTokenStream(), 5);
});
assertEquals("Shingle size must be between 2 and 4, got 5", e2.getMessage());
}
public void testShingleCountLimits() {
Token[] tokens = new Token[5000];
tokens[0] = new Token("term", 1, 0, 1);
tokens[1] = new Token("term1", 1, 2, 3);
for (int i = 2; i < 5000; i++) {
tokens[i] = new Token("term" + i, 0, 2, 3);
}
Exception e = expectThrows(IllegalStateException.class, () -> {
TokenStream ts = new FixedShingleFilter(new CannedTokenStream(tokens), 2);
ts.reset();
while (ts.incrementToken()) {}
});
assertEquals("Too many shingles (> 1000) at term [term]", e.getMessage());
}
}