From 76192024a8ecc69ddb8a13b96d38aec1d4308243 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christoph=20B=C3=BCscher?= Date: Mon, 14 Dec 2015 13:04:10 +0100 Subject: [PATCH] Make RescoreBuilder and nested QueryRescorer Writable Adding serialization capabilities to RescoreBuilder and make all QueryRescorer implement NamedWritable, also requiring all implementations of RescoreBuilder.Rescorer to implement equals() and hashCode. In addition, the current rescore mode enumeration is pulled out to a separate class to make sharing of constants easier between the query builders XContent rendering coder and the parser. --- .../action/search/SearchRequestBuilder.java | 4 +- .../common/io/stream/StreamInput.java | 8 + .../common/io/stream/StreamOutput.java | 9 +- .../search/rescore/QueryRescoreMode.java | 117 +++++++++++ .../search/rescore/QueryRescorer.java | 87 +------- .../search/rescore/RescoreBuilder.java | 192 +++++++++++++++--- .../builder/SearchSourceBuilderTests.java | 3 +- .../search/functionscore/QueryRescorerIT.java | 15 +- .../rescore/QueryRescoreBuilderTests.java | 170 ++++++++++++++++ .../search/rescore/QueryRescoreModeTests.java | 58 ++++++ 10 files changed, 543 insertions(+), 120 deletions(-) create mode 100644 core/src/main/java/org/elasticsearch/search/rescore/QueryRescoreMode.java create mode 100644 core/src/test/java/org/elasticsearch/search/rescore/QueryRescoreBuilderTests.java create mode 100644 core/src/test/java/org/elasticsearch/search/rescore/QueryRescoreModeTests.java diff --git a/core/src/main/java/org/elasticsearch/action/search/SearchRequestBuilder.java b/core/src/main/java/org/elasticsearch/action/search/SearchRequestBuilder.java index 52d45ec9407..1557c266bd4 100644 --- a/core/src/main/java/org/elasticsearch/action/search/SearchRequestBuilder.java +++ b/core/src/main/java/org/elasticsearch/action/search/SearchRequestBuilder.java @@ -421,7 +421,7 @@ public class SearchRequestBuilder extends ActionRequestBuilder { + Avg { + @Override + public float combine(float primary, float secondary) { + return (primary + secondary) / 2; + } + + @Override + public String toString() { + return "avg"; + } + }, + Max { + @Override + public float combine(float primary, float secondary) { + return Math.max(primary, secondary); + } + + @Override + public String toString() { + return "max"; + } + }, + Min { + @Override + public float combine(float primary, float secondary) { + return Math.min(primary, secondary); + } + + @Override + public String toString() { + return "min"; + } + }, + Total { + @Override + public float combine(float primary, float secondary) { + return primary + secondary; + } + + @Override + public String toString() { + return "sum"; + } + }, + Multiply { + @Override + public float combine(float primary, float secondary) { + return primary * secondary; + } + + @Override + public String toString() { + return "product"; + } + }; + + public abstract float combine(float primary, float secondary); + + static QueryRescoreMode PROTOTYPE = Total; + + @Override + public QueryRescoreMode readFrom(StreamInput in) throws IOException { + int ordinal = in.readVInt(); + if (ordinal < 0 || ordinal >= values().length) { + throw new IOException("Unknown ScoreMode ordinal [" + ordinal + "]"); + } + return values()[ordinal]; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeVInt(this.ordinal()); + } + + public static QueryRescoreMode fromString(String scoreMode) { + for (QueryRescoreMode mode : values()) { + if (scoreMode.toLowerCase(Locale.ROOT).equals(mode.name().toLowerCase(Locale.ROOT))) { + return mode; + } + } + throw new IllegalArgumentException("illegal score_mode [" + scoreMode + "]"); + } + + @Override + public String toString() { + return name().toLowerCase(Locale.ROOT); + } +} \ No newline at end of file diff --git a/core/src/main/java/org/elasticsearch/search/rescore/QueryRescorer.java b/core/src/main/java/org/elasticsearch/search/rescore/QueryRescorer.java index afbd034840e..7f95ff10824 100644 --- a/core/src/main/java/org/elasticsearch/search/rescore/QueryRescorer.java +++ b/core/src/main/java/org/elasticsearch/search/rescore/QueryRescorer.java @@ -38,66 +38,6 @@ import java.util.Set; public final class QueryRescorer implements Rescorer { - private static enum ScoreMode { - Avg { - @Override - public float combine(float primary, float secondary) { - return (primary + secondary) / 2; - } - - @Override - public String toString() { - return "avg"; - } - }, - Max { - @Override - public float combine(float primary, float secondary) { - return Math.max(primary, secondary); - } - - @Override - public String toString() { - return "max"; - } - }, - Min { - @Override - public float combine(float primary, float secondary) { - return Math.min(primary, secondary); - } - - @Override - public String toString() { - return "min"; - } - }, - Total { - @Override - public float combine(float primary, float secondary) { - return primary + secondary; - } - - @Override - public String toString() { - return "sum"; - } - }, - Multiply { - @Override - public float combine(float primary, float secondary) { - return primary * secondary; - } - - @Override - public String toString() { - return "product"; - } - }; - - public abstract float combine(float primary, float secondary); - } - public static final Rescorer INSTANCE = new QueryRescorer(); public static final String NAME = "query"; @@ -170,7 +110,7 @@ public final class QueryRescorer implements Rescorer { rescoreExplain.getValue() * secondaryWeight, "product of:", rescoreExplain, Explanation.match(secondaryWeight, "secondaryWeight")); - ScoreMode scoreMode = rescore.scoreMode(); + QueryRescoreMode scoreMode = rescore.scoreMode(); return Explanation.match( scoreMode.combine(prim.getValue(), sec.getValue()), scoreMode + " of:", @@ -228,7 +168,7 @@ public final class QueryRescorer implements Rescorer { // secondary score? in.scoreDocs[i].score *= ctx.queryWeight(); } - + // TODO: this is wrong, i.e. we are comparing apples and oranges at this point. It would be better if we always rescored all // incoming first pass hits, instead of allowing recoring of just the top subset: Arrays.sort(in.scoreDocs, SCORE_DOC_COMPARATOR); @@ -240,13 +180,13 @@ public final class QueryRescorer implements Rescorer { public QueryRescoreContext(QueryRescorer rescorer) { super(NAME, 10, rescorer); - this.scoreMode = ScoreMode.Total; + this.scoreMode = QueryRescoreMode.Total; } private ParsedQuery parsedQuery; private float queryWeight = 1.0f; private float rescoreQueryWeight = 1.0f; - private ScoreMode scoreMode; + private QueryRescoreMode scoreMode; public void setParsedQuery(ParsedQuery parsedQuery) { this.parsedQuery = parsedQuery; @@ -264,7 +204,7 @@ public final class QueryRescorer implements Rescorer { return rescoreQueryWeight; } - public ScoreMode scoreMode() { + public QueryRescoreMode scoreMode() { return scoreMode; } @@ -276,26 +216,13 @@ public final class QueryRescorer implements Rescorer { this.queryWeight = queryWeight; } - public void setScoreMode(ScoreMode scoreMode) { + public void setScoreMode(QueryRescoreMode scoreMode) { this.scoreMode = scoreMode; } public void setScoreMode(String scoreMode) { - if ("avg".equals(scoreMode)) { - setScoreMode(ScoreMode.Avg); - } else if ("max".equals(scoreMode)) { - setScoreMode(ScoreMode.Max); - } else if ("min".equals(scoreMode)) { - setScoreMode(ScoreMode.Min); - } else if ("total".equals(scoreMode)) { - setScoreMode(ScoreMode.Total); - } else if ("multiply".equals(scoreMode)) { - setScoreMode(ScoreMode.Multiply); - } else { - throw new IllegalArgumentException("illegal score_mode [" + scoreMode + "]"); - } + setScoreMode(QueryRescoreMode.fromString(scoreMode)); } - } @Override diff --git a/core/src/main/java/org/elasticsearch/search/rescore/RescoreBuilder.java b/core/src/main/java/org/elasticsearch/search/rescore/RescoreBuilder.java index fde282427d7..7510d24f82d 100644 --- a/core/src/main/java/org/elasticsearch/search/rescore/RescoreBuilder.java +++ b/core/src/main/java/org/elasticsearch/search/rescore/RescoreBuilder.java @@ -19,24 +19,36 @@ package org.elasticsearch.search.rescore; +import org.elasticsearch.ExceptionsHelper; +import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentFactory; +import org.elasticsearch.index.query.MatchAllQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; import java.io.IOException; +import java.util.Locale; +import java.util.Objects; -public class RescoreBuilder implements ToXContent { +public class RescoreBuilder implements ToXContent, Writeable { private Rescorer rescorer; private Integer windowSize; + public static final RescoreBuilder PROTOYPE = new RescoreBuilder(new QueryRescorer(new MatchAllQueryBuilder())); - public static QueryRescorer queryRescorer(QueryBuilder queryBuilder) { - return new QueryRescorer(queryBuilder); + public RescoreBuilder(Rescorer rescorer) { + if (rescorer == null) { + throw new IllegalArgumentException("rescorer cannot be null"); + } + this.rescorer = rescorer; } - public RescoreBuilder rescorer(Rescorer rescorer) { - this.rescorer = rescorer; - return this; + public Rescorer rescorer() { + return this.rescorer; } public RescoreBuilder windowSize(int windowSize) { @@ -48,10 +60,6 @@ public class RescoreBuilder implements ToXContent { return windowSize; } - public boolean isEmpty() { - return rescorer == null; - } - @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { if (windowSize != null) { @@ -61,13 +69,66 @@ public class RescoreBuilder implements ToXContent { return builder; } - public static abstract class Rescorer implements ToXContent { + public static QueryRescorer queryRescorer(QueryBuilder queryBuilder) { + return new QueryRescorer(queryBuilder); + } + + @Override + public final int hashCode() { + return Objects.hash(windowSize, rescorer); + } + + @Override + public final boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + RescoreBuilder other = (RescoreBuilder) obj; + return Objects.equals(windowSize, other.windowSize) && + Objects.equals(rescorer, other.rescorer); + } + + @Override + public RescoreBuilder readFrom(StreamInput in) throws IOException { + RescoreBuilder builder = new RescoreBuilder(in.readRescorer()); + Integer windowSize = in.readOptionalVInt(); + if (windowSize != null) { + builder.windowSize(windowSize); + } + return builder; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeRescorer(rescorer); + out.writeOptionalVInt(this.windowSize); + } + + @Override + public final String toString() { + try { + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.prettyPrint(); + builder.startObject(); + toXContent(builder, EMPTY_PARAMS); + builder.endObject(); + return builder.string(); + } catch (Exception e) { + return "{ \"error\" : \"" + ExceptionsHelper.detailedMessage(e) + "\"}"; + } + } + + public static abstract class Rescorer implements ToXContent, NamedWriteable { private String name; public Rescorer(String name) { this.name = name; } + @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(name); @@ -78,23 +139,41 @@ public class RescoreBuilder implements ToXContent { protected abstract XContentBuilder innerToXContent(XContentBuilder builder, Params params) throws IOException; + @Override + public abstract int hashCode(); + + @Override + public abstract boolean equals(Object obj); } public static class QueryRescorer extends Rescorer { + private static final String NAME = "query"; - private QueryBuilder queryBuilder; - private Float rescoreQueryWeight; - private Float queryWeight; - private String scoreMode; + public static final QueryRescorer PROTOTYPE = new QueryRescorer(new MatchAllQueryBuilder()); + public static final float DEFAULT_RESCORE_QUERYWEIGHT = 1.0f; + public static final float DEFAULT_QUERYWEIGHT = 1.0f; + public static final QueryRescoreMode DEFAULT_SCORE_MODE = QueryRescoreMode.Total; + private final QueryBuilder queryBuilder; + private float rescoreQueryWeight = DEFAULT_RESCORE_QUERYWEIGHT; + private float queryWeight = DEFAULT_QUERYWEIGHT; + private QueryRescoreMode scoreMode = DEFAULT_SCORE_MODE; /** * Creates a new {@link QueryRescorer} instance * @param builder the query builder to build the rescore query from */ - public QueryRescorer(QueryBuilder builder) { + public QueryRescorer(QueryBuilder builder) { super(NAME); this.queryBuilder = builder; } + + /** + * @return the query used for this rescore query + */ + public QueryBuilder getRescoreQuery() { + return this.queryBuilder; + } + /** * Sets the original query weight for rescoring. The default is 1.0 */ @@ -103,6 +182,14 @@ public class RescoreBuilder implements ToXContent { return this; } + + /** + * Gets the original query weight for rescoring. The default is 1.0 + */ + public float getQueryWeight() { + return this.queryWeight; + } + /** * Sets the original query weight for rescoring. The default is 1.0 */ @@ -112,27 +199,76 @@ public class RescoreBuilder implements ToXContent { } /** - * Sets the original query score mode. The default is total + * Gets the original query weight for rescoring. The default is 1.0 */ - public QueryRescorer setScoreMode(String scoreMode) { + public float getRescoreQueryWeight() { + return this.rescoreQueryWeight; + } + + /** + * Sets the original query score mode. The default is {@link QueryRescoreMode#Total}. + */ + public QueryRescorer setScoreMode(QueryRescoreMode scoreMode) { this.scoreMode = scoreMode; return this; } + /** + * Gets the original query score mode. The default is total + */ + public QueryRescoreMode getScoreMode() { + return this.scoreMode; + } + @Override protected XContentBuilder innerToXContent(XContentBuilder builder, Params params) throws IOException { builder.field("rescore_query", queryBuilder); - if (queryWeight != null) { - builder.field("query_weight", queryWeight); - } - if (rescoreQueryWeight != null) { - builder.field("rescore_query_weight", rescoreQueryWeight); - } - if (scoreMode != null) { - builder.field("score_mode", scoreMode); - } + builder.field("query_weight", queryWeight); + builder.field("rescore_query_weight", rescoreQueryWeight); + builder.field("score_mode", scoreMode.name().toLowerCase(Locale.ROOT)); return builder; } - } + @Override + public final int hashCode() { + return Objects.hash(getClass(), scoreMode, queryWeight, rescoreQueryWeight, queryBuilder); + } + + @Override + public final boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + QueryRescorer other = (QueryRescorer) obj; + return Objects.equals(scoreMode, other.scoreMode) && + Objects.equals(queryWeight, other.queryWeight) && + Objects.equals(rescoreQueryWeight, other.rescoreQueryWeight) && + Objects.equals(queryBuilder, other.queryBuilder); + } + + @Override + public QueryRescorer readFrom(StreamInput in) throws IOException { + QueryRescorer rescorer = new QueryRescorer(in.readQuery()); + rescorer.setScoreMode(QueryRescoreMode.PROTOTYPE.readFrom(in)); + rescorer.setRescoreQueryWeight(in.readFloat()); + rescorer.setQueryWeight(in.readFloat()); + return rescorer; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeQuery(queryBuilder); + scoreMode.writeTo(out); + out.writeFloat(rescoreQueryWeight); + out.writeFloat(queryWeight); + } + + @Override + public String getWriteableName() { + return NAME; + } + } } diff --git a/core/src/test/java/org/elasticsearch/search/builder/SearchSourceBuilderTests.java b/core/src/test/java/org/elasticsearch/search/builder/SearchSourceBuilderTests.java index b80810fc6d5..d7ede712447 100644 --- a/core/src/test/java/org/elasticsearch/search/builder/SearchSourceBuilderTests.java +++ b/core/src/test/java/org/elasticsearch/search/builder/SearchSourceBuilderTests.java @@ -274,8 +274,7 @@ public class SearchSourceBuilderTests extends ESTestCase { int numRescores = randomIntBetween(1, 5); for (int i = 0; i < numRescores; i++) { // NORELEASE need a random rescore builder method - RescoreBuilder rescoreBuilder = new RescoreBuilder(); - rescoreBuilder.rescorer(RescoreBuilder.queryRescorer(QueryBuilders.termQuery(randomAsciiOfLengthBetween(5, 20), + RescoreBuilder rescoreBuilder = new RescoreBuilder(RescoreBuilder.queryRescorer(QueryBuilders.termQuery(randomAsciiOfLengthBetween(5, 20), randomAsciiOfLengthBetween(5, 20)))); builder.addRescorer(rescoreBuilder); } diff --git a/core/src/test/java/org/elasticsearch/search/functionscore/QueryRescorerIT.java b/core/src/test/java/org/elasticsearch/search/functionscore/QueryRescorerIT.java index 861701a9f6a..5644f893603 100644 --- a/core/src/test/java/org/elasticsearch/search/functionscore/QueryRescorerIT.java +++ b/core/src/test/java/org/elasticsearch/search/functionscore/QueryRescorerIT.java @@ -37,6 +37,7 @@ import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.query.functionscore.ScoreFunctionBuilders; import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHits; +import org.elasticsearch.search.rescore.QueryRescoreMode; import org.elasticsearch.search.rescore.RescoreBuilder; import org.elasticsearch.search.rescore.RescoreBuilder.QueryRescorer; import org.elasticsearch.test.ESIntegTestCase; @@ -541,7 +542,7 @@ public class QueryRescorerIT extends ESIntegTestCase { .setQueryWeight(0.5f).setRescoreQueryWeight(0.4f); if (!"".equals(scoreModes[innerMode])) { - innerRescoreQuery.setScoreMode(scoreModes[innerMode]); + innerRescoreQuery.setScoreMode(QueryRescoreMode.fromString(scoreModes[innerMode])); } SearchResponse searchResponse = client() @@ -564,7 +565,7 @@ public class QueryRescorerIT extends ESIntegTestCase { .boost(4.0f)).setQueryWeight(0.5f).setRescoreQueryWeight(0.4f); if (!"".equals(scoreModes[outerMode])) { - outerRescoreQuery.setScoreMode(scoreModes[outerMode]); + outerRescoreQuery.setScoreMode(QueryRescoreMode.fromString(scoreModes[outerMode])); } searchResponse = client() @@ -612,7 +613,7 @@ public class QueryRescorerIT extends ESIntegTestCase { .setRescoreQueryWeight(secondaryWeight); if (!"".equals(scoreMode)) { - rescoreQuery.setScoreMode(scoreMode); + rescoreQuery.setScoreMode(QueryRescoreMode.fromString(scoreMode)); } SearchResponse rescored = client() @@ -683,11 +684,11 @@ public class QueryRescorerIT extends ESIntegTestCase { int numDocs = indexRandomNumbers("keyword", 1, true); QueryRescorer eightIsGreat = RescoreBuilder.queryRescorer( QueryBuilders.functionScoreQuery(QueryBuilders.termQuery("field1", English.intToEnglish(8)), - ScoreFunctionBuilders.weightFactorFunction(1000.0f)).boostMode(CombineFunction.REPLACE)).setScoreMode("total"); + ScoreFunctionBuilders.weightFactorFunction(1000.0f)).boostMode(CombineFunction.REPLACE)).setScoreMode(QueryRescoreMode.Total); QueryRescorer sevenIsBetter = RescoreBuilder.queryRescorer( QueryBuilders.functionScoreQuery(QueryBuilders.termQuery("field1", English.intToEnglish(7)), ScoreFunctionBuilders.weightFactorFunction(10000.0f)).boostMode(CombineFunction.REPLACE)) - .setScoreMode("total"); + .setScoreMode(QueryRescoreMode.Total); // First set the rescore window large enough that both rescores take effect SearchRequestBuilder request = client().prepareSearch(); @@ -704,10 +705,10 @@ public class QueryRescorerIT extends ESIntegTestCase { // Now use one rescore to drag the number we're looking for into the window of another QueryRescorer ninetyIsGood = RescoreBuilder.queryRescorer( QueryBuilders.functionScoreQuery(QueryBuilders.queryStringQuery("*ninety*"), ScoreFunctionBuilders.weightFactorFunction(1000.0f)) - .boostMode(CombineFunction.REPLACE)).setScoreMode("total"); + .boostMode(CombineFunction.REPLACE)).setScoreMode(QueryRescoreMode.Total); QueryRescorer oneToo = RescoreBuilder.queryRescorer( QueryBuilders.functionScoreQuery(QueryBuilders.queryStringQuery("*one*"), ScoreFunctionBuilders.weightFactorFunction(1000.0f)) - .boostMode(CombineFunction.REPLACE)).setScoreMode("total"); + .boostMode(CombineFunction.REPLACE)).setScoreMode(QueryRescoreMode.Total); request.clearRescorers().addRescorer(ninetyIsGood, numDocs).addRescorer(oneToo, 10); response = request.setSize(2).get(); assertFirstHit(response, hasId("91")); diff --git a/core/src/test/java/org/elasticsearch/search/rescore/QueryRescoreBuilderTests.java b/core/src/test/java/org/elasticsearch/search/rescore/QueryRescoreBuilderTests.java new file mode 100644 index 00000000000..2aa55f8b626 --- /dev/null +++ b/core/src/test/java/org/elasticsearch/search/rescore/QueryRescoreBuilderTests.java @@ -0,0 +1,170 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.search.rescore; + +import org.elasticsearch.common.io.stream.BytesStreamOutput; +import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.index.query.MatchAllQueryBuilder; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.search.rescore.RescoreBuilder.QueryRescorer; +import org.elasticsearch.search.rescore.RescoreBuilder.Rescorer; +import org.elasticsearch.test.ESTestCase; +import org.junit.AfterClass; +import org.junit.BeforeClass; + +import java.io.IOException; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.not; + +public class QueryRescoreBuilderTests extends ESTestCase { + + private static final int NUMBER_OF_TESTBUILDERS = 20; + private static NamedWriteableRegistry namedWriteableRegistry; + + /** + * setup for the whole base test class + */ + @BeforeClass + public static void init() { + namedWriteableRegistry = new NamedWriteableRegistry(); + namedWriteableRegistry.registerPrototype(Rescorer.class, org.elasticsearch.search.rescore.RescoreBuilder.QueryRescorer.PROTOTYPE); + namedWriteableRegistry.registerPrototype(QueryBuilder.class, new MatchAllQueryBuilder()); + } + + @AfterClass + public static void afterClass() throws Exception { + namedWriteableRegistry = null; + } + + /** + * Test serialization and deserialization of the rescore builder + */ + public void testSerialization() throws IOException { + for (int runs = 0; runs < NUMBER_OF_TESTBUILDERS; runs++) { + RescoreBuilder original = randomRescoreBuilder(); + RescoreBuilder deserialized = serializedCopy(original); + assertEquals(deserialized, original); + assertEquals(deserialized.hashCode(), original.hashCode()); + assertNotSame(deserialized, original); + } + } + + /** + * Test equality and hashCode properties + */ + public void testEqualsAndHashcode() throws IOException { + for (int runs = 0; runs < NUMBER_OF_TESTBUILDERS; runs++) { + RescoreBuilder firstBuilder = randomRescoreBuilder(); + assertFalse("rescore builder is equal to null", firstBuilder.equals(null)); + assertFalse("rescore builder is equal to incompatible type", firstBuilder.equals("")); + assertTrue("rescore builder is not equal to self", firstBuilder.equals(firstBuilder)); + assertThat("same rescore builder's hashcode returns different values if called multiple times", firstBuilder.hashCode(), + equalTo(firstBuilder.hashCode())); + assertThat("different rescore builder should not be equal", mutate(firstBuilder), not(equalTo(firstBuilder))); + + RescoreBuilder secondBuilder = serializedCopy(firstBuilder); + assertTrue("rescore builder is not equal to self", secondBuilder.equals(secondBuilder)); + assertTrue("rescore builder is not equal to its copy", firstBuilder.equals(secondBuilder)); + assertTrue("equals is not symmetric", secondBuilder.equals(firstBuilder)); + assertThat("rescore builder copy's hashcode is different from original hashcode", secondBuilder.hashCode(), equalTo(firstBuilder.hashCode())); + + RescoreBuilder thirdBuilder = serializedCopy(secondBuilder); + assertTrue("rescore builder is not equal to self", thirdBuilder.equals(thirdBuilder)); + assertTrue("rescore builder is not equal to its copy", secondBuilder.equals(thirdBuilder)); + assertThat("rescore builder copy's hashcode is different from original hashcode", secondBuilder.hashCode(), equalTo(thirdBuilder.hashCode())); + assertTrue("equals is not transitive", firstBuilder.equals(thirdBuilder)); + assertThat("rescore builder copy's hashcode is different from original hashcode", firstBuilder.hashCode(), equalTo(thirdBuilder.hashCode())); + assertTrue("equals is not symmetric", thirdBuilder.equals(secondBuilder)); + assertTrue("equals is not symmetric", thirdBuilder.equals(firstBuilder)); + } + } + + private RescoreBuilder mutate(RescoreBuilder original) throws IOException { + RescoreBuilder mutation = serializedCopy(original); + if (randomBoolean()) { + Integer windowSize = original.windowSize(); + if (windowSize != null) { + mutation.windowSize(windowSize + 1); + } else { + mutation.windowSize(randomIntBetween(0, 100)); + } + } else { + QueryRescorer queryRescorer = (QueryRescorer) mutation.rescorer(); + switch (randomIntBetween(0, 3)) { + case 0: + queryRescorer.setQueryWeight(queryRescorer.getQueryWeight() + 0.1f); + break; + case 1: + queryRescorer.setRescoreQueryWeight(queryRescorer.getRescoreQueryWeight() + 0.1f); + break; + case 2: + QueryRescoreMode other; + do { + other = randomFrom(QueryRescoreMode.values()); + } while (other == queryRescorer.getScoreMode()); + queryRescorer.setScoreMode(other); + break; + case 3: + // only increase the boost to make it a slightly different query + queryRescorer.getRescoreQuery().boost(queryRescorer.getRescoreQuery().boost() + 0.1f); + break; + default: + throw new IllegalStateException("unexpected random mutation in test"); + } + } + return mutation; + } + + /** + * create random shape that is put under test + */ + private static RescoreBuilder randomRescoreBuilder() { + QueryBuilder queryBuilder = new MatchAllQueryBuilder().boost(randomFloat()).queryName(randomAsciiOfLength(20)); + org.elasticsearch.search.rescore.RescoreBuilder.QueryRescorer rescorer = new + org.elasticsearch.search.rescore.RescoreBuilder.QueryRescorer(queryBuilder); + if (randomBoolean()) { + rescorer.setQueryWeight(randomFloat()); + } + if (randomBoolean()) { + rescorer.setRescoreQueryWeight(randomFloat()); + } + if (randomBoolean()) { + rescorer.setScoreMode(randomFrom(QueryRescoreMode.values())); + } + RescoreBuilder builder = new RescoreBuilder(rescorer); + if (randomBoolean()) { + builder.windowSize(randomIntBetween(0, 100)); + } + return builder; + } + + private static RescoreBuilder serializedCopy(RescoreBuilder original) throws IOException { + try (BytesStreamOutput output = new BytesStreamOutput()) { + original.writeTo(output); + try (StreamInput in = new NamedWriteableAwareStreamInput(StreamInput.wrap(output.bytes()), namedWriteableRegistry)) { + return RescoreBuilder.PROTOYPE.readFrom(in); + } + } + } + +} diff --git a/core/src/test/java/org/elasticsearch/search/rescore/QueryRescoreModeTests.java b/core/src/test/java/org/elasticsearch/search/rescore/QueryRescoreModeTests.java new file mode 100644 index 00000000000..7b4cafe716a --- /dev/null +++ b/core/src/test/java/org/elasticsearch/search/rescore/QueryRescoreModeTests.java @@ -0,0 +1,58 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.search.rescore; + +import org.elasticsearch.test.ESTestCase; + +import java.io.IOException; + +/** + * Test fixing the ordinals and names in {@link QueryRescoreMode}. These should not be changed since we + * use the names in the parser and the ordinals in serialization. + */ +public class QueryRescoreModeTests extends ESTestCase { + + /** + * Test @link {@link QueryRescoreMode} enum ordinals and names, since serilaization relies on it + */ + public void testQueryRescoreMode() throws IOException { + float primary = randomFloat(); + float secondary = randomFloat(); + assertEquals(0, QueryRescoreMode.Avg.ordinal()); + assertEquals("avg", QueryRescoreMode.Avg.toString()); + assertEquals((primary + secondary)/2.0f, QueryRescoreMode.Avg.combine(primary, secondary), Float.MIN_VALUE); + + assertEquals(1, QueryRescoreMode.Max.ordinal()); + assertEquals("max", QueryRescoreMode.Max.toString()); + assertEquals(Math.max(primary, secondary), QueryRescoreMode.Max.combine(primary, secondary), Float.MIN_VALUE); + + assertEquals(2, QueryRescoreMode.Min.ordinal()); + assertEquals("min", QueryRescoreMode.Min.toString()); + assertEquals(Math.min(primary, secondary), QueryRescoreMode.Min.combine(primary, secondary), Float.MIN_VALUE); + + assertEquals(3, QueryRescoreMode.Total.ordinal()); + assertEquals("sum", QueryRescoreMode.Total.toString()); + assertEquals(primary + secondary, QueryRescoreMode.Total.combine(primary, secondary), Float.MIN_VALUE); + + assertEquals(4, QueryRescoreMode.Multiply.ordinal()); + assertEquals("product", QueryRescoreMode.Multiply.toString()); + assertEquals(primary * secondary, QueryRescoreMode.Multiply.combine(primary, secondary), Float.MIN_VALUE); + } +}