FunctionScore: Refactor RandomScoreFunction to be consistent, and return values in rang [0.0, 1.0]

RandomScoreFunction previously relied on the order the documents were
iterated in from Lucene. This caused changes in ordering, with the same
seed, if documents moved to different segments. With this change, a
murmur32 hash of the _uid for each document is used as the "random"
value. Also, the hash is adjusted so as to only return values between
0.0 and 1.0 to enable easier manipulation to fit into users' scoring
models.

closes #6907, #7446
This commit is contained in:
Ryan Ernst 2014-08-27 08:37:25 -07:00
parent 3aa72f2738
commit 65afa1d93b
7 changed files with 113 additions and 253 deletions

View File

@ -140,8 +140,12 @@ not.
===== Random ===== Random
The `random_score` generates scores via a pseudo random number algorithm The `random_score` generates scores using a hash of the `_uid` field,
that is initialized with a `seed`. with a `seed` for variation. If `seed` is not specified, the current
time is used.
NOTE: Using this feature will load field data for `_uid`, which can
be a memory intensive operation since the values are unique.
[source,js] [source,js]
-------------------------------------------------- --------------------------------------------------

View File

@ -20,57 +20,54 @@ package org.elasticsearch.common.lucene.search.function;
import org.apache.lucene.index.AtomicReaderContext; import org.apache.lucene.index.AtomicReaderContext;
import org.apache.lucene.search.Explanation; import org.apache.lucene.search.Explanation;
import org.apache.lucene.util.StringHelper;
import org.elasticsearch.index.fielddata.AtomicFieldData;
import org.elasticsearch.index.fielddata.IndexFieldData;
import org.elasticsearch.index.fielddata.SortedBinaryDocValues;
/** /**
* Pseudo randomly generate a score for each {@link #score}. * Pseudo randomly generate a score for each {@link #score}.
*/ */
public class RandomScoreFunction extends ScoreFunction { public class RandomScoreFunction extends ScoreFunction {
private final PRNG prng; private int originalSeed;
private int saltedSeed;
private final IndexFieldData<?> uidFieldData;
private SortedBinaryDocValues uidByteData;
public RandomScoreFunction(long seed) { /**
* Creates a RandomScoreFunction.
*
* @param seed A seed for randomness
* @param salt A value to salt the seed with, ideally unique to the running node/index
* @param uidFieldData The field data for _uid to use for generating consistent random values for the same id
*/
public RandomScoreFunction(int seed, int salt, IndexFieldData<?> uidFieldData) {
super(CombineFunction.MULT); super(CombineFunction.MULT);
this.prng = new PRNG(seed); this.originalSeed = seed;
this.saltedSeed = seed ^ salt;
this.uidFieldData = uidFieldData;
if (uidFieldData == null) throw new NullPointerException("uid missing");
} }
@Override @Override
public void setNextReader(AtomicReaderContext context) { public void setNextReader(AtomicReaderContext context) {
// intentionally does nothing AtomicFieldData leafData = uidFieldData.load(context);
uidByteData = leafData.getBytesValues();
if (uidByteData == null) throw new NullPointerException("failed to get uid byte data");
} }
@Override @Override
public double score(int docId, float subQueryScore) { public double score(int docId, float subQueryScore) {
return prng.nextFloat(); uidByteData.setDocument(docId);
int hash = StringHelper.murmurhash3_x86_32(uidByteData.valueAt(0), saltedSeed);
return (hash & 0x00FFFFFF) / (float)(1 << 24); // only use the lower 24 bits to construct a float from 0.0-1.0
} }
@Override @Override
public Explanation explainScore(int docId, float subQueryScore) { public Explanation explainScore(int docId, float subQueryScore) {
Explanation exp = new Explanation(); Explanation exp = new Explanation();
exp.setDescription("random score function (seed: " + prng.originalSeed + ")"); exp.setDescription("random score function (seed: " + originalSeed + ")");
return exp; return exp;
} }
/**
* A non thread-safe PRNG
*/
static class PRNG {
private static final long multiplier = 0x5DEECE66DL;
private static final long addend = 0xBL;
private static final long mask = (1L << 48) - 1;
final long originalSeed;
long seed;
PRNG(long seed) {
this.originalSeed = seed;
this.seed = (seed ^ multiplier) & mask;
}
public float nextFloat() {
seed = (seed * multiplier + addend) & mask;
return seed / (float)(1 << 24);
}
}
} }

View File

@ -75,7 +75,7 @@ public class ScoreFunctionBuilders {
return (new FactorBuilder()).boostFactor(boost); return (new FactorBuilder()).boostFactor(boost);
} }
public static RandomScoreFunctionBuilder randomFunction(long seed) { public static RandomScoreFunctionBuilder randomFunction(int seed) {
return (new RandomScoreFunctionBuilder()).seed(seed); return (new RandomScoreFunctionBuilder()).seed(seed);
} }

View File

@ -28,7 +28,7 @@ import java.io.IOException;
*/ */
public class RandomScoreFunctionBuilder implements ScoreFunctionBuilder { public class RandomScoreFunctionBuilder implements ScoreFunctionBuilder {
private Long seed = null; private Integer seed = null;
public RandomScoreFunctionBuilder() { public RandomScoreFunctionBuilder() {
} }
@ -44,7 +44,7 @@ public class RandomScoreFunctionBuilder implements ScoreFunctionBuilder {
* *
* @param seed The seed. * @param seed The seed.
*/ */
public RandomScoreFunctionBuilder seed(long seed) { public RandomScoreFunctionBuilder seed(int seed) {
this.seed = seed; this.seed = seed;
return this; return this;
} }
@ -53,7 +53,7 @@ public class RandomScoreFunctionBuilder implements ScoreFunctionBuilder {
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(getName()); builder.startObject(getName());
if (seed != null) { if (seed != null) {
builder.field("seed", seed.longValue()); builder.field("seed", seed.intValue());
} }
return builder.endObject(); return builder.endObject();
} }

View File

@ -24,6 +24,8 @@ import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.lucene.search.function.RandomScoreFunction; import org.elasticsearch.common.lucene.search.function.RandomScoreFunction;
import org.elasticsearch.common.lucene.search.function.ScoreFunction; import org.elasticsearch.common.lucene.search.function.ScoreFunction;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.fielddata.IndexFieldData;
import org.elasticsearch.index.mapper.FieldMapper;
import org.elasticsearch.index.query.QueryParseContext; import org.elasticsearch.index.query.QueryParseContext;
import org.elasticsearch.index.query.QueryParsingException; import org.elasticsearch.index.query.QueryParsingException;
import org.elasticsearch.index.query.functionscore.ScoreFunctionParser; import org.elasticsearch.index.query.functionscore.ScoreFunctionParser;
@ -32,9 +34,6 @@ import org.elasticsearch.search.internal.SearchContext;
import java.io.IOException; import java.io.IOException;
/**
*
*/
public class RandomScoreFunctionParser implements ScoreFunctionParser { public class RandomScoreFunctionParser implements ScoreFunctionParser {
public static String[] NAMES = { "random_score", "randomScore" }; public static String[] NAMES = { "random_score", "randomScore" };
@ -51,7 +50,7 @@ public class RandomScoreFunctionParser implements ScoreFunctionParser {
@Override @Override
public ScoreFunction parse(QueryParseContext parseContext, XContentParser parser) throws IOException, QueryParsingException { public ScoreFunction parse(QueryParseContext parseContext, XContentParser parser) throws IOException, QueryParsingException {
long seed = -1; int seed = -1;
String currentFieldName = null; String currentFieldName = null;
XContentParser.Token token; XContentParser.Token token;
@ -60,7 +59,7 @@ public class RandomScoreFunctionParser implements ScoreFunctionParser {
currentFieldName = parser.currentName(); currentFieldName = parser.currentName();
} else if (token.isValue()) { } else if (token.isValue()) {
if ("seed".equals(currentFieldName)) { if ("seed".equals(currentFieldName)) {
seed = parser.longValue(); seed = parser.intValue();
} else { } else {
throw new QueryParsingException(parseContext.index(), NAMES[0] + " query does not support [" + currentFieldName + "]"); throw new QueryParsingException(parseContext.index(), NAMES[0] + " query does not support [" + currentFieldName + "]");
} }
@ -68,20 +67,15 @@ public class RandomScoreFunctionParser implements ScoreFunctionParser {
} }
if (seed == -1) { if (seed == -1) {
seed = parseContext.nowInMillis(); seed = (int)parseContext.nowInMillis();
} }
ShardId shardId = SearchContext.current().indexShard().shardId(); ShardId shardId = SearchContext.current().indexShard().shardId();
seed = salt(seed, shardId.index().name(), shardId.id()); int salt = (shardId.index().name().hashCode() << 10) | shardId.id();
return new RandomScoreFunction(seed); final FieldMapper<?> mapper = SearchContext.current().mapperService().smartNameFieldMapper("_uid");
IndexFieldData<?> uidFieldData = SearchContext.current().fieldData().getForField(mapper);
return new RandomScoreFunction(seed, salt, uidFieldData);
} }
public static long salt(long seed, String index, int shardId) {
long salt = index.hashCode();
salt = salt << 32;
salt |= shardId;
return salt^seed;
}
} }

View File

@ -1,193 +0,0 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.common.lucene.search.function;
import com.google.common.collect.Lists;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.TextField;
import org.apache.lucene.index.*;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.store.RAMDirectory;
import org.elasticsearch.common.lucene.Lucene;
import org.elasticsearch.common.lucene.search.Queries;
import org.elasticsearch.test.ElasticsearchTestCase;
import org.junit.After;
import org.junit.Test;
import java.io.IOException;
import java.util.List;
import static org.hamcrest.Matchers.*;
/**
* Test {@link RandomScoreFunction}
*/
public class RandomScoreFunctionTests extends ElasticsearchTestCase {
private final String[] ids = { "1", "2", "3" };
private IndexWriter writer;
private AtomicReader reader;
@After
public void closeReaderAndWriterIfUsed() throws IOException {
if (reader != null) {
reader.close();
}
if (writer != null) {
writer.close();
}
}
/**
* Create a "mock" {@link IndexSearcher} that uses an in-memory directory
* containing three documents whose IDs are "1", "2", and "3" respectively.
* @return Never {@code null}
* @throws IOException if an unexpected error occurs while mocking
*/
private IndexSearcher mockSearcher() throws IOException {
writer = new IndexWriter(new RAMDirectory(), new IndexWriterConfig(Lucene.VERSION, Lucene.STANDARD_ANALYZER));
for (String id : ids) {
Document document = new Document();
document.add(new TextField("_id", id, Field.Store.YES));
writer.addDocument(document);
}
reader = SlowCompositeReaderWrapper.wrap(DirectoryReader.open(writer, true));
return new IndexSearcher(reader);
}
/**
* Given the same seed, the pseudo random number generator should match on
* each use given the same number of invocations.
*/
@Test
public void testPrngNextFloatIsConsistent() {
long seed = randomLong();
RandomScoreFunction.PRNG prng = new RandomScoreFunction.PRNG(seed);
RandomScoreFunction.PRNG prng2 = new RandomScoreFunction.PRNG(seed);
// The seed will be changing the entire time, so each value should be
// different
assertThat(prng.nextFloat(), equalTo(prng2.nextFloat()));
assertThat(prng.nextFloat(), equalTo(prng2.nextFloat()));
assertThat(prng.nextFloat(), equalTo(prng2.nextFloat()));
assertThat(prng.nextFloat(), equalTo(prng2.nextFloat()));
}
@Test
public void testPrngNextFloatSometimesFirstIsGreaterThanSecond() {
boolean firstWasGreater = false;
// Since the results themselves are intended to be random, we cannot
// just do @Repeat(iterations = 100) because some iterations are
// expected to fail
for (int i = 0; i < 100; ++i) {
long seed = randomLong();
RandomScoreFunction.PRNG prng = new RandomScoreFunction.PRNG(seed);
float firstRandom = prng.nextFloat();
float secondRandom = prng.nextFloat();
if (firstRandom > secondRandom) {
firstWasGreater = true;
}
}
assertTrue("First value was never greater than the second value", firstWasGreater);
}
@Test
public void testPrngNextFloatSometimesFirstIsLessThanSecond() {
boolean firstWasLess = false;
// Since the results themselves are intended to be random, we cannot
// just do @Repeat(iterations = 100) because some iterations are
// expected to fail
for (int i = 0; i < 1000; ++i) {
long seed = randomLong();
RandomScoreFunction.PRNG prng = new RandomScoreFunction.PRNG(seed);
float firstRandom = prng.nextFloat();
float secondRandom = prng.nextFloat();
if (firstRandom < secondRandom) {
firstWasLess = true;
}
}
assertTrue("First value was never less than the second value", firstWasLess);
}
@Test
public void testScorerResultsInRandomOrder() throws IOException {
List<String> idsNotSpotted = Lists.newArrayList(ids);
IndexSearcher searcher = mockSearcher();
// Since the results themselves are intended to be random, we cannot
// just do @Repeat(iterations = 100) because some iterations are
// expected to fail
for (int i = 0; i < 100; ++i) {
// Randomly seeded to keep trying to shuffle without walking through
// values
RandomScoreFunction function = new RandomScoreFunction(randomLong());
// fulfilling contract
function.setNextReader(reader.getContext());
FunctionScoreQuery query = new FunctionScoreQuery(Queries.newMatchAllQuery(), function);
// Testing that we get a random result
TopDocs docs = searcher.search(query, 1);
String id = reader.document(docs.scoreDocs[0].doc).getField("_id").stringValue();
if (idsNotSpotted.remove(id) && idsNotSpotted.isEmpty()) {
// short circuit test because we succeeded
break;
}
}
assertThat(idsNotSpotted, empty());
}
@Test
public void testExplainScoreReportsOriginalSeed() {
long seed = randomLong();
Explanation subExplanation = new Explanation();
RandomScoreFunction function = new RandomScoreFunction(seed);
// Trigger a random call to change the seed to ensure that we are
// reporting the _original_ seed
function.score(0, 1.0f);
// Generate the randomScore explanation
Explanation randomExplanation = function.explainScore(0, subExplanation.getValue());
// Original seed should be there
assertThat(randomExplanation.getDescription(), containsString("" + seed));
}
}

View File

@ -18,7 +18,9 @@
*/ */
package org.elasticsearch.search.functionscore; package org.elasticsearch.search.functionscore;
import org.apache.lucene.search.Explanation;
import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.index.query.functionscore.random.RandomScoreFunctionBuilder;
import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits; import org.elasticsearch.search.SearchHits;
import org.elasticsearch.test.ElasticsearchIntegrationTest; import org.elasticsearch.test.ElasticsearchIntegrationTest;
@ -34,14 +36,13 @@ import static org.elasticsearch.index.query.QueryBuilders.*;
import static org.elasticsearch.index.query.functionscore.ScoreFunctionBuilders.*; import static org.elasticsearch.index.query.functionscore.ScoreFunctionBuilders.*;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertNoFailures; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertNoFailures;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.*;
import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.arrayContaining;
import static org.hamcrest.Matchers.nullValue;
public class RandomScoreFunctionTests extends ElasticsearchIntegrationTest { public class RandomScoreFunctionTests extends ElasticsearchIntegrationTest {
@Test @Test
public void consistentHitsWithSameSeed() throws Exception { public void testConsistentHitsWithSameSeed() throws Exception {
createIndex("test"); createIndex("test");
ensureGreen(); // make sure we are done otherwise preference could change? ensureGreen(); // make sure we are done otherwise preference could change?
int docCount = randomIntBetween(100, 200); int docCount = randomIntBetween(100, 200);
@ -52,7 +53,7 @@ public class RandomScoreFunctionTests extends ElasticsearchIntegrationTest {
refresh(); refresh();
int outerIters = scaledRandomIntBetween(10, 20); int outerIters = scaledRandomIntBetween(10, 20);
for (int o = 0; o < outerIters; o++) { for (int o = 0; o < outerIters; o++) {
final long seed = randomLong(); final int seed = randomInt();
String preference = randomRealisticUnicodeOfLengthBetween(1, 10); // at least one char!! String preference = randomRealisticUnicodeOfLengthBetween(1, 10); // at least one char!!
// randomPreference should not start with '_' (reserved for known preference types (e.g. _shards, _primary) // randomPreference should not start with '_' (reserved for known preference types (e.g. _shards, _primary)
while (preference.startsWith("_")) { while (preference.startsWith("_")) {
@ -73,10 +74,20 @@ public class RandomScoreFunctionTests extends ElasticsearchIntegrationTest {
} else { } else {
assertThat(hits.getHits().length, equalTo(searchResponse.getHits().getHits().length)); assertThat(hits.getHits().length, equalTo(searchResponse.getHits().getHits().length));
for (int j = 0; j < hitCount; j++) { for (int j = 0; j < hitCount; j++) {
assertThat(searchResponse.getHits().getAt(j).score(), equalTo(hits.getAt(j).score()));
assertThat(searchResponse.getHits().getAt(j).id(), equalTo(hits.getAt(j).id())); assertThat(searchResponse.getHits().getAt(j).id(), equalTo(hits.getAt(j).id()));
assertThat(searchResponse.getHits().getAt(j).score(), equalTo(hits.getAt(j).score()));
} }
} }
// randomly change some docs to get them in different segments
int numDocsToChange = randomIntBetween(20, 50);
while (numDocsToChange > 0) {
int doc = randomInt(docCount);
index("test", "type", "" + doc, jsonBuilder().startObject().endObject());
--numDocsToChange;
}
flush();
refresh();
} }
} }
} }
@ -148,9 +159,56 @@ public class RandomScoreFunctionTests extends ElasticsearchIntegrationTest {
assertThat(firstHit.getScore(), greaterThan(1f)); assertThat(firstHit.getScore(), greaterThan(1f));
} }
@Test
public void testSeedReportedInExplain() throws Exception {
createIndex("test");
ensureGreen();
index("test", "type", "1", jsonBuilder().startObject().endObject());
flush();
refresh();
int seed = 12345678;
SearchResponse resp = client().prepareSearch("test")
.setQuery(functionScoreQuery(matchAllQuery(), randomFunction(seed)))
.setExplain(true)
.get();
assertNoFailures(resp);
assertEquals(1, resp.getHits().totalHits());
SearchHit firstHit = resp.getHits().getAt(0);
assertThat(firstHit.explanation().toString(), containsString("" + seed));
}
@Test
public void testScoreRange() throws Exception {
// all random scores should be in range [0.0, 1.0]
createIndex("test");
ensureGreen();
int docCount = randomIntBetween(100, 200);
for (int i = 0; i < docCount; i++) {
String id = randomRealisticUnicodeOfCodepointLengthBetween(1, 50);
index("test", "type", id, jsonBuilder().startObject().endObject());
}
flush();
refresh();
int iters = scaledRandomIntBetween(10, 20);
for (int i = 0; i < iters; ++i) {
int seed = randomInt();
SearchResponse searchResponse = client().prepareSearch()
.setQuery(functionScoreQuery(matchAllQuery(), randomFunction(seed)))
.setSize(docCount)
.execute().actionGet();
assertNoFailures(searchResponse);
for (SearchHit hit : searchResponse.getHits().getHits()) {
assertThat(hit.score(), allOf(greaterThanOrEqualTo(0.0f), lessThanOrEqualTo(1.0f)));
}
}
}
@Test @Test
@Ignore @Ignore
public void distribution() throws Exception { public void checkDistribution() throws Exception {
int count = 10000; int count = 10000;
assertAcked(prepareCreate("test")); assertAcked(prepareCreate("test"));
@ -168,7 +226,7 @@ public class RandomScoreFunctionTests extends ElasticsearchIntegrationTest {
for (int i = 0; i < count; i++) { for (int i = 0; i < count; i++) {
SearchResponse searchResponse = client().prepareSearch() SearchResponse searchResponse = client().prepareSearch()
.setQuery(functionScoreQuery(matchAllQuery(), randomFunction(System.nanoTime()))) .setQuery(functionScoreQuery(matchAllQuery(), new RandomScoreFunctionBuilder()))
.execute().actionGet(); .execute().actionGet();
matrix[Integer.valueOf(searchResponse.getHits().getAt(0).id())]++; matrix[Integer.valueOf(searchResponse.getHits().getAt(0).id())]++;