Fix _all boosting.
_all boosting used to rely on the fact that the TokenStream doesn't eagerly consume the input java.io.Reader. This fixes the issue by using binary search in order to find the right boost given a token's start offset. Close #4315
This commit is contained in:
parent
53be1fe9d0
commit
309ee7d581
|
@ -40,14 +40,20 @@ public class AllEntries extends Reader {
|
||||||
public static class Entry {
|
public static class Entry {
|
||||||
private final String name;
|
private final String name;
|
||||||
private final FastStringReader reader;
|
private final FastStringReader reader;
|
||||||
|
private final int startOffset;
|
||||||
private final float boost;
|
private final float boost;
|
||||||
|
|
||||||
public Entry(String name, FastStringReader reader, float boost) {
|
public Entry(String name, FastStringReader reader, int startOffset, float boost) {
|
||||||
this.name = name;
|
this.name = name;
|
||||||
this.reader = reader;
|
this.reader = reader;
|
||||||
|
this.startOffset = startOffset;
|
||||||
this.boost = boost;
|
this.boost = boost;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public int startOffset() {
|
||||||
|
return startOffset;
|
||||||
|
}
|
||||||
|
|
||||||
public String name() {
|
public String name() {
|
||||||
return this.name;
|
return this.name;
|
||||||
}
|
}
|
||||||
|
@ -75,7 +81,15 @@ public class AllEntries extends Reader {
|
||||||
if (boost != 1.0f) {
|
if (boost != 1.0f) {
|
||||||
customBoost = true;
|
customBoost = true;
|
||||||
}
|
}
|
||||||
Entry entry = new Entry(name, new FastStringReader(text), boost);
|
final int lastStartOffset;
|
||||||
|
if (entries.isEmpty()) {
|
||||||
|
lastStartOffset = -1;
|
||||||
|
} else {
|
||||||
|
final Entry last = entries.get(entries.size() - 1);
|
||||||
|
lastStartOffset = last.startOffset() + last.reader().length();
|
||||||
|
}
|
||||||
|
final int startOffset = lastStartOffset + 1; // +1 because we insert a space between tokens
|
||||||
|
Entry entry = new Entry(name, new FastStringReader(text), startOffset, boost);
|
||||||
entries.add(entry);
|
entries.add(entry);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -129,8 +143,22 @@ public class AllEntries extends Reader {
|
||||||
return fields;
|
return fields;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Entry current() {
|
// compute the boost for a token with the given startOffset
|
||||||
return this.current;
|
public float boost(int startOffset) {
|
||||||
|
int lo = 0, hi = entries.size() - 1;
|
||||||
|
while (lo <= hi) {
|
||||||
|
final int mid = (lo + hi) >>> 1;
|
||||||
|
final int midOffset = entries.get(mid).startOffset();
|
||||||
|
if (startOffset < midOffset) {
|
||||||
|
hi = mid - 1;
|
||||||
|
} else {
|
||||||
|
lo = mid + 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
final int index = Math.max(0, hi); // protection against broken token streams
|
||||||
|
assert entries.get(index).startOffset() <= startOffset;
|
||||||
|
assert index == entries.size() - 1 || entries.get(index + 1).startOffset() > startOffset;
|
||||||
|
return entries.get(index).boost();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -186,7 +214,7 @@ public class AllEntries extends Reader {
|
||||||
@Override
|
@Override
|
||||||
public void close() {
|
public void close() {
|
||||||
if (current != null) {
|
if (current != null) {
|
||||||
current.reader().close();
|
// no need to close, these are readers on strings
|
||||||
current = null;
|
current = null;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,6 +22,7 @@ package org.elasticsearch.common.lucene.all;
|
||||||
import org.apache.lucene.analysis.Analyzer;
|
import org.apache.lucene.analysis.Analyzer;
|
||||||
import org.apache.lucene.analysis.TokenFilter;
|
import org.apache.lucene.analysis.TokenFilter;
|
||||||
import org.apache.lucene.analysis.TokenStream;
|
import org.apache.lucene.analysis.TokenStream;
|
||||||
|
import org.apache.lucene.analysis.tokenattributes.OffsetAttribute;
|
||||||
import org.apache.lucene.analysis.tokenattributes.PayloadAttribute;
|
import org.apache.lucene.analysis.tokenattributes.PayloadAttribute;
|
||||||
import org.apache.lucene.util.BytesRef;
|
import org.apache.lucene.util.BytesRef;
|
||||||
|
|
||||||
|
@ -42,11 +43,13 @@ public final class AllTokenStream extends TokenFilter {
|
||||||
|
|
||||||
private final AllEntries allEntries;
|
private final AllEntries allEntries;
|
||||||
|
|
||||||
|
private final OffsetAttribute offsetAttribute;
|
||||||
private final PayloadAttribute payloadAttribute;
|
private final PayloadAttribute payloadAttribute;
|
||||||
|
|
||||||
AllTokenStream(TokenStream input, AllEntries allEntries) {
|
AllTokenStream(TokenStream input, AllEntries allEntries) {
|
||||||
super(input);
|
super(input);
|
||||||
this.allEntries = allEntries;
|
this.allEntries = allEntries;
|
||||||
|
offsetAttribute = addAttribute(OffsetAttribute.class);
|
||||||
payloadAttribute = addAttribute(PayloadAttribute.class);
|
payloadAttribute = addAttribute(PayloadAttribute.class);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -59,14 +62,12 @@ public final class AllTokenStream extends TokenFilter {
|
||||||
if (!input.incrementToken()) {
|
if (!input.incrementToken()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (allEntries.current() != null) {
|
final float boost = allEntries.boost(offsetAttribute.startOffset());
|
||||||
float boost = allEntries.current().boost();
|
if (boost != 1.0f) {
|
||||||
if (boost != 1.0f) {
|
encodeFloat(boost, payloadSpare.bytes, payloadSpare.offset);
|
||||||
encodeFloat(boost, payloadSpare.bytes, payloadSpare.offset);
|
payloadAttribute.setPayload(payloadSpare);
|
||||||
payloadAttribute.setPayload(payloadSpare);
|
} else {
|
||||||
} else {
|
payloadAttribute.setPayload(null);
|
||||||
payloadAttribute.setPayload(null);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,6 +19,11 @@
|
||||||
|
|
||||||
package org.elasticsearch.common.lucene.all;
|
package org.elasticsearch.common.lucene.all;
|
||||||
|
|
||||||
|
import org.apache.lucene.analysis.TokenStream;
|
||||||
|
import org.apache.lucene.analysis.core.WhitespaceAnalyzer;
|
||||||
|
import org.apache.lucene.analysis.payloads.PayloadHelper;
|
||||||
|
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
|
||||||
|
import org.apache.lucene.analysis.tokenattributes.PayloadAttribute;
|
||||||
import org.apache.lucene.document.Document;
|
import org.apache.lucene.document.Document;
|
||||||
import org.apache.lucene.document.Field;
|
import org.apache.lucene.document.Field;
|
||||||
import org.apache.lucene.document.StoredField;
|
import org.apache.lucene.document.StoredField;
|
||||||
|
@ -27,6 +32,7 @@ import org.apache.lucene.index.*;
|
||||||
import org.apache.lucene.search.*;
|
import org.apache.lucene.search.*;
|
||||||
import org.apache.lucene.store.Directory;
|
import org.apache.lucene.store.Directory;
|
||||||
import org.apache.lucene.store.RAMDirectory;
|
import org.apache.lucene.store.RAMDirectory;
|
||||||
|
import org.apache.lucene.util.BytesRef;
|
||||||
import org.elasticsearch.common.lucene.Lucene;
|
import org.elasticsearch.common.lucene.Lucene;
|
||||||
import org.elasticsearch.test.ElasticsearchTestCase;
|
import org.elasticsearch.test.ElasticsearchTestCase;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
@ -40,6 +46,51 @@ import static org.hamcrest.Matchers.equalTo;
|
||||||
*/
|
*/
|
||||||
public class SimpleAllTests extends ElasticsearchTestCase {
|
public class SimpleAllTests extends ElasticsearchTestCase {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testBoostOnEagerTokenizer() throws Exception {
|
||||||
|
AllEntries allEntries = new AllEntries();
|
||||||
|
allEntries.addText("field1", "all", 2.0f);
|
||||||
|
allEntries.addText("field2", "your", 1.0f);
|
||||||
|
allEntries.addText("field1", "boosts", 0.5f);
|
||||||
|
allEntries.reset();
|
||||||
|
// whitespace analyzer's tokenizer reads characters eagerly on the contrary to the standard tokenizer
|
||||||
|
final TokenStream ts = AllTokenStream.allTokenStream("any", allEntries, new WhitespaceAnalyzer(Lucene.VERSION));
|
||||||
|
final CharTermAttribute termAtt = ts.addAttribute(CharTermAttribute.class);
|
||||||
|
final PayloadAttribute payloadAtt = ts.addAttribute(PayloadAttribute.class);
|
||||||
|
ts.reset();
|
||||||
|
for (int i = 0; i < 3; ++i) {
|
||||||
|
assertTrue(ts.incrementToken());
|
||||||
|
final String term;
|
||||||
|
final float boost;
|
||||||
|
switch (i) {
|
||||||
|
case 0:
|
||||||
|
term = "all";
|
||||||
|
boost = 2;
|
||||||
|
break;
|
||||||
|
case 1:
|
||||||
|
term = "your";
|
||||||
|
boost = 1;
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
term = "boosts";
|
||||||
|
boost = 0.5f;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw new AssertionError();
|
||||||
|
}
|
||||||
|
assertEquals(term, termAtt.toString());
|
||||||
|
final BytesRef payload = payloadAtt.getPayload();
|
||||||
|
if (payload == null || payload.length == 0) {
|
||||||
|
assertEquals(boost, 1f, 0.001f);
|
||||||
|
} else {
|
||||||
|
assertEquals(4, payload.length);
|
||||||
|
final float b = PayloadHelper.decodeFloat(payload.bytes, payload.offset);
|
||||||
|
assertEquals(boost, b, 0.001f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assertFalse(ts.incrementToken());
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testAllEntriesRead() throws Exception {
|
public void testAllEntriesRead() throws Exception {
|
||||||
AllEntries allEntries = new AllEntries();
|
AllEntries allEntries = new AllEntries();
|
||||||
|
|
Loading…
Reference in New Issue