Fix _all boosting.

_all boosting used to rely on the fact that the TokenStream doesn't eagerly
consume the input java.io.Reader. This fixes the issue by using binary search
in order to find the right boost given a token's start offset.

Close #4315
This commit is contained in:
Adrien Grand 2013-12-03 16:29:28 +01:00
parent 53be1fe9d0
commit 309ee7d581
3 changed files with 93 additions and 13 deletions

View File

@ -40,14 +40,20 @@ public class AllEntries extends Reader {
public static class Entry {
private final String name;
private final FastStringReader reader;
private final int startOffset;
private final float boost;
public Entry(String name, FastStringReader reader, float boost) {
public Entry(String name, FastStringReader reader, int startOffset, float boost) {
this.name = name;
this.reader = reader;
this.startOffset = startOffset;
this.boost = boost;
}
public int startOffset() {
return startOffset;
}
public String name() {
return this.name;
}
@ -75,7 +81,15 @@ public class AllEntries extends Reader {
if (boost != 1.0f) {
customBoost = true;
}
Entry entry = new Entry(name, new FastStringReader(text), boost);
final int lastStartOffset;
if (entries.isEmpty()) {
lastStartOffset = -1;
} else {
final Entry last = entries.get(entries.size() - 1);
lastStartOffset = last.startOffset() + last.reader().length();
}
final int startOffset = lastStartOffset + 1; // +1 because we insert a space between tokens
Entry entry = new Entry(name, new FastStringReader(text), startOffset, boost);
entries.add(entry);
}
@ -129,8 +143,22 @@ public class AllEntries extends Reader {
return fields;
}
public Entry current() {
return this.current;
// compute the boost for a token with the given startOffset
public float boost(int startOffset) {
int lo = 0, hi = entries.size() - 1;
while (lo <= hi) {
final int mid = (lo + hi) >>> 1;
final int midOffset = entries.get(mid).startOffset();
if (startOffset < midOffset) {
hi = mid - 1;
} else {
lo = mid + 1;
}
}
final int index = Math.max(0, hi); // protection against broken token streams
assert entries.get(index).startOffset() <= startOffset;
assert index == entries.size() - 1 || entries.get(index + 1).startOffset() > startOffset;
return entries.get(index).boost();
}
@Override
@ -186,7 +214,7 @@ public class AllEntries extends Reader {
@Override
public void close() {
if (current != null) {
current.reader().close();
// no need to close, these are readers on strings
current = null;
}
}

View File

@ -22,6 +22,7 @@ package org.elasticsearch.common.lucene.all;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenFilter;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.OffsetAttribute;
import org.apache.lucene.analysis.tokenattributes.PayloadAttribute;
import org.apache.lucene.util.BytesRef;
@ -42,11 +43,13 @@ public final class AllTokenStream extends TokenFilter {
private final AllEntries allEntries;
private final OffsetAttribute offsetAttribute;
private final PayloadAttribute payloadAttribute;
AllTokenStream(TokenStream input, AllEntries allEntries) {
super(input);
this.allEntries = allEntries;
offsetAttribute = addAttribute(OffsetAttribute.class);
payloadAttribute = addAttribute(PayloadAttribute.class);
}
@ -59,15 +62,13 @@ public final class AllTokenStream extends TokenFilter {
if (!input.incrementToken()) {
return false;
}
if (allEntries.current() != null) {
float boost = allEntries.current().boost();
final float boost = allEntries.boost(offsetAttribute.startOffset());
if (boost != 1.0f) {
encodeFloat(boost, payloadSpare.bytes, payloadSpare.offset);
payloadAttribute.setPayload(payloadSpare);
} else {
payloadAttribute.setPayload(null);
}
}
return true;
}

View File

@ -19,6 +19,11 @@
package org.elasticsearch.common.lucene.all;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.core.WhitespaceAnalyzer;
import org.apache.lucene.analysis.payloads.PayloadHelper;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.analysis.tokenattributes.PayloadAttribute;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.StoredField;
@ -27,6 +32,7 @@ import org.apache.lucene.index.*;
import org.apache.lucene.search.*;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.RAMDirectory;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.common.lucene.Lucene;
import org.elasticsearch.test.ElasticsearchTestCase;
import org.junit.Test;
@ -40,6 +46,51 @@ import static org.hamcrest.Matchers.equalTo;
*/
public class SimpleAllTests extends ElasticsearchTestCase {
@Test
public void testBoostOnEagerTokenizer() throws Exception {
AllEntries allEntries = new AllEntries();
allEntries.addText("field1", "all", 2.0f);
allEntries.addText("field2", "your", 1.0f);
allEntries.addText("field1", "boosts", 0.5f);
allEntries.reset();
// whitespace analyzer's tokenizer reads characters eagerly on the contrary to the standard tokenizer
final TokenStream ts = AllTokenStream.allTokenStream("any", allEntries, new WhitespaceAnalyzer(Lucene.VERSION));
final CharTermAttribute termAtt = ts.addAttribute(CharTermAttribute.class);
final PayloadAttribute payloadAtt = ts.addAttribute(PayloadAttribute.class);
ts.reset();
for (int i = 0; i < 3; ++i) {
assertTrue(ts.incrementToken());
final String term;
final float boost;
switch (i) {
case 0:
term = "all";
boost = 2;
break;
case 1:
term = "your";
boost = 1;
break;
case 2:
term = "boosts";
boost = 0.5f;
break;
default:
throw new AssertionError();
}
assertEquals(term, termAtt.toString());
final BytesRef payload = payloadAtt.getPayload();
if (payload == null || payload.length == 0) {
assertEquals(boost, 1f, 0.001f);
} else {
assertEquals(4, payload.length);
final float b = PayloadHelper.decodeFloat(payload.bytes, payload.offset);
assertEquals(boost, b, 0.001f);
}
}
assertFalse(ts.incrementToken());
}
@Test
public void testAllEntriesRead() throws Exception {
AllEntries allEntries = new AllEntries();