Merge branch 'refactor/rescore-builder-equals-hash'

This commit is contained in:
Christoph Büscher 2016-01-14 15:40:34 +01:00
commit 3d98756e64
10 changed files with 543 additions and 120 deletions

View File

@ -421,7 +421,7 @@ public class SearchRequestBuilder extends ActionRequestBuilder<SearchRequest, Se
* @return this for chaining * @return this for chaining
*/ */
public SearchRequestBuilder addRescorer(RescoreBuilder.Rescorer rescorer) { public SearchRequestBuilder addRescorer(RescoreBuilder.Rescorer rescorer) {
sourceBuilder().addRescorer(new RescoreBuilder().rescorer(rescorer)); sourceBuilder().addRescorer(new RescoreBuilder(rescorer));
return this; return this;
} }
@ -433,7 +433,7 @@ public class SearchRequestBuilder extends ActionRequestBuilder<SearchRequest, Se
* @return this for chaining * @return this for chaining
*/ */
public SearchRequestBuilder addRescorer(RescoreBuilder.Rescorer rescorer, int window) { public SearchRequestBuilder addRescorer(RescoreBuilder.Rescorer rescorer, int window) {
sourceBuilder().addRescorer(new RescoreBuilder().rescorer(rescorer).windowSize(window)); sourceBuilder().addRescorer(new RescoreBuilder(rescorer).windowSize(window));
return this; return this;
} }

View File

@ -37,6 +37,7 @@ import org.elasticsearch.common.geo.builders.ShapeBuilder;
import org.elasticsearch.common.text.Text; import org.elasticsearch.common.text.Text;
import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.functionscore.ScoreFunctionBuilder; import org.elasticsearch.index.query.functionscore.ScoreFunctionBuilder;
import org.elasticsearch.search.rescore.RescoreBuilder.Rescorer;
import org.joda.time.DateTime; import org.joda.time.DateTime;
import org.joda.time.DateTimeZone; import org.joda.time.DateTimeZone;
@ -676,6 +677,13 @@ public abstract class StreamInput extends InputStream {
return readNamedWriteable(ShapeBuilder.class); return readNamedWriteable(ShapeBuilder.class);
} }
/**
* Reads a {@link QueryBuilder} from the current stream
*/
public Rescorer readRescorer() throws IOException {
return readNamedWriteable(Rescorer.class);
}
/** /**
* Reads a {@link org.elasticsearch.index.query.functionscore.ScoreFunctionBuilder} from the current stream * Reads a {@link org.elasticsearch.index.query.functionscore.ScoreFunctionBuilder} from the current stream
*/ */

View File

@ -36,13 +36,13 @@ import org.elasticsearch.common.geo.builders.ShapeBuilder;
import org.elasticsearch.common.text.Text; import org.elasticsearch.common.text.Text;
import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.functionscore.ScoreFunctionBuilder; import org.elasticsearch.index.query.functionscore.ScoreFunctionBuilder;
import org.elasticsearch.search.rescore.RescoreBuilder.Rescorer;
import org.joda.time.ReadableInstant; import org.joda.time.ReadableInstant;
import java.io.EOFException; import java.io.EOFException;
import java.io.FileNotFoundException; import java.io.FileNotFoundException;
import java.io.IOException; import java.io.IOException;
import java.io.OutputStream; import java.io.OutputStream;
import java.nio.channels.ClosedChannelException;
import java.nio.file.AccessDeniedException; import java.nio.file.AccessDeniedException;
import java.nio.file.AtomicMoveNotSupportedException; import java.nio.file.AtomicMoveNotSupportedException;
import java.nio.file.DirectoryNotEmptyException; import java.nio.file.DirectoryNotEmptyException;
@ -676,5 +676,12 @@ public abstract class StreamOutput extends OutputStream {
for (T obj: list) { for (T obj: list) {
obj.writeTo(this); obj.writeTo(this);
} }
}
/**
* Writes a {@link Rescorer} to the current stream
*/
public void writeRescorer(Rescorer rescorer) throws IOException {
writeNamedWriteable(rescorer);
} }
} }

View File

@ -0,0 +1,117 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.search.rescore;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import java.io.IOException;
import java.util.Locale;
public enum QueryRescoreMode implements Writeable<QueryRescoreMode> {
Avg {
@Override
public float combine(float primary, float secondary) {
return (primary + secondary) / 2;
}
@Override
public String toString() {
return "avg";
}
},
Max {
@Override
public float combine(float primary, float secondary) {
return Math.max(primary, secondary);
}
@Override
public String toString() {
return "max";
}
},
Min {
@Override
public float combine(float primary, float secondary) {
return Math.min(primary, secondary);
}
@Override
public String toString() {
return "min";
}
},
Total {
@Override
public float combine(float primary, float secondary) {
return primary + secondary;
}
@Override
public String toString() {
return "sum";
}
},
Multiply {
@Override
public float combine(float primary, float secondary) {
return primary * secondary;
}
@Override
public String toString() {
return "product";
}
};
public abstract float combine(float primary, float secondary);
static QueryRescoreMode PROTOTYPE = Total;
@Override
public QueryRescoreMode readFrom(StreamInput in) throws IOException {
int ordinal = in.readVInt();
if (ordinal < 0 || ordinal >= values().length) {
throw new IOException("Unknown ScoreMode ordinal [" + ordinal + "]");
}
return values()[ordinal];
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeVInt(this.ordinal());
}
public static QueryRescoreMode fromString(String scoreMode) {
for (QueryRescoreMode mode : values()) {
if (scoreMode.toLowerCase(Locale.ROOT).equals(mode.name().toLowerCase(Locale.ROOT))) {
return mode;
}
}
throw new IllegalArgumentException("illegal score_mode [" + scoreMode + "]");
}
@Override
public String toString() {
return name().toLowerCase(Locale.ROOT);
}
}

View File

@ -38,66 +38,6 @@ import java.util.Set;
public final class QueryRescorer implements Rescorer { public final class QueryRescorer implements Rescorer {
private static enum ScoreMode {
Avg {
@Override
public float combine(float primary, float secondary) {
return (primary + secondary) / 2;
}
@Override
public String toString() {
return "avg";
}
},
Max {
@Override
public float combine(float primary, float secondary) {
return Math.max(primary, secondary);
}
@Override
public String toString() {
return "max";
}
},
Min {
@Override
public float combine(float primary, float secondary) {
return Math.min(primary, secondary);
}
@Override
public String toString() {
return "min";
}
},
Total {
@Override
public float combine(float primary, float secondary) {
return primary + secondary;
}
@Override
public String toString() {
return "sum";
}
},
Multiply {
@Override
public float combine(float primary, float secondary) {
return primary * secondary;
}
@Override
public String toString() {
return "product";
}
};
public abstract float combine(float primary, float secondary);
}
public static final Rescorer INSTANCE = new QueryRescorer(); public static final Rescorer INSTANCE = new QueryRescorer();
public static final String NAME = "query"; public static final String NAME = "query";
@ -170,7 +110,7 @@ public final class QueryRescorer implements Rescorer {
rescoreExplain.getValue() * secondaryWeight, rescoreExplain.getValue() * secondaryWeight,
"product of:", "product of:",
rescoreExplain, Explanation.match(secondaryWeight, "secondaryWeight")); rescoreExplain, Explanation.match(secondaryWeight, "secondaryWeight"));
ScoreMode scoreMode = rescore.scoreMode(); QueryRescoreMode scoreMode = rescore.scoreMode();
return Explanation.match( return Explanation.match(
scoreMode.combine(prim.getValue(), sec.getValue()), scoreMode.combine(prim.getValue(), sec.getValue()),
scoreMode + " of:", scoreMode + " of:",
@ -228,7 +168,7 @@ public final class QueryRescorer implements Rescorer {
// secondary score? // secondary score?
in.scoreDocs[i].score *= ctx.queryWeight(); in.scoreDocs[i].score *= ctx.queryWeight();
} }
// TODO: this is wrong, i.e. we are comparing apples and oranges at this point. It would be better if we always rescored all // TODO: this is wrong, i.e. we are comparing apples and oranges at this point. It would be better if we always rescored all
// incoming first pass hits, instead of allowing recoring of just the top subset: // incoming first pass hits, instead of allowing recoring of just the top subset:
Arrays.sort(in.scoreDocs, SCORE_DOC_COMPARATOR); Arrays.sort(in.scoreDocs, SCORE_DOC_COMPARATOR);
@ -240,13 +180,13 @@ public final class QueryRescorer implements Rescorer {
public QueryRescoreContext(QueryRescorer rescorer) { public QueryRescoreContext(QueryRescorer rescorer) {
super(NAME, 10, rescorer); super(NAME, 10, rescorer);
this.scoreMode = ScoreMode.Total; this.scoreMode = QueryRescoreMode.Total;
} }
private ParsedQuery parsedQuery; private ParsedQuery parsedQuery;
private float queryWeight = 1.0f; private float queryWeight = 1.0f;
private float rescoreQueryWeight = 1.0f; private float rescoreQueryWeight = 1.0f;
private ScoreMode scoreMode; private QueryRescoreMode scoreMode;
public void setParsedQuery(ParsedQuery parsedQuery) { public void setParsedQuery(ParsedQuery parsedQuery) {
this.parsedQuery = parsedQuery; this.parsedQuery = parsedQuery;
@ -264,7 +204,7 @@ public final class QueryRescorer implements Rescorer {
return rescoreQueryWeight; return rescoreQueryWeight;
} }
public ScoreMode scoreMode() { public QueryRescoreMode scoreMode() {
return scoreMode; return scoreMode;
} }
@ -276,26 +216,13 @@ public final class QueryRescorer implements Rescorer {
this.queryWeight = queryWeight; this.queryWeight = queryWeight;
} }
public void setScoreMode(ScoreMode scoreMode) { public void setScoreMode(QueryRescoreMode scoreMode) {
this.scoreMode = scoreMode; this.scoreMode = scoreMode;
} }
public void setScoreMode(String scoreMode) { public void setScoreMode(String scoreMode) {
if ("avg".equals(scoreMode)) { setScoreMode(QueryRescoreMode.fromString(scoreMode));
setScoreMode(ScoreMode.Avg);
} else if ("max".equals(scoreMode)) {
setScoreMode(ScoreMode.Max);
} else if ("min".equals(scoreMode)) {
setScoreMode(ScoreMode.Min);
} else if ("total".equals(scoreMode)) {
setScoreMode(ScoreMode.Total);
} else if ("multiply".equals(scoreMode)) {
setScoreMode(ScoreMode.Multiply);
} else {
throw new IllegalArgumentException("illegal score_mode [" + scoreMode + "]");
}
} }
} }
@Override @Override

View File

@ -19,24 +19,36 @@
package org.elasticsearch.search.rescore; package org.elasticsearch.search.rescore;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.index.query.MatchAllQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilder;
import java.io.IOException; import java.io.IOException;
import java.util.Locale;
import java.util.Objects;
public class RescoreBuilder implements ToXContent { public class RescoreBuilder implements ToXContent, Writeable<RescoreBuilder> {
private Rescorer rescorer; private Rescorer rescorer;
private Integer windowSize; private Integer windowSize;
public static final RescoreBuilder PROTOYPE = new RescoreBuilder(new QueryRescorer(new MatchAllQueryBuilder()));
public static QueryRescorer queryRescorer(QueryBuilder queryBuilder) { public RescoreBuilder(Rescorer rescorer) {
return new QueryRescorer(queryBuilder); if (rescorer == null) {
throw new IllegalArgumentException("rescorer cannot be null");
}
this.rescorer = rescorer;
} }
public RescoreBuilder rescorer(Rescorer rescorer) { public Rescorer rescorer() {
this.rescorer = rescorer; return this.rescorer;
return this;
} }
public RescoreBuilder windowSize(int windowSize) { public RescoreBuilder windowSize(int windowSize) {
@ -48,10 +60,6 @@ public class RescoreBuilder implements ToXContent {
return windowSize; return windowSize;
} }
public boolean isEmpty() {
return rescorer == null;
}
@Override @Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
if (windowSize != null) { if (windowSize != null) {
@ -61,13 +69,66 @@ public class RescoreBuilder implements ToXContent {
return builder; return builder;
} }
public static abstract class Rescorer implements ToXContent { public static QueryRescorer queryRescorer(QueryBuilder<?> queryBuilder) {
return new QueryRescorer(queryBuilder);
}
@Override
public final int hashCode() {
return Objects.hash(windowSize, rescorer);
}
@Override
public final boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null || getClass() != obj.getClass()) {
return false;
}
RescoreBuilder other = (RescoreBuilder) obj;
return Objects.equals(windowSize, other.windowSize) &&
Objects.equals(rescorer, other.rescorer);
}
@Override
public RescoreBuilder readFrom(StreamInput in) throws IOException {
RescoreBuilder builder = new RescoreBuilder(in.readRescorer());
Integer windowSize = in.readOptionalVInt();
if (windowSize != null) {
builder.windowSize(windowSize);
}
return builder;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeRescorer(rescorer);
out.writeOptionalVInt(this.windowSize);
}
@Override
public final String toString() {
try {
XContentBuilder builder = XContentFactory.jsonBuilder();
builder.prettyPrint();
builder.startObject();
toXContent(builder, EMPTY_PARAMS);
builder.endObject();
return builder.string();
} catch (Exception e) {
return "{ \"error\" : \"" + ExceptionsHelper.detailedMessage(e) + "\"}";
}
}
public static abstract class Rescorer implements ToXContent, NamedWriteable<Rescorer> {
private String name; private String name;
public Rescorer(String name) { public Rescorer(String name) {
this.name = name; this.name = name;
} }
@Override @Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(name); builder.startObject(name);
@ -78,23 +139,41 @@ public class RescoreBuilder implements ToXContent {
protected abstract XContentBuilder innerToXContent(XContentBuilder builder, Params params) throws IOException; protected abstract XContentBuilder innerToXContent(XContentBuilder builder, Params params) throws IOException;
@Override
public abstract int hashCode();
@Override
public abstract boolean equals(Object obj);
} }
public static class QueryRescorer extends Rescorer { public static class QueryRescorer extends Rescorer {
private static final String NAME = "query"; private static final String NAME = "query";
private QueryBuilder queryBuilder; public static final QueryRescorer PROTOTYPE = new QueryRescorer(new MatchAllQueryBuilder());
private Float rescoreQueryWeight; public static final float DEFAULT_RESCORE_QUERYWEIGHT = 1.0f;
private Float queryWeight; public static final float DEFAULT_QUERYWEIGHT = 1.0f;
private String scoreMode; public static final QueryRescoreMode DEFAULT_SCORE_MODE = QueryRescoreMode.Total;
private final QueryBuilder<?> queryBuilder;
private float rescoreQueryWeight = DEFAULT_RESCORE_QUERYWEIGHT;
private float queryWeight = DEFAULT_QUERYWEIGHT;
private QueryRescoreMode scoreMode = DEFAULT_SCORE_MODE;
/** /**
* Creates a new {@link QueryRescorer} instance * Creates a new {@link QueryRescorer} instance
* @param builder the query builder to build the rescore query from * @param builder the query builder to build the rescore query from
*/ */
public QueryRescorer(QueryBuilder builder) { public QueryRescorer(QueryBuilder<?> builder) {
super(NAME); super(NAME);
this.queryBuilder = builder; this.queryBuilder = builder;
} }
/**
* @return the query used for this rescore query
*/
public QueryBuilder<?> getRescoreQuery() {
return this.queryBuilder;
}
/** /**
* Sets the original query weight for rescoring. The default is <tt>1.0</tt> * Sets the original query weight for rescoring. The default is <tt>1.0</tt>
*/ */
@ -103,6 +182,14 @@ public class RescoreBuilder implements ToXContent {
return this; return this;
} }
/**
* Gets the original query weight for rescoring. The default is <tt>1.0</tt>
*/
public float getQueryWeight() {
return this.queryWeight;
}
/** /**
* Sets the original query weight for rescoring. The default is <tt>1.0</tt> * Sets the original query weight for rescoring. The default is <tt>1.0</tt>
*/ */
@ -112,27 +199,76 @@ public class RescoreBuilder implements ToXContent {
} }
/** /**
* Sets the original query score mode. The default is <tt>total</tt> * Gets the original query weight for rescoring. The default is <tt>1.0</tt>
*/ */
public QueryRescorer setScoreMode(String scoreMode) { public float getRescoreQueryWeight() {
return this.rescoreQueryWeight;
}
/**
* Sets the original query score mode. The default is {@link QueryRescoreMode#Total}.
*/
public QueryRescorer setScoreMode(QueryRescoreMode scoreMode) {
this.scoreMode = scoreMode; this.scoreMode = scoreMode;
return this; return this;
} }
/**
* Gets the original query score mode. The default is <tt>total</tt>
*/
public QueryRescoreMode getScoreMode() {
return this.scoreMode;
}
@Override @Override
protected XContentBuilder innerToXContent(XContentBuilder builder, Params params) throws IOException { protected XContentBuilder innerToXContent(XContentBuilder builder, Params params) throws IOException {
builder.field("rescore_query", queryBuilder); builder.field("rescore_query", queryBuilder);
if (queryWeight != null) { builder.field("query_weight", queryWeight);
builder.field("query_weight", queryWeight); builder.field("rescore_query_weight", rescoreQueryWeight);
} builder.field("score_mode", scoreMode.name().toLowerCase(Locale.ROOT));
if (rescoreQueryWeight != null) {
builder.field("rescore_query_weight", rescoreQueryWeight);
}
if (scoreMode != null) {
builder.field("score_mode", scoreMode);
}
return builder; return builder;
} }
}
@Override
public final int hashCode() {
return Objects.hash(getClass(), scoreMode, queryWeight, rescoreQueryWeight, queryBuilder);
}
@Override
public final boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null || getClass() != obj.getClass()) {
return false;
}
QueryRescorer other = (QueryRescorer) obj;
return Objects.equals(scoreMode, other.scoreMode) &&
Objects.equals(queryWeight, other.queryWeight) &&
Objects.equals(rescoreQueryWeight, other.rescoreQueryWeight) &&
Objects.equals(queryBuilder, other.queryBuilder);
}
@Override
public QueryRescorer readFrom(StreamInput in) throws IOException {
QueryRescorer rescorer = new QueryRescorer(in.readQuery());
rescorer.setScoreMode(QueryRescoreMode.PROTOTYPE.readFrom(in));
rescorer.setRescoreQueryWeight(in.readFloat());
rescorer.setQueryWeight(in.readFloat());
return rescorer;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeQuery(queryBuilder);
scoreMode.writeTo(out);
out.writeFloat(rescoreQueryWeight);
out.writeFloat(queryWeight);
}
@Override
public String getWriteableName() {
return NAME;
}
}
} }

View File

@ -274,8 +274,7 @@ public class SearchSourceBuilderTests extends ESTestCase {
int numRescores = randomIntBetween(1, 5); int numRescores = randomIntBetween(1, 5);
for (int i = 0; i < numRescores; i++) { for (int i = 0; i < numRescores; i++) {
// NORELEASE need a random rescore builder method // NORELEASE need a random rescore builder method
RescoreBuilder rescoreBuilder = new RescoreBuilder(); RescoreBuilder rescoreBuilder = new RescoreBuilder(RescoreBuilder.queryRescorer(QueryBuilders.termQuery(randomAsciiOfLengthBetween(5, 20),
rescoreBuilder.rescorer(RescoreBuilder.queryRescorer(QueryBuilders.termQuery(randomAsciiOfLengthBetween(5, 20),
randomAsciiOfLengthBetween(5, 20)))); randomAsciiOfLengthBetween(5, 20))));
builder.addRescorer(rescoreBuilder); builder.addRescorer(rescoreBuilder);
} }

View File

@ -37,6 +37,7 @@ import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.query.functionscore.ScoreFunctionBuilders; import org.elasticsearch.index.query.functionscore.ScoreFunctionBuilders;
import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits; import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.rescore.QueryRescoreMode;
import org.elasticsearch.search.rescore.RescoreBuilder; import org.elasticsearch.search.rescore.RescoreBuilder;
import org.elasticsearch.search.rescore.RescoreBuilder.QueryRescorer; import org.elasticsearch.search.rescore.RescoreBuilder.QueryRescorer;
import org.elasticsearch.test.ESIntegTestCase; import org.elasticsearch.test.ESIntegTestCase;
@ -541,7 +542,7 @@ public class QueryRescorerIT extends ESIntegTestCase {
.setQueryWeight(0.5f).setRescoreQueryWeight(0.4f); .setQueryWeight(0.5f).setRescoreQueryWeight(0.4f);
if (!"".equals(scoreModes[innerMode])) { if (!"".equals(scoreModes[innerMode])) {
innerRescoreQuery.setScoreMode(scoreModes[innerMode]); innerRescoreQuery.setScoreMode(QueryRescoreMode.fromString(scoreModes[innerMode]));
} }
SearchResponse searchResponse = client() SearchResponse searchResponse = client()
@ -564,7 +565,7 @@ public class QueryRescorerIT extends ESIntegTestCase {
.boost(4.0f)).setQueryWeight(0.5f).setRescoreQueryWeight(0.4f); .boost(4.0f)).setQueryWeight(0.5f).setRescoreQueryWeight(0.4f);
if (!"".equals(scoreModes[outerMode])) { if (!"".equals(scoreModes[outerMode])) {
outerRescoreQuery.setScoreMode(scoreModes[outerMode]); outerRescoreQuery.setScoreMode(QueryRescoreMode.fromString(scoreModes[outerMode]));
} }
searchResponse = client() searchResponse = client()
@ -612,7 +613,7 @@ public class QueryRescorerIT extends ESIntegTestCase {
.setRescoreQueryWeight(secondaryWeight); .setRescoreQueryWeight(secondaryWeight);
if (!"".equals(scoreMode)) { if (!"".equals(scoreMode)) {
rescoreQuery.setScoreMode(scoreMode); rescoreQuery.setScoreMode(QueryRescoreMode.fromString(scoreMode));
} }
SearchResponse rescored = client() SearchResponse rescored = client()
@ -683,11 +684,11 @@ public class QueryRescorerIT extends ESIntegTestCase {
int numDocs = indexRandomNumbers("keyword", 1, true); int numDocs = indexRandomNumbers("keyword", 1, true);
QueryRescorer eightIsGreat = RescoreBuilder.queryRescorer( QueryRescorer eightIsGreat = RescoreBuilder.queryRescorer(
QueryBuilders.functionScoreQuery(QueryBuilders.termQuery("field1", English.intToEnglish(8)), QueryBuilders.functionScoreQuery(QueryBuilders.termQuery("field1", English.intToEnglish(8)),
ScoreFunctionBuilders.weightFactorFunction(1000.0f)).boostMode(CombineFunction.REPLACE)).setScoreMode("total"); ScoreFunctionBuilders.weightFactorFunction(1000.0f)).boostMode(CombineFunction.REPLACE)).setScoreMode(QueryRescoreMode.Total);
QueryRescorer sevenIsBetter = RescoreBuilder.queryRescorer( QueryRescorer sevenIsBetter = RescoreBuilder.queryRescorer(
QueryBuilders.functionScoreQuery(QueryBuilders.termQuery("field1", English.intToEnglish(7)), QueryBuilders.functionScoreQuery(QueryBuilders.termQuery("field1", English.intToEnglish(7)),
ScoreFunctionBuilders.weightFactorFunction(10000.0f)).boostMode(CombineFunction.REPLACE)) ScoreFunctionBuilders.weightFactorFunction(10000.0f)).boostMode(CombineFunction.REPLACE))
.setScoreMode("total"); .setScoreMode(QueryRescoreMode.Total);
// First set the rescore window large enough that both rescores take effect // First set the rescore window large enough that both rescores take effect
SearchRequestBuilder request = client().prepareSearch(); SearchRequestBuilder request = client().prepareSearch();
@ -704,10 +705,10 @@ public class QueryRescorerIT extends ESIntegTestCase {
// Now use one rescore to drag the number we're looking for into the window of another // Now use one rescore to drag the number we're looking for into the window of another
QueryRescorer ninetyIsGood = RescoreBuilder.queryRescorer( QueryRescorer ninetyIsGood = RescoreBuilder.queryRescorer(
QueryBuilders.functionScoreQuery(QueryBuilders.queryStringQuery("*ninety*"), ScoreFunctionBuilders.weightFactorFunction(1000.0f)) QueryBuilders.functionScoreQuery(QueryBuilders.queryStringQuery("*ninety*"), ScoreFunctionBuilders.weightFactorFunction(1000.0f))
.boostMode(CombineFunction.REPLACE)).setScoreMode("total"); .boostMode(CombineFunction.REPLACE)).setScoreMode(QueryRescoreMode.Total);
QueryRescorer oneToo = RescoreBuilder.queryRescorer( QueryRescorer oneToo = RescoreBuilder.queryRescorer(
QueryBuilders.functionScoreQuery(QueryBuilders.queryStringQuery("*one*"), ScoreFunctionBuilders.weightFactorFunction(1000.0f)) QueryBuilders.functionScoreQuery(QueryBuilders.queryStringQuery("*one*"), ScoreFunctionBuilders.weightFactorFunction(1000.0f))
.boostMode(CombineFunction.REPLACE)).setScoreMode("total"); .boostMode(CombineFunction.REPLACE)).setScoreMode(QueryRescoreMode.Total);
request.clearRescorers().addRescorer(ninetyIsGood, numDocs).addRescorer(oneToo, 10); request.clearRescorers().addRescorer(ninetyIsGood, numDocs).addRescorer(oneToo, 10);
response = request.setSize(2).get(); response = request.setSize(2).get();
assertFirstHit(response, hasId("91")); assertFirstHit(response, hasId("91"));

View File

@ -0,0 +1,170 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.search.rescore;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.index.query.MatchAllQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.search.rescore.RescoreBuilder.QueryRescorer;
import org.elasticsearch.search.rescore.RescoreBuilder.Rescorer;
import org.elasticsearch.test.ESTestCase;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import java.io.IOException;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.not;
public class QueryRescoreBuilderTests extends ESTestCase {
private static final int NUMBER_OF_TESTBUILDERS = 20;
private static NamedWriteableRegistry namedWriteableRegistry;
/**
* setup for the whole base test class
*/
@BeforeClass
public static void init() {
namedWriteableRegistry = new NamedWriteableRegistry();
namedWriteableRegistry.registerPrototype(Rescorer.class, org.elasticsearch.search.rescore.RescoreBuilder.QueryRescorer.PROTOTYPE);
namedWriteableRegistry.registerPrototype(QueryBuilder.class, new MatchAllQueryBuilder());
}
@AfterClass
public static void afterClass() throws Exception {
namedWriteableRegistry = null;
}
/**
* Test serialization and deserialization of the rescore builder
*/
public void testSerialization() throws IOException {
for (int runs = 0; runs < NUMBER_OF_TESTBUILDERS; runs++) {
RescoreBuilder original = randomRescoreBuilder();
RescoreBuilder deserialized = serializedCopy(original);
assertEquals(deserialized, original);
assertEquals(deserialized.hashCode(), original.hashCode());
assertNotSame(deserialized, original);
}
}
/**
* Test equality and hashCode properties
*/
public void testEqualsAndHashcode() throws IOException {
for (int runs = 0; runs < NUMBER_OF_TESTBUILDERS; runs++) {
RescoreBuilder firstBuilder = randomRescoreBuilder();
assertFalse("rescore builder is equal to null", firstBuilder.equals(null));
assertFalse("rescore builder is equal to incompatible type", firstBuilder.equals(""));
assertTrue("rescore builder is not equal to self", firstBuilder.equals(firstBuilder));
assertThat("same rescore builder's hashcode returns different values if called multiple times", firstBuilder.hashCode(),
equalTo(firstBuilder.hashCode()));
assertThat("different rescore builder should not be equal", mutate(firstBuilder), not(equalTo(firstBuilder)));
RescoreBuilder secondBuilder = serializedCopy(firstBuilder);
assertTrue("rescore builder is not equal to self", secondBuilder.equals(secondBuilder));
assertTrue("rescore builder is not equal to its copy", firstBuilder.equals(secondBuilder));
assertTrue("equals is not symmetric", secondBuilder.equals(firstBuilder));
assertThat("rescore builder copy's hashcode is different from original hashcode", secondBuilder.hashCode(), equalTo(firstBuilder.hashCode()));
RescoreBuilder thirdBuilder = serializedCopy(secondBuilder);
assertTrue("rescore builder is not equal to self", thirdBuilder.equals(thirdBuilder));
assertTrue("rescore builder is not equal to its copy", secondBuilder.equals(thirdBuilder));
assertThat("rescore builder copy's hashcode is different from original hashcode", secondBuilder.hashCode(), equalTo(thirdBuilder.hashCode()));
assertTrue("equals is not transitive", firstBuilder.equals(thirdBuilder));
assertThat("rescore builder copy's hashcode is different from original hashcode", firstBuilder.hashCode(), equalTo(thirdBuilder.hashCode()));
assertTrue("equals is not symmetric", thirdBuilder.equals(secondBuilder));
assertTrue("equals is not symmetric", thirdBuilder.equals(firstBuilder));
}
}
private RescoreBuilder mutate(RescoreBuilder original) throws IOException {
RescoreBuilder mutation = serializedCopy(original);
if (randomBoolean()) {
Integer windowSize = original.windowSize();
if (windowSize != null) {
mutation.windowSize(windowSize + 1);
} else {
mutation.windowSize(randomIntBetween(0, 100));
}
} else {
QueryRescorer queryRescorer = (QueryRescorer) mutation.rescorer();
switch (randomIntBetween(0, 3)) {
case 0:
queryRescorer.setQueryWeight(queryRescorer.getQueryWeight() + 0.1f);
break;
case 1:
queryRescorer.setRescoreQueryWeight(queryRescorer.getRescoreQueryWeight() + 0.1f);
break;
case 2:
QueryRescoreMode other;
do {
other = randomFrom(QueryRescoreMode.values());
} while (other == queryRescorer.getScoreMode());
queryRescorer.setScoreMode(other);
break;
case 3:
// only increase the boost to make it a slightly different query
queryRescorer.getRescoreQuery().boost(queryRescorer.getRescoreQuery().boost() + 0.1f);
break;
default:
throw new IllegalStateException("unexpected random mutation in test");
}
}
return mutation;
}
/**
* create random shape that is put under test
*/
private static RescoreBuilder randomRescoreBuilder() {
QueryBuilder<MatchAllQueryBuilder> queryBuilder = new MatchAllQueryBuilder().boost(randomFloat()).queryName(randomAsciiOfLength(20));
org.elasticsearch.search.rescore.RescoreBuilder.QueryRescorer rescorer = new
org.elasticsearch.search.rescore.RescoreBuilder.QueryRescorer(queryBuilder);
if (randomBoolean()) {
rescorer.setQueryWeight(randomFloat());
}
if (randomBoolean()) {
rescorer.setRescoreQueryWeight(randomFloat());
}
if (randomBoolean()) {
rescorer.setScoreMode(randomFrom(QueryRescoreMode.values()));
}
RescoreBuilder builder = new RescoreBuilder(rescorer);
if (randomBoolean()) {
builder.windowSize(randomIntBetween(0, 100));
}
return builder;
}
private static RescoreBuilder serializedCopy(RescoreBuilder original) throws IOException {
try (BytesStreamOutput output = new BytesStreamOutput()) {
original.writeTo(output);
try (StreamInput in = new NamedWriteableAwareStreamInput(StreamInput.wrap(output.bytes()), namedWriteableRegistry)) {
return RescoreBuilder.PROTOYPE.readFrom(in);
}
}
}
}

View File

@ -0,0 +1,58 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.search.rescore;
import org.elasticsearch.test.ESTestCase;
import java.io.IOException;
/**
* Test fixing the ordinals and names in {@link QueryRescoreMode}. These should not be changed since we
* use the names in the parser and the ordinals in serialization.
*/
public class QueryRescoreModeTests extends ESTestCase {
/**
* Test @link {@link QueryRescoreMode} enum ordinals and names, since serilaization relies on it
*/
public void testQueryRescoreMode() throws IOException {
float primary = randomFloat();
float secondary = randomFloat();
assertEquals(0, QueryRescoreMode.Avg.ordinal());
assertEquals("avg", QueryRescoreMode.Avg.toString());
assertEquals((primary + secondary)/2.0f, QueryRescoreMode.Avg.combine(primary, secondary), Float.MIN_VALUE);
assertEquals(1, QueryRescoreMode.Max.ordinal());
assertEquals("max", QueryRescoreMode.Max.toString());
assertEquals(Math.max(primary, secondary), QueryRescoreMode.Max.combine(primary, secondary), Float.MIN_VALUE);
assertEquals(2, QueryRescoreMode.Min.ordinal());
assertEquals("min", QueryRescoreMode.Min.toString());
assertEquals(Math.min(primary, secondary), QueryRescoreMode.Min.combine(primary, secondary), Float.MIN_VALUE);
assertEquals(3, QueryRescoreMode.Total.ordinal());
assertEquals("sum", QueryRescoreMode.Total.toString());
assertEquals(primary + secondary, QueryRescoreMode.Total.combine(primary, secondary), Float.MIN_VALUE);
assertEquals(4, QueryRescoreMode.Multiply.ordinal());
assertEquals("product", QueryRescoreMode.Multiply.toString());
assertEquals(primary * secondary, QueryRescoreMode.Multiply.combine(primary, secondary), Float.MIN_VALUE);
}
}