Make RescoreBuilder and nested QueryRescorer Writable

Adding serialization capabilities to RescoreBuilder and make
all QueryRescorer implement NamedWritable, also requiring
all implementations of RescoreBuilder.Rescorer to implement
equals() and hashCode.

In addition, the current rescore mode enumeration is pulled out to a
separate class to make sharing of constants easier between
the query builders XContent rendering coder and the parser.
This commit is contained in:
Christoph Büscher 2015-12-14 13:04:10 +01:00
parent a2796b555f
commit 76192024a8
10 changed files with 543 additions and 120 deletions

View File

@ -421,7 +421,7 @@ public class SearchRequestBuilder extends ActionRequestBuilder<SearchRequest, Se
* @return this for chaining
*/
public SearchRequestBuilder addRescorer(RescoreBuilder.Rescorer rescorer) {
sourceBuilder().addRescorer(new RescoreBuilder().rescorer(rescorer));
sourceBuilder().addRescorer(new RescoreBuilder(rescorer));
return this;
}
@ -433,7 +433,7 @@ public class SearchRequestBuilder extends ActionRequestBuilder<SearchRequest, Se
* @return this for chaining
*/
public SearchRequestBuilder addRescorer(RescoreBuilder.Rescorer rescorer, int window) {
sourceBuilder().addRescorer(new RescoreBuilder().rescorer(rescorer).windowSize(window));
sourceBuilder().addRescorer(new RescoreBuilder(rescorer).windowSize(window));
return this;
}

View File

@ -37,6 +37,7 @@ import org.elasticsearch.common.geo.builders.ShapeBuilder;
import org.elasticsearch.common.text.Text;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.functionscore.ScoreFunctionBuilder;
import org.elasticsearch.search.rescore.RescoreBuilder.Rescorer;
import org.joda.time.DateTime;
import org.joda.time.DateTimeZone;
@ -676,6 +677,13 @@ public abstract class StreamInput extends InputStream {
return readNamedWriteable(ShapeBuilder.class);
}
/**
* Reads a {@link QueryBuilder} from the current stream
*/
public Rescorer readRescorer() throws IOException {
return readNamedWriteable(Rescorer.class);
}
/**
* Reads a {@link org.elasticsearch.index.query.functionscore.ScoreFunctionBuilder} from the current stream
*/

View File

@ -36,13 +36,13 @@ import org.elasticsearch.common.geo.builders.ShapeBuilder;
import org.elasticsearch.common.text.Text;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.functionscore.ScoreFunctionBuilder;
import org.elasticsearch.search.rescore.RescoreBuilder.Rescorer;
import org.joda.time.ReadableInstant;
import java.io.EOFException;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.OutputStream;
import java.nio.channels.ClosedChannelException;
import java.nio.file.AccessDeniedException;
import java.nio.file.AtomicMoveNotSupportedException;
import java.nio.file.DirectoryNotEmptyException;
@ -676,5 +676,12 @@ public abstract class StreamOutput extends OutputStream {
for (T obj: list) {
obj.writeTo(this);
}
}
/**
* Writes a {@link Rescorer} to the current stream
*/
public void writeRescorer(Rescorer rescorer) throws IOException {
writeNamedWriteable(rescorer);
}
}

View File

@ -0,0 +1,117 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.search.rescore;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import java.io.IOException;
import java.util.Locale;
public enum QueryRescoreMode implements Writeable<QueryRescoreMode> {
Avg {
@Override
public float combine(float primary, float secondary) {
return (primary + secondary) / 2;
}
@Override
public String toString() {
return "avg";
}
},
Max {
@Override
public float combine(float primary, float secondary) {
return Math.max(primary, secondary);
}
@Override
public String toString() {
return "max";
}
},
Min {
@Override
public float combine(float primary, float secondary) {
return Math.min(primary, secondary);
}
@Override
public String toString() {
return "min";
}
},
Total {
@Override
public float combine(float primary, float secondary) {
return primary + secondary;
}
@Override
public String toString() {
return "sum";
}
},
Multiply {
@Override
public float combine(float primary, float secondary) {
return primary * secondary;
}
@Override
public String toString() {
return "product";
}
};
public abstract float combine(float primary, float secondary);
static QueryRescoreMode PROTOTYPE = Total;
@Override
public QueryRescoreMode readFrom(StreamInput in) throws IOException {
int ordinal = in.readVInt();
if (ordinal < 0 || ordinal >= values().length) {
throw new IOException("Unknown ScoreMode ordinal [" + ordinal + "]");
}
return values()[ordinal];
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeVInt(this.ordinal());
}
public static QueryRescoreMode fromString(String scoreMode) {
for (QueryRescoreMode mode : values()) {
if (scoreMode.toLowerCase(Locale.ROOT).equals(mode.name().toLowerCase(Locale.ROOT))) {
return mode;
}
}
throw new IllegalArgumentException("illegal score_mode [" + scoreMode + "]");
}
@Override
public String toString() {
return name().toLowerCase(Locale.ROOT);
}
}

View File

@ -38,66 +38,6 @@ import java.util.Set;
public final class QueryRescorer implements Rescorer {
private static enum ScoreMode {
Avg {
@Override
public float combine(float primary, float secondary) {
return (primary + secondary) / 2;
}
@Override
public String toString() {
return "avg";
}
},
Max {
@Override
public float combine(float primary, float secondary) {
return Math.max(primary, secondary);
}
@Override
public String toString() {
return "max";
}
},
Min {
@Override
public float combine(float primary, float secondary) {
return Math.min(primary, secondary);
}
@Override
public String toString() {
return "min";
}
},
Total {
@Override
public float combine(float primary, float secondary) {
return primary + secondary;
}
@Override
public String toString() {
return "sum";
}
},
Multiply {
@Override
public float combine(float primary, float secondary) {
return primary * secondary;
}
@Override
public String toString() {
return "product";
}
};
public abstract float combine(float primary, float secondary);
}
public static final Rescorer INSTANCE = new QueryRescorer();
public static final String NAME = "query";
@ -170,7 +110,7 @@ public final class QueryRescorer implements Rescorer {
rescoreExplain.getValue() * secondaryWeight,
"product of:",
rescoreExplain, Explanation.match(secondaryWeight, "secondaryWeight"));
ScoreMode scoreMode = rescore.scoreMode();
QueryRescoreMode scoreMode = rescore.scoreMode();
return Explanation.match(
scoreMode.combine(prim.getValue(), sec.getValue()),
scoreMode + " of:",
@ -228,7 +168,7 @@ public final class QueryRescorer implements Rescorer {
// secondary score?
in.scoreDocs[i].score *= ctx.queryWeight();
}
// TODO: this is wrong, i.e. we are comparing apples and oranges at this point. It would be better if we always rescored all
// incoming first pass hits, instead of allowing recoring of just the top subset:
Arrays.sort(in.scoreDocs, SCORE_DOC_COMPARATOR);
@ -240,13 +180,13 @@ public final class QueryRescorer implements Rescorer {
public QueryRescoreContext(QueryRescorer rescorer) {
super(NAME, 10, rescorer);
this.scoreMode = ScoreMode.Total;
this.scoreMode = QueryRescoreMode.Total;
}
private ParsedQuery parsedQuery;
private float queryWeight = 1.0f;
private float rescoreQueryWeight = 1.0f;
private ScoreMode scoreMode;
private QueryRescoreMode scoreMode;
public void setParsedQuery(ParsedQuery parsedQuery) {
this.parsedQuery = parsedQuery;
@ -264,7 +204,7 @@ public final class QueryRescorer implements Rescorer {
return rescoreQueryWeight;
}
public ScoreMode scoreMode() {
public QueryRescoreMode scoreMode() {
return scoreMode;
}
@ -276,26 +216,13 @@ public final class QueryRescorer implements Rescorer {
this.queryWeight = queryWeight;
}
public void setScoreMode(ScoreMode scoreMode) {
public void setScoreMode(QueryRescoreMode scoreMode) {
this.scoreMode = scoreMode;
}
public void setScoreMode(String scoreMode) {
if ("avg".equals(scoreMode)) {
setScoreMode(ScoreMode.Avg);
} else if ("max".equals(scoreMode)) {
setScoreMode(ScoreMode.Max);
} else if ("min".equals(scoreMode)) {
setScoreMode(ScoreMode.Min);
} else if ("total".equals(scoreMode)) {
setScoreMode(ScoreMode.Total);
} else if ("multiply".equals(scoreMode)) {
setScoreMode(ScoreMode.Multiply);
} else {
throw new IllegalArgumentException("illegal score_mode [" + scoreMode + "]");
}
setScoreMode(QueryRescoreMode.fromString(scoreMode));
}
}
@Override

View File

@ -19,24 +19,36 @@
package org.elasticsearch.search.rescore;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.index.query.MatchAllQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import java.io.IOException;
import java.util.Locale;
import java.util.Objects;
public class RescoreBuilder implements ToXContent {
public class RescoreBuilder implements ToXContent, Writeable<RescoreBuilder> {
private Rescorer rescorer;
private Integer windowSize;
public static final RescoreBuilder PROTOYPE = new RescoreBuilder(new QueryRescorer(new MatchAllQueryBuilder()));
public static QueryRescorer queryRescorer(QueryBuilder queryBuilder) {
return new QueryRescorer(queryBuilder);
public RescoreBuilder(Rescorer rescorer) {
if (rescorer == null) {
throw new IllegalArgumentException("rescorer cannot be null");
}
this.rescorer = rescorer;
}
public RescoreBuilder rescorer(Rescorer rescorer) {
this.rescorer = rescorer;
return this;
public Rescorer rescorer() {
return this.rescorer;
}
public RescoreBuilder windowSize(int windowSize) {
@ -48,10 +60,6 @@ public class RescoreBuilder implements ToXContent {
return windowSize;
}
public boolean isEmpty() {
return rescorer == null;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
if (windowSize != null) {
@ -61,13 +69,66 @@ public class RescoreBuilder implements ToXContent {
return builder;
}
public static abstract class Rescorer implements ToXContent {
public static QueryRescorer queryRescorer(QueryBuilder<?> queryBuilder) {
return new QueryRescorer(queryBuilder);
}
@Override
public final int hashCode() {
return Objects.hash(windowSize, rescorer);
}
@Override
public final boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null || getClass() != obj.getClass()) {
return false;
}
RescoreBuilder other = (RescoreBuilder) obj;
return Objects.equals(windowSize, other.windowSize) &&
Objects.equals(rescorer, other.rescorer);
}
@Override
public RescoreBuilder readFrom(StreamInput in) throws IOException {
RescoreBuilder builder = new RescoreBuilder(in.readRescorer());
Integer windowSize = in.readOptionalVInt();
if (windowSize != null) {
builder.windowSize(windowSize);
}
return builder;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeRescorer(rescorer);
out.writeOptionalVInt(this.windowSize);
}
@Override
public final String toString() {
try {
XContentBuilder builder = XContentFactory.jsonBuilder();
builder.prettyPrint();
builder.startObject();
toXContent(builder, EMPTY_PARAMS);
builder.endObject();
return builder.string();
} catch (Exception e) {
return "{ \"error\" : \"" + ExceptionsHelper.detailedMessage(e) + "\"}";
}
}
public static abstract class Rescorer implements ToXContent, NamedWriteable<Rescorer> {
private String name;
public Rescorer(String name) {
this.name = name;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(name);
@ -78,23 +139,41 @@ public class RescoreBuilder implements ToXContent {
protected abstract XContentBuilder innerToXContent(XContentBuilder builder, Params params) throws IOException;
@Override
public abstract int hashCode();
@Override
public abstract boolean equals(Object obj);
}
public static class QueryRescorer extends Rescorer {
private static final String NAME = "query";
private QueryBuilder queryBuilder;
private Float rescoreQueryWeight;
private Float queryWeight;
private String scoreMode;
public static final QueryRescorer PROTOTYPE = new QueryRescorer(new MatchAllQueryBuilder());
public static final float DEFAULT_RESCORE_QUERYWEIGHT = 1.0f;
public static final float DEFAULT_QUERYWEIGHT = 1.0f;
public static final QueryRescoreMode DEFAULT_SCORE_MODE = QueryRescoreMode.Total;
private final QueryBuilder<?> queryBuilder;
private float rescoreQueryWeight = DEFAULT_RESCORE_QUERYWEIGHT;
private float queryWeight = DEFAULT_QUERYWEIGHT;
private QueryRescoreMode scoreMode = DEFAULT_SCORE_MODE;
/**
* Creates a new {@link QueryRescorer} instance
* @param builder the query builder to build the rescore query from
*/
public QueryRescorer(QueryBuilder builder) {
public QueryRescorer(QueryBuilder<?> builder) {
super(NAME);
this.queryBuilder = builder;
}
/**
* @return the query used for this rescore query
*/
public QueryBuilder<?> getRescoreQuery() {
return this.queryBuilder;
}
/**
* Sets the original query weight for rescoring. The default is <tt>1.0</tt>
*/
@ -103,6 +182,14 @@ public class RescoreBuilder implements ToXContent {
return this;
}
/**
* Gets the original query weight for rescoring. The default is <tt>1.0</tt>
*/
public float getQueryWeight() {
return this.queryWeight;
}
/**
* Sets the original query weight for rescoring. The default is <tt>1.0</tt>
*/
@ -112,27 +199,76 @@ public class RescoreBuilder implements ToXContent {
}
/**
* Sets the original query score mode. The default is <tt>total</tt>
* Gets the original query weight for rescoring. The default is <tt>1.0</tt>
*/
public QueryRescorer setScoreMode(String scoreMode) {
public float getRescoreQueryWeight() {
return this.rescoreQueryWeight;
}
/**
* Sets the original query score mode. The default is {@link QueryRescoreMode#Total}.
*/
public QueryRescorer setScoreMode(QueryRescoreMode scoreMode) {
this.scoreMode = scoreMode;
return this;
}
/**
* Gets the original query score mode. The default is <tt>total</tt>
*/
public QueryRescoreMode getScoreMode() {
return this.scoreMode;
}
@Override
protected XContentBuilder innerToXContent(XContentBuilder builder, Params params) throws IOException {
builder.field("rescore_query", queryBuilder);
if (queryWeight != null) {
builder.field("query_weight", queryWeight);
}
if (rescoreQueryWeight != null) {
builder.field("rescore_query_weight", rescoreQueryWeight);
}
if (scoreMode != null) {
builder.field("score_mode", scoreMode);
}
builder.field("query_weight", queryWeight);
builder.field("rescore_query_weight", rescoreQueryWeight);
builder.field("score_mode", scoreMode.name().toLowerCase(Locale.ROOT));
return builder;
}
}
@Override
public final int hashCode() {
return Objects.hash(getClass(), scoreMode, queryWeight, rescoreQueryWeight, queryBuilder);
}
@Override
public final boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null || getClass() != obj.getClass()) {
return false;
}
QueryRescorer other = (QueryRescorer) obj;
return Objects.equals(scoreMode, other.scoreMode) &&
Objects.equals(queryWeight, other.queryWeight) &&
Objects.equals(rescoreQueryWeight, other.rescoreQueryWeight) &&
Objects.equals(queryBuilder, other.queryBuilder);
}
@Override
public QueryRescorer readFrom(StreamInput in) throws IOException {
QueryRescorer rescorer = new QueryRescorer(in.readQuery());
rescorer.setScoreMode(QueryRescoreMode.PROTOTYPE.readFrom(in));
rescorer.setRescoreQueryWeight(in.readFloat());
rescorer.setQueryWeight(in.readFloat());
return rescorer;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeQuery(queryBuilder);
scoreMode.writeTo(out);
out.writeFloat(rescoreQueryWeight);
out.writeFloat(queryWeight);
}
@Override
public String getWriteableName() {
return NAME;
}
}
}

View File

@ -274,8 +274,7 @@ public class SearchSourceBuilderTests extends ESTestCase {
int numRescores = randomIntBetween(1, 5);
for (int i = 0; i < numRescores; i++) {
// NORELEASE need a random rescore builder method
RescoreBuilder rescoreBuilder = new RescoreBuilder();
rescoreBuilder.rescorer(RescoreBuilder.queryRescorer(QueryBuilders.termQuery(randomAsciiOfLengthBetween(5, 20),
RescoreBuilder rescoreBuilder = new RescoreBuilder(RescoreBuilder.queryRescorer(QueryBuilders.termQuery(randomAsciiOfLengthBetween(5, 20),
randomAsciiOfLengthBetween(5, 20))));
builder.addRescorer(rescoreBuilder);
}

View File

@ -37,6 +37,7 @@ import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.query.functionscore.ScoreFunctionBuilders;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.rescore.QueryRescoreMode;
import org.elasticsearch.search.rescore.RescoreBuilder;
import org.elasticsearch.search.rescore.RescoreBuilder.QueryRescorer;
import org.elasticsearch.test.ESIntegTestCase;
@ -541,7 +542,7 @@ public class QueryRescorerIT extends ESIntegTestCase {
.setQueryWeight(0.5f).setRescoreQueryWeight(0.4f);
if (!"".equals(scoreModes[innerMode])) {
innerRescoreQuery.setScoreMode(scoreModes[innerMode]);
innerRescoreQuery.setScoreMode(QueryRescoreMode.fromString(scoreModes[innerMode]));
}
SearchResponse searchResponse = client()
@ -564,7 +565,7 @@ public class QueryRescorerIT extends ESIntegTestCase {
.boost(4.0f)).setQueryWeight(0.5f).setRescoreQueryWeight(0.4f);
if (!"".equals(scoreModes[outerMode])) {
outerRescoreQuery.setScoreMode(scoreModes[outerMode]);
outerRescoreQuery.setScoreMode(QueryRescoreMode.fromString(scoreModes[outerMode]));
}
searchResponse = client()
@ -612,7 +613,7 @@ public class QueryRescorerIT extends ESIntegTestCase {
.setRescoreQueryWeight(secondaryWeight);
if (!"".equals(scoreMode)) {
rescoreQuery.setScoreMode(scoreMode);
rescoreQuery.setScoreMode(QueryRescoreMode.fromString(scoreMode));
}
SearchResponse rescored = client()
@ -683,11 +684,11 @@ public class QueryRescorerIT extends ESIntegTestCase {
int numDocs = indexRandomNumbers("keyword", 1, true);
QueryRescorer eightIsGreat = RescoreBuilder.queryRescorer(
QueryBuilders.functionScoreQuery(QueryBuilders.termQuery("field1", English.intToEnglish(8)),
ScoreFunctionBuilders.weightFactorFunction(1000.0f)).boostMode(CombineFunction.REPLACE)).setScoreMode("total");
ScoreFunctionBuilders.weightFactorFunction(1000.0f)).boostMode(CombineFunction.REPLACE)).setScoreMode(QueryRescoreMode.Total);
QueryRescorer sevenIsBetter = RescoreBuilder.queryRescorer(
QueryBuilders.functionScoreQuery(QueryBuilders.termQuery("field1", English.intToEnglish(7)),
ScoreFunctionBuilders.weightFactorFunction(10000.0f)).boostMode(CombineFunction.REPLACE))
.setScoreMode("total");
.setScoreMode(QueryRescoreMode.Total);
// First set the rescore window large enough that both rescores take effect
SearchRequestBuilder request = client().prepareSearch();
@ -704,10 +705,10 @@ public class QueryRescorerIT extends ESIntegTestCase {
// Now use one rescore to drag the number we're looking for into the window of another
QueryRescorer ninetyIsGood = RescoreBuilder.queryRescorer(
QueryBuilders.functionScoreQuery(QueryBuilders.queryStringQuery("*ninety*"), ScoreFunctionBuilders.weightFactorFunction(1000.0f))
.boostMode(CombineFunction.REPLACE)).setScoreMode("total");
.boostMode(CombineFunction.REPLACE)).setScoreMode(QueryRescoreMode.Total);
QueryRescorer oneToo = RescoreBuilder.queryRescorer(
QueryBuilders.functionScoreQuery(QueryBuilders.queryStringQuery("*one*"), ScoreFunctionBuilders.weightFactorFunction(1000.0f))
.boostMode(CombineFunction.REPLACE)).setScoreMode("total");
.boostMode(CombineFunction.REPLACE)).setScoreMode(QueryRescoreMode.Total);
request.clearRescorers().addRescorer(ninetyIsGood, numDocs).addRescorer(oneToo, 10);
response = request.setSize(2).get();
assertFirstHit(response, hasId("91"));

View File

@ -0,0 +1,170 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.search.rescore;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.index.query.MatchAllQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.search.rescore.RescoreBuilder.QueryRescorer;
import org.elasticsearch.search.rescore.RescoreBuilder.Rescorer;
import org.elasticsearch.test.ESTestCase;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import java.io.IOException;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.not;
public class QueryRescoreBuilderTests extends ESTestCase {
private static final int NUMBER_OF_TESTBUILDERS = 20;
private static NamedWriteableRegistry namedWriteableRegistry;
/**
* setup for the whole base test class
*/
@BeforeClass
public static void init() {
namedWriteableRegistry = new NamedWriteableRegistry();
namedWriteableRegistry.registerPrototype(Rescorer.class, org.elasticsearch.search.rescore.RescoreBuilder.QueryRescorer.PROTOTYPE);
namedWriteableRegistry.registerPrototype(QueryBuilder.class, new MatchAllQueryBuilder());
}
@AfterClass
public static void afterClass() throws Exception {
namedWriteableRegistry = null;
}
/**
* Test serialization and deserialization of the rescore builder
*/
public void testSerialization() throws IOException {
for (int runs = 0; runs < NUMBER_OF_TESTBUILDERS; runs++) {
RescoreBuilder original = randomRescoreBuilder();
RescoreBuilder deserialized = serializedCopy(original);
assertEquals(deserialized, original);
assertEquals(deserialized.hashCode(), original.hashCode());
assertNotSame(deserialized, original);
}
}
/**
* Test equality and hashCode properties
*/
public void testEqualsAndHashcode() throws IOException {
for (int runs = 0; runs < NUMBER_OF_TESTBUILDERS; runs++) {
RescoreBuilder firstBuilder = randomRescoreBuilder();
assertFalse("rescore builder is equal to null", firstBuilder.equals(null));
assertFalse("rescore builder is equal to incompatible type", firstBuilder.equals(""));
assertTrue("rescore builder is not equal to self", firstBuilder.equals(firstBuilder));
assertThat("same rescore builder's hashcode returns different values if called multiple times", firstBuilder.hashCode(),
equalTo(firstBuilder.hashCode()));
assertThat("different rescore builder should not be equal", mutate(firstBuilder), not(equalTo(firstBuilder)));
RescoreBuilder secondBuilder = serializedCopy(firstBuilder);
assertTrue("rescore builder is not equal to self", secondBuilder.equals(secondBuilder));
assertTrue("rescore builder is not equal to its copy", firstBuilder.equals(secondBuilder));
assertTrue("equals is not symmetric", secondBuilder.equals(firstBuilder));
assertThat("rescore builder copy's hashcode is different from original hashcode", secondBuilder.hashCode(), equalTo(firstBuilder.hashCode()));
RescoreBuilder thirdBuilder = serializedCopy(secondBuilder);
assertTrue("rescore builder is not equal to self", thirdBuilder.equals(thirdBuilder));
assertTrue("rescore builder is not equal to its copy", secondBuilder.equals(thirdBuilder));
assertThat("rescore builder copy's hashcode is different from original hashcode", secondBuilder.hashCode(), equalTo(thirdBuilder.hashCode()));
assertTrue("equals is not transitive", firstBuilder.equals(thirdBuilder));
assertThat("rescore builder copy's hashcode is different from original hashcode", firstBuilder.hashCode(), equalTo(thirdBuilder.hashCode()));
assertTrue("equals is not symmetric", thirdBuilder.equals(secondBuilder));
assertTrue("equals is not symmetric", thirdBuilder.equals(firstBuilder));
}
}
private RescoreBuilder mutate(RescoreBuilder original) throws IOException {
RescoreBuilder mutation = serializedCopy(original);
if (randomBoolean()) {
Integer windowSize = original.windowSize();
if (windowSize != null) {
mutation.windowSize(windowSize + 1);
} else {
mutation.windowSize(randomIntBetween(0, 100));
}
} else {
QueryRescorer queryRescorer = (QueryRescorer) mutation.rescorer();
switch (randomIntBetween(0, 3)) {
case 0:
queryRescorer.setQueryWeight(queryRescorer.getQueryWeight() + 0.1f);
break;
case 1:
queryRescorer.setRescoreQueryWeight(queryRescorer.getRescoreQueryWeight() + 0.1f);
break;
case 2:
QueryRescoreMode other;
do {
other = randomFrom(QueryRescoreMode.values());
} while (other == queryRescorer.getScoreMode());
queryRescorer.setScoreMode(other);
break;
case 3:
// only increase the boost to make it a slightly different query
queryRescorer.getRescoreQuery().boost(queryRescorer.getRescoreQuery().boost() + 0.1f);
break;
default:
throw new IllegalStateException("unexpected random mutation in test");
}
}
return mutation;
}
/**
* create random shape that is put under test
*/
private static RescoreBuilder randomRescoreBuilder() {
QueryBuilder<MatchAllQueryBuilder> queryBuilder = new MatchAllQueryBuilder().boost(randomFloat()).queryName(randomAsciiOfLength(20));
org.elasticsearch.search.rescore.RescoreBuilder.QueryRescorer rescorer = new
org.elasticsearch.search.rescore.RescoreBuilder.QueryRescorer(queryBuilder);
if (randomBoolean()) {
rescorer.setQueryWeight(randomFloat());
}
if (randomBoolean()) {
rescorer.setRescoreQueryWeight(randomFloat());
}
if (randomBoolean()) {
rescorer.setScoreMode(randomFrom(QueryRescoreMode.values()));
}
RescoreBuilder builder = new RescoreBuilder(rescorer);
if (randomBoolean()) {
builder.windowSize(randomIntBetween(0, 100));
}
return builder;
}
private static RescoreBuilder serializedCopy(RescoreBuilder original) throws IOException {
try (BytesStreamOutput output = new BytesStreamOutput()) {
original.writeTo(output);
try (StreamInput in = new NamedWriteableAwareStreamInput(StreamInput.wrap(output.bytes()), namedWriteableRegistry)) {
return RescoreBuilder.PROTOYPE.readFrom(in);
}
}
}
}

View File

@ -0,0 +1,58 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.search.rescore;
import org.elasticsearch.test.ESTestCase;
import java.io.IOException;
/**
* Test fixing the ordinals and names in {@link QueryRescoreMode}. These should not be changed since we
* use the names in the parser and the ordinals in serialization.
*/
public class QueryRescoreModeTests extends ESTestCase {
/**
* Test @link {@link QueryRescoreMode} enum ordinals and names, since serilaization relies on it
*/
public void testQueryRescoreMode() throws IOException {
float primary = randomFloat();
float secondary = randomFloat();
assertEquals(0, QueryRescoreMode.Avg.ordinal());
assertEquals("avg", QueryRescoreMode.Avg.toString());
assertEquals((primary + secondary)/2.0f, QueryRescoreMode.Avg.combine(primary, secondary), Float.MIN_VALUE);
assertEquals(1, QueryRescoreMode.Max.ordinal());
assertEquals("max", QueryRescoreMode.Max.toString());
assertEquals(Math.max(primary, secondary), QueryRescoreMode.Max.combine(primary, secondary), Float.MIN_VALUE);
assertEquals(2, QueryRescoreMode.Min.ordinal());
assertEquals("min", QueryRescoreMode.Min.toString());
assertEquals(Math.min(primary, secondary), QueryRescoreMode.Min.combine(primary, secondary), Float.MIN_VALUE);
assertEquals(3, QueryRescoreMode.Total.ordinal());
assertEquals("sum", QueryRescoreMode.Total.toString());
assertEquals(primary + secondary, QueryRescoreMode.Total.combine(primary, secondary), Float.MIN_VALUE);
assertEquals(4, QueryRescoreMode.Multiply.ordinal());
assertEquals("product", QueryRescoreMode.Multiply.toString());
assertEquals(primary * secondary, QueryRescoreMode.Multiply.combine(primary, secondary), Float.MIN_VALUE);
}
}