diff --git a/core/src/main/java/org/elasticsearch/action/search/SearchRequestBuilder.java b/core/src/main/java/org/elasticsearch/action/search/SearchRequestBuilder.java index 52d45ec9407..1557c266bd4 100644 --- a/core/src/main/java/org/elasticsearch/action/search/SearchRequestBuilder.java +++ b/core/src/main/java/org/elasticsearch/action/search/SearchRequestBuilder.java @@ -421,7 +421,7 @@ public class SearchRequestBuilder extends ActionRequestBuilder { + Avg { + @Override + public float combine(float primary, float secondary) { + return (primary + secondary) / 2; + } + + @Override + public String toString() { + return "avg"; + } + }, + Max { + @Override + public float combine(float primary, float secondary) { + return Math.max(primary, secondary); + } + + @Override + public String toString() { + return "max"; + } + }, + Min { + @Override + public float combine(float primary, float secondary) { + return Math.min(primary, secondary); + } + + @Override + public String toString() { + return "min"; + } + }, + Total { + @Override + public float combine(float primary, float secondary) { + return primary + secondary; + } + + @Override + public String toString() { + return "sum"; + } + }, + Multiply { + @Override + public float combine(float primary, float secondary) { + return primary * secondary; + } + + @Override + public String toString() { + return "product"; + } + }; + + public abstract float combine(float primary, float secondary); + + static QueryRescoreMode PROTOTYPE = Total; + + @Override + public QueryRescoreMode readFrom(StreamInput in) throws IOException { + int ordinal = in.readVInt(); + if (ordinal < 0 || ordinal >= values().length) { + throw new IOException("Unknown ScoreMode ordinal [" + ordinal + "]"); + } + return values()[ordinal]; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeVInt(this.ordinal()); + } + + public static QueryRescoreMode fromString(String scoreMode) { + for (QueryRescoreMode mode : values()) { + if (scoreMode.toLowerCase(Locale.ROOT).equals(mode.name().toLowerCase(Locale.ROOT))) { + return mode; + } + } + throw new IllegalArgumentException("illegal score_mode [" + scoreMode + "]"); + } + + @Override + public String toString() { + return name().toLowerCase(Locale.ROOT); + } +} \ No newline at end of file diff --git a/core/src/main/java/org/elasticsearch/search/rescore/QueryRescorer.java b/core/src/main/java/org/elasticsearch/search/rescore/QueryRescorer.java index afbd034840e..7f95ff10824 100644 --- a/core/src/main/java/org/elasticsearch/search/rescore/QueryRescorer.java +++ b/core/src/main/java/org/elasticsearch/search/rescore/QueryRescorer.java @@ -38,66 +38,6 @@ import java.util.Set; public final class QueryRescorer implements Rescorer { - private static enum ScoreMode { - Avg { - @Override - public float combine(float primary, float secondary) { - return (primary + secondary) / 2; - } - - @Override - public String toString() { - return "avg"; - } - }, - Max { - @Override - public float combine(float primary, float secondary) { - return Math.max(primary, secondary); - } - - @Override - public String toString() { - return "max"; - } - }, - Min { - @Override - public float combine(float primary, float secondary) { - return Math.min(primary, secondary); - } - - @Override - public String toString() { - return "min"; - } - }, - Total { - @Override - public float combine(float primary, float secondary) { - return primary + secondary; - } - - @Override - public String toString() { - return "sum"; - } - }, - Multiply { - @Override - public float combine(float primary, float secondary) { - return primary * secondary; - } - - @Override - public String toString() { - return "product"; - } - }; - - public abstract float combine(float primary, float secondary); - } - public static final Rescorer INSTANCE = new QueryRescorer(); public static final String NAME = "query"; @@ -170,7 +110,7 @@ public final class QueryRescorer implements Rescorer { rescoreExplain.getValue() * secondaryWeight, "product of:", rescoreExplain, Explanation.match(secondaryWeight, "secondaryWeight")); - ScoreMode scoreMode = rescore.scoreMode(); + QueryRescoreMode scoreMode = rescore.scoreMode(); return Explanation.match( scoreMode.combine(prim.getValue(), sec.getValue()), scoreMode + " of:", @@ -228,7 +168,7 @@ public final class QueryRescorer implements Rescorer { // secondary score? in.scoreDocs[i].score *= ctx.queryWeight(); } - + // TODO: this is wrong, i.e. we are comparing apples and oranges at this point. It would be better if we always rescored all // incoming first pass hits, instead of allowing recoring of just the top subset: Arrays.sort(in.scoreDocs, SCORE_DOC_COMPARATOR); @@ -240,13 +180,13 @@ public final class QueryRescorer implements Rescorer { public QueryRescoreContext(QueryRescorer rescorer) { super(NAME, 10, rescorer); - this.scoreMode = ScoreMode.Total; + this.scoreMode = QueryRescoreMode.Total; } private ParsedQuery parsedQuery; private float queryWeight = 1.0f; private float rescoreQueryWeight = 1.0f; - private ScoreMode scoreMode; + private QueryRescoreMode scoreMode; public void setParsedQuery(ParsedQuery parsedQuery) { this.parsedQuery = parsedQuery; @@ -264,7 +204,7 @@ public final class QueryRescorer implements Rescorer { return rescoreQueryWeight; } - public ScoreMode scoreMode() { + public QueryRescoreMode scoreMode() { return scoreMode; } @@ -276,26 +216,13 @@ public final class QueryRescorer implements Rescorer { this.queryWeight = queryWeight; } - public void setScoreMode(ScoreMode scoreMode) { + public void setScoreMode(QueryRescoreMode scoreMode) { this.scoreMode = scoreMode; } public void setScoreMode(String scoreMode) { - if ("avg".equals(scoreMode)) { - setScoreMode(ScoreMode.Avg); - } else if ("max".equals(scoreMode)) { - setScoreMode(ScoreMode.Max); - } else if ("min".equals(scoreMode)) { - setScoreMode(ScoreMode.Min); - } else if ("total".equals(scoreMode)) { - setScoreMode(ScoreMode.Total); - } else if ("multiply".equals(scoreMode)) { - setScoreMode(ScoreMode.Multiply); - } else { - throw new IllegalArgumentException("illegal score_mode [" + scoreMode + "]"); - } + setScoreMode(QueryRescoreMode.fromString(scoreMode)); } - } @Override diff --git a/core/src/main/java/org/elasticsearch/search/rescore/RescoreBuilder.java b/core/src/main/java/org/elasticsearch/search/rescore/RescoreBuilder.java index fde282427d7..7510d24f82d 100644 --- a/core/src/main/java/org/elasticsearch/search/rescore/RescoreBuilder.java +++ b/core/src/main/java/org/elasticsearch/search/rescore/RescoreBuilder.java @@ -19,24 +19,36 @@ package org.elasticsearch.search.rescore; +import org.elasticsearch.ExceptionsHelper; +import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentFactory; +import org.elasticsearch.index.query.MatchAllQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; import java.io.IOException; +import java.util.Locale; +import java.util.Objects; -public class RescoreBuilder implements ToXContent { +public class RescoreBuilder implements ToXContent, Writeable { private Rescorer rescorer; private Integer windowSize; + public static final RescoreBuilder PROTOYPE = new RescoreBuilder(new QueryRescorer(new MatchAllQueryBuilder())); - public static QueryRescorer queryRescorer(QueryBuilder queryBuilder) { - return new QueryRescorer(queryBuilder); + public RescoreBuilder(Rescorer rescorer) { + if (rescorer == null) { + throw new IllegalArgumentException("rescorer cannot be null"); + } + this.rescorer = rescorer; } - public RescoreBuilder rescorer(Rescorer rescorer) { - this.rescorer = rescorer; - return this; + public Rescorer rescorer() { + return this.rescorer; } public RescoreBuilder windowSize(int windowSize) { @@ -48,10 +60,6 @@ public class RescoreBuilder implements ToXContent { return windowSize; } - public boolean isEmpty() { - return rescorer == null; - } - @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { if (windowSize != null) { @@ -61,13 +69,66 @@ public class RescoreBuilder implements ToXContent { return builder; } - public static abstract class Rescorer implements ToXContent { + public static QueryRescorer queryRescorer(QueryBuilder queryBuilder) { + return new QueryRescorer(queryBuilder); + } + + @Override + public final int hashCode() { + return Objects.hash(windowSize, rescorer); + } + + @Override + public final boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + RescoreBuilder other = (RescoreBuilder) obj; + return Objects.equals(windowSize, other.windowSize) && + Objects.equals(rescorer, other.rescorer); + } + + @Override + public RescoreBuilder readFrom(StreamInput in) throws IOException { + RescoreBuilder builder = new RescoreBuilder(in.readRescorer()); + Integer windowSize = in.readOptionalVInt(); + if (windowSize != null) { + builder.windowSize(windowSize); + } + return builder; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeRescorer(rescorer); + out.writeOptionalVInt(this.windowSize); + } + + @Override + public final String toString() { + try { + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.prettyPrint(); + builder.startObject(); + toXContent(builder, EMPTY_PARAMS); + builder.endObject(); + return builder.string(); + } catch (Exception e) { + return "{ \"error\" : \"" + ExceptionsHelper.detailedMessage(e) + "\"}"; + } + } + + public static abstract class Rescorer implements ToXContent, NamedWriteable { private String name; public Rescorer(String name) { this.name = name; } + @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(name); @@ -78,23 +139,41 @@ public class RescoreBuilder implements ToXContent { protected abstract XContentBuilder innerToXContent(XContentBuilder builder, Params params) throws IOException; + @Override + public abstract int hashCode(); + + @Override + public abstract boolean equals(Object obj); } public static class QueryRescorer extends Rescorer { + private static final String NAME = "query"; - private QueryBuilder queryBuilder; - private Float rescoreQueryWeight; - private Float queryWeight; - private String scoreMode; + public static final QueryRescorer PROTOTYPE = new QueryRescorer(new MatchAllQueryBuilder()); + public static final float DEFAULT_RESCORE_QUERYWEIGHT = 1.0f; + public static final float DEFAULT_QUERYWEIGHT = 1.0f; + public static final QueryRescoreMode DEFAULT_SCORE_MODE = QueryRescoreMode.Total; + private final QueryBuilder queryBuilder; + private float rescoreQueryWeight = DEFAULT_RESCORE_QUERYWEIGHT; + private float queryWeight = DEFAULT_QUERYWEIGHT; + private QueryRescoreMode scoreMode = DEFAULT_SCORE_MODE; /** * Creates a new {@link QueryRescorer} instance * @param builder the query builder to build the rescore query from */ - public QueryRescorer(QueryBuilder builder) { + public QueryRescorer(QueryBuilder builder) { super(NAME); this.queryBuilder = builder; } + + /** + * @return the query used for this rescore query + */ + public QueryBuilder getRescoreQuery() { + return this.queryBuilder; + } + /** * Sets the original query weight for rescoring. The default is 1.0 */ @@ -103,6 +182,14 @@ public class RescoreBuilder implements ToXContent { return this; } + + /** + * Gets the original query weight for rescoring. The default is 1.0 + */ + public float getQueryWeight() { + return this.queryWeight; + } + /** * Sets the original query weight for rescoring. The default is 1.0 */ @@ -112,27 +199,76 @@ public class RescoreBuilder implements ToXContent { } /** - * Sets the original query score mode. The default is total + * Gets the original query weight for rescoring. The default is 1.0 */ - public QueryRescorer setScoreMode(String scoreMode) { + public float getRescoreQueryWeight() { + return this.rescoreQueryWeight; + } + + /** + * Sets the original query score mode. The default is {@link QueryRescoreMode#Total}. + */ + public QueryRescorer setScoreMode(QueryRescoreMode scoreMode) { this.scoreMode = scoreMode; return this; } + /** + * Gets the original query score mode. The default is total + */ + public QueryRescoreMode getScoreMode() { + return this.scoreMode; + } + @Override protected XContentBuilder innerToXContent(XContentBuilder builder, Params params) throws IOException { builder.field("rescore_query", queryBuilder); - if (queryWeight != null) { - builder.field("query_weight", queryWeight); - } - if (rescoreQueryWeight != null) { - builder.field("rescore_query_weight", rescoreQueryWeight); - } - if (scoreMode != null) { - builder.field("score_mode", scoreMode); - } + builder.field("query_weight", queryWeight); + builder.field("rescore_query_weight", rescoreQueryWeight); + builder.field("score_mode", scoreMode.name().toLowerCase(Locale.ROOT)); return builder; } - } + @Override + public final int hashCode() { + return Objects.hash(getClass(), scoreMode, queryWeight, rescoreQueryWeight, queryBuilder); + } + + @Override + public final boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + QueryRescorer other = (QueryRescorer) obj; + return Objects.equals(scoreMode, other.scoreMode) && + Objects.equals(queryWeight, other.queryWeight) && + Objects.equals(rescoreQueryWeight, other.rescoreQueryWeight) && + Objects.equals(queryBuilder, other.queryBuilder); + } + + @Override + public QueryRescorer readFrom(StreamInput in) throws IOException { + QueryRescorer rescorer = new QueryRescorer(in.readQuery()); + rescorer.setScoreMode(QueryRescoreMode.PROTOTYPE.readFrom(in)); + rescorer.setRescoreQueryWeight(in.readFloat()); + rescorer.setQueryWeight(in.readFloat()); + return rescorer; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeQuery(queryBuilder); + scoreMode.writeTo(out); + out.writeFloat(rescoreQueryWeight); + out.writeFloat(queryWeight); + } + + @Override + public String getWriteableName() { + return NAME; + } + } } diff --git a/core/src/test/java/org/elasticsearch/search/builder/SearchSourceBuilderTests.java b/core/src/test/java/org/elasticsearch/search/builder/SearchSourceBuilderTests.java index b80810fc6d5..d7ede712447 100644 --- a/core/src/test/java/org/elasticsearch/search/builder/SearchSourceBuilderTests.java +++ b/core/src/test/java/org/elasticsearch/search/builder/SearchSourceBuilderTests.java @@ -274,8 +274,7 @@ public class SearchSourceBuilderTests extends ESTestCase { int numRescores = randomIntBetween(1, 5); for (int i = 0; i < numRescores; i++) { // NORELEASE need a random rescore builder method - RescoreBuilder rescoreBuilder = new RescoreBuilder(); - rescoreBuilder.rescorer(RescoreBuilder.queryRescorer(QueryBuilders.termQuery(randomAsciiOfLengthBetween(5, 20), + RescoreBuilder rescoreBuilder = new RescoreBuilder(RescoreBuilder.queryRescorer(QueryBuilders.termQuery(randomAsciiOfLengthBetween(5, 20), randomAsciiOfLengthBetween(5, 20)))); builder.addRescorer(rescoreBuilder); } diff --git a/core/src/test/java/org/elasticsearch/search/functionscore/QueryRescorerIT.java b/core/src/test/java/org/elasticsearch/search/functionscore/QueryRescorerIT.java index 861701a9f6a..5644f893603 100644 --- a/core/src/test/java/org/elasticsearch/search/functionscore/QueryRescorerIT.java +++ b/core/src/test/java/org/elasticsearch/search/functionscore/QueryRescorerIT.java @@ -37,6 +37,7 @@ import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.query.functionscore.ScoreFunctionBuilders; import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHits; +import org.elasticsearch.search.rescore.QueryRescoreMode; import org.elasticsearch.search.rescore.RescoreBuilder; import org.elasticsearch.search.rescore.RescoreBuilder.QueryRescorer; import org.elasticsearch.test.ESIntegTestCase; @@ -541,7 +542,7 @@ public class QueryRescorerIT extends ESIntegTestCase { .setQueryWeight(0.5f).setRescoreQueryWeight(0.4f); if (!"".equals(scoreModes[innerMode])) { - innerRescoreQuery.setScoreMode(scoreModes[innerMode]); + innerRescoreQuery.setScoreMode(QueryRescoreMode.fromString(scoreModes[innerMode])); } SearchResponse searchResponse = client() @@ -564,7 +565,7 @@ public class QueryRescorerIT extends ESIntegTestCase { .boost(4.0f)).setQueryWeight(0.5f).setRescoreQueryWeight(0.4f); if (!"".equals(scoreModes[outerMode])) { - outerRescoreQuery.setScoreMode(scoreModes[outerMode]); + outerRescoreQuery.setScoreMode(QueryRescoreMode.fromString(scoreModes[outerMode])); } searchResponse = client() @@ -612,7 +613,7 @@ public class QueryRescorerIT extends ESIntegTestCase { .setRescoreQueryWeight(secondaryWeight); if (!"".equals(scoreMode)) { - rescoreQuery.setScoreMode(scoreMode); + rescoreQuery.setScoreMode(QueryRescoreMode.fromString(scoreMode)); } SearchResponse rescored = client() @@ -683,11 +684,11 @@ public class QueryRescorerIT extends ESIntegTestCase { int numDocs = indexRandomNumbers("keyword", 1, true); QueryRescorer eightIsGreat = RescoreBuilder.queryRescorer( QueryBuilders.functionScoreQuery(QueryBuilders.termQuery("field1", English.intToEnglish(8)), - ScoreFunctionBuilders.weightFactorFunction(1000.0f)).boostMode(CombineFunction.REPLACE)).setScoreMode("total"); + ScoreFunctionBuilders.weightFactorFunction(1000.0f)).boostMode(CombineFunction.REPLACE)).setScoreMode(QueryRescoreMode.Total); QueryRescorer sevenIsBetter = RescoreBuilder.queryRescorer( QueryBuilders.functionScoreQuery(QueryBuilders.termQuery("field1", English.intToEnglish(7)), ScoreFunctionBuilders.weightFactorFunction(10000.0f)).boostMode(CombineFunction.REPLACE)) - .setScoreMode("total"); + .setScoreMode(QueryRescoreMode.Total); // First set the rescore window large enough that both rescores take effect SearchRequestBuilder request = client().prepareSearch(); @@ -704,10 +705,10 @@ public class QueryRescorerIT extends ESIntegTestCase { // Now use one rescore to drag the number we're looking for into the window of another QueryRescorer ninetyIsGood = RescoreBuilder.queryRescorer( QueryBuilders.functionScoreQuery(QueryBuilders.queryStringQuery("*ninety*"), ScoreFunctionBuilders.weightFactorFunction(1000.0f)) - .boostMode(CombineFunction.REPLACE)).setScoreMode("total"); + .boostMode(CombineFunction.REPLACE)).setScoreMode(QueryRescoreMode.Total); QueryRescorer oneToo = RescoreBuilder.queryRescorer( QueryBuilders.functionScoreQuery(QueryBuilders.queryStringQuery("*one*"), ScoreFunctionBuilders.weightFactorFunction(1000.0f)) - .boostMode(CombineFunction.REPLACE)).setScoreMode("total"); + .boostMode(CombineFunction.REPLACE)).setScoreMode(QueryRescoreMode.Total); request.clearRescorers().addRescorer(ninetyIsGood, numDocs).addRescorer(oneToo, 10); response = request.setSize(2).get(); assertFirstHit(response, hasId("91")); diff --git a/core/src/test/java/org/elasticsearch/search/rescore/QueryRescoreBuilderTests.java b/core/src/test/java/org/elasticsearch/search/rescore/QueryRescoreBuilderTests.java new file mode 100644 index 00000000000..2aa55f8b626 --- /dev/null +++ b/core/src/test/java/org/elasticsearch/search/rescore/QueryRescoreBuilderTests.java @@ -0,0 +1,170 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.search.rescore; + +import org.elasticsearch.common.io.stream.BytesStreamOutput; +import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.index.query.MatchAllQueryBuilder; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.search.rescore.RescoreBuilder.QueryRescorer; +import org.elasticsearch.search.rescore.RescoreBuilder.Rescorer; +import org.elasticsearch.test.ESTestCase; +import org.junit.AfterClass; +import org.junit.BeforeClass; + +import java.io.IOException; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.not; + +public class QueryRescoreBuilderTests extends ESTestCase { + + private static final int NUMBER_OF_TESTBUILDERS = 20; + private static NamedWriteableRegistry namedWriteableRegistry; + + /** + * setup for the whole base test class + */ + @BeforeClass + public static void init() { + namedWriteableRegistry = new NamedWriteableRegistry(); + namedWriteableRegistry.registerPrototype(Rescorer.class, org.elasticsearch.search.rescore.RescoreBuilder.QueryRescorer.PROTOTYPE); + namedWriteableRegistry.registerPrototype(QueryBuilder.class, new MatchAllQueryBuilder()); + } + + @AfterClass + public static void afterClass() throws Exception { + namedWriteableRegistry = null; + } + + /** + * Test serialization and deserialization of the rescore builder + */ + public void testSerialization() throws IOException { + for (int runs = 0; runs < NUMBER_OF_TESTBUILDERS; runs++) { + RescoreBuilder original = randomRescoreBuilder(); + RescoreBuilder deserialized = serializedCopy(original); + assertEquals(deserialized, original); + assertEquals(deserialized.hashCode(), original.hashCode()); + assertNotSame(deserialized, original); + } + } + + /** + * Test equality and hashCode properties + */ + public void testEqualsAndHashcode() throws IOException { + for (int runs = 0; runs < NUMBER_OF_TESTBUILDERS; runs++) { + RescoreBuilder firstBuilder = randomRescoreBuilder(); + assertFalse("rescore builder is equal to null", firstBuilder.equals(null)); + assertFalse("rescore builder is equal to incompatible type", firstBuilder.equals("")); + assertTrue("rescore builder is not equal to self", firstBuilder.equals(firstBuilder)); + assertThat("same rescore builder's hashcode returns different values if called multiple times", firstBuilder.hashCode(), + equalTo(firstBuilder.hashCode())); + assertThat("different rescore builder should not be equal", mutate(firstBuilder), not(equalTo(firstBuilder))); + + RescoreBuilder secondBuilder = serializedCopy(firstBuilder); + assertTrue("rescore builder is not equal to self", secondBuilder.equals(secondBuilder)); + assertTrue("rescore builder is not equal to its copy", firstBuilder.equals(secondBuilder)); + assertTrue("equals is not symmetric", secondBuilder.equals(firstBuilder)); + assertThat("rescore builder copy's hashcode is different from original hashcode", secondBuilder.hashCode(), equalTo(firstBuilder.hashCode())); + + RescoreBuilder thirdBuilder = serializedCopy(secondBuilder); + assertTrue("rescore builder is not equal to self", thirdBuilder.equals(thirdBuilder)); + assertTrue("rescore builder is not equal to its copy", secondBuilder.equals(thirdBuilder)); + assertThat("rescore builder copy's hashcode is different from original hashcode", secondBuilder.hashCode(), equalTo(thirdBuilder.hashCode())); + assertTrue("equals is not transitive", firstBuilder.equals(thirdBuilder)); + assertThat("rescore builder copy's hashcode is different from original hashcode", firstBuilder.hashCode(), equalTo(thirdBuilder.hashCode())); + assertTrue("equals is not symmetric", thirdBuilder.equals(secondBuilder)); + assertTrue("equals is not symmetric", thirdBuilder.equals(firstBuilder)); + } + } + + private RescoreBuilder mutate(RescoreBuilder original) throws IOException { + RescoreBuilder mutation = serializedCopy(original); + if (randomBoolean()) { + Integer windowSize = original.windowSize(); + if (windowSize != null) { + mutation.windowSize(windowSize + 1); + } else { + mutation.windowSize(randomIntBetween(0, 100)); + } + } else { + QueryRescorer queryRescorer = (QueryRescorer) mutation.rescorer(); + switch (randomIntBetween(0, 3)) { + case 0: + queryRescorer.setQueryWeight(queryRescorer.getQueryWeight() + 0.1f); + break; + case 1: + queryRescorer.setRescoreQueryWeight(queryRescorer.getRescoreQueryWeight() + 0.1f); + break; + case 2: + QueryRescoreMode other; + do { + other = randomFrom(QueryRescoreMode.values()); + } while (other == queryRescorer.getScoreMode()); + queryRescorer.setScoreMode(other); + break; + case 3: + // only increase the boost to make it a slightly different query + queryRescorer.getRescoreQuery().boost(queryRescorer.getRescoreQuery().boost() + 0.1f); + break; + default: + throw new IllegalStateException("unexpected random mutation in test"); + } + } + return mutation; + } + + /** + * create random shape that is put under test + */ + private static RescoreBuilder randomRescoreBuilder() { + QueryBuilder queryBuilder = new MatchAllQueryBuilder().boost(randomFloat()).queryName(randomAsciiOfLength(20)); + org.elasticsearch.search.rescore.RescoreBuilder.QueryRescorer rescorer = new + org.elasticsearch.search.rescore.RescoreBuilder.QueryRescorer(queryBuilder); + if (randomBoolean()) { + rescorer.setQueryWeight(randomFloat()); + } + if (randomBoolean()) { + rescorer.setRescoreQueryWeight(randomFloat()); + } + if (randomBoolean()) { + rescorer.setScoreMode(randomFrom(QueryRescoreMode.values())); + } + RescoreBuilder builder = new RescoreBuilder(rescorer); + if (randomBoolean()) { + builder.windowSize(randomIntBetween(0, 100)); + } + return builder; + } + + private static RescoreBuilder serializedCopy(RescoreBuilder original) throws IOException { + try (BytesStreamOutput output = new BytesStreamOutput()) { + original.writeTo(output); + try (StreamInput in = new NamedWriteableAwareStreamInput(StreamInput.wrap(output.bytes()), namedWriteableRegistry)) { + return RescoreBuilder.PROTOYPE.readFrom(in); + } + } + } + +} diff --git a/core/src/test/java/org/elasticsearch/search/rescore/QueryRescoreModeTests.java b/core/src/test/java/org/elasticsearch/search/rescore/QueryRescoreModeTests.java new file mode 100644 index 00000000000..7b4cafe716a --- /dev/null +++ b/core/src/test/java/org/elasticsearch/search/rescore/QueryRescoreModeTests.java @@ -0,0 +1,58 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.search.rescore; + +import org.elasticsearch.test.ESTestCase; + +import java.io.IOException; + +/** + * Test fixing the ordinals and names in {@link QueryRescoreMode}. These should not be changed since we + * use the names in the parser and the ordinals in serialization. + */ +public class QueryRescoreModeTests extends ESTestCase { + + /** + * Test @link {@link QueryRescoreMode} enum ordinals and names, since serilaization relies on it + */ + public void testQueryRescoreMode() throws IOException { + float primary = randomFloat(); + float secondary = randomFloat(); + assertEquals(0, QueryRescoreMode.Avg.ordinal()); + assertEquals("avg", QueryRescoreMode.Avg.toString()); + assertEquals((primary + secondary)/2.0f, QueryRescoreMode.Avg.combine(primary, secondary), Float.MIN_VALUE); + + assertEquals(1, QueryRescoreMode.Max.ordinal()); + assertEquals("max", QueryRescoreMode.Max.toString()); + assertEquals(Math.max(primary, secondary), QueryRescoreMode.Max.combine(primary, secondary), Float.MIN_VALUE); + + assertEquals(2, QueryRescoreMode.Min.ordinal()); + assertEquals("min", QueryRescoreMode.Min.toString()); + assertEquals(Math.min(primary, secondary), QueryRescoreMode.Min.combine(primary, secondary), Float.MIN_VALUE); + + assertEquals(3, QueryRescoreMode.Total.ordinal()); + assertEquals("sum", QueryRescoreMode.Total.toString()); + assertEquals(primary + secondary, QueryRescoreMode.Total.combine(primary, secondary), Float.MIN_VALUE); + + assertEquals(4, QueryRescoreMode.Multiply.ordinal()); + assertEquals("product", QueryRescoreMode.Multiply.toString()); + assertEquals(primary * secondary, QueryRescoreMode.Multiply.combine(primary, secondary), Float.MIN_VALUE); + } +}