From 513f4e6c57ea15114010c69ffa1665fbce13b881 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christoph=20B=C3=BCscher?= Date: Mon, 18 Jan 2016 15:31:16 +0100 Subject: [PATCH] Add serialization and fromXContent to SmoothingModels PhraseSuggestionBuilder uses three smoothing models internally. In order to enable proper serialization / parsing from xContent to the phrase suggester later, this change starts by making the smoothing models writable, adding hashCode/equals and fromXContent. --- .../suggest/phrase/PhraseSuggestParser.java | 6 +- .../phrase/PhraseSuggestionBuilder.java | 266 ++++++++++++++++-- .../AbstractShapeBuilderTestCase.java | 1 - .../suggest/phrase/LaplaceModelTests.java | 38 +++ .../phrase/LinearInterpolationModelTests.java | 55 ++++ .../suggest/phrase/SmoothingModelTest.java | 161 +++++++++++ .../phrase/StupidBackoffModelTests.java | 38 +++ 7 files changed, 539 insertions(+), 26 deletions(-) create mode 100644 core/src/test/java/org/elasticsearch/search/suggest/phrase/LaplaceModelTests.java create mode 100644 core/src/test/java/org/elasticsearch/search/suggest/phrase/LinearInterpolationModelTests.java create mode 100644 core/src/test/java/org/elasticsearch/search/suggest/phrase/SmoothingModelTest.java create mode 100644 core/src/test/java/org/elasticsearch/search/suggest/phrase/StupidBackoffModelTests.java diff --git a/core/src/main/java/org/elasticsearch/search/suggest/phrase/PhraseSuggestParser.java b/core/src/main/java/org/elasticsearch/search/suggest/phrase/PhraseSuggestParser.java index 0b904a95720..c226d061047 100644 --- a/core/src/main/java/org/elasticsearch/search/suggest/phrase/PhraseSuggestParser.java +++ b/core/src/main/java/org/elasticsearch/search/suggest/phrase/PhraseSuggestParser.java @@ -36,6 +36,8 @@ import org.elasticsearch.script.Template; import org.elasticsearch.search.suggest.SuggestContextParser; import org.elasticsearch.search.suggest.SuggestUtils; import org.elasticsearch.search.suggest.SuggestionSearchContext; +import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.Laplace; +import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.StupidBackoff; import org.elasticsearch.search.suggest.phrase.PhraseSuggestionContext.DirectCandidateGenerator; import java.io.IOException; @@ -265,7 +267,7 @@ public final class PhraseSuggestParser implements SuggestContextParser { }); } else if ("laplace".equals(fieldName)) { ensureNoSmoothing(suggestion); - double theAlpha = 0.5; + double theAlpha = Laplace.DEFAULT_LAPLACE_ALPHA; while ((token = parser.nextToken()) != Token.END_OBJECT) { if (token == XContentParser.Token.FIELD_NAME) { @@ -286,7 +288,7 @@ public final class PhraseSuggestParser implements SuggestContextParser { } else if ("stupid_backoff".equals(fieldName) || "stupidBackoff".equals(fieldName)) { ensureNoSmoothing(suggestion); - double theDiscount = 0.4; + double theDiscount = StupidBackoff.DEFAULT_BACKOFF_DISCOUNT; while ((token = parser.nextToken()) != Token.END_OBJECT) { if (token == XContentParser.Token.FIELD_NAME) { fieldName = parser.currentName(); diff --git a/core/src/main/java/org/elasticsearch/search/suggest/phrase/PhraseSuggestionBuilder.java b/core/src/main/java/org/elasticsearch/search/suggest/phrase/PhraseSuggestionBuilder.java index 1055fbe83fc..97ca09d25a1 100644 --- a/core/src/main/java/org/elasticsearch/search/suggest/phrase/PhraseSuggestionBuilder.java +++ b/core/src/main/java/org/elasticsearch/search/suggest/phrase/PhraseSuggestionBuilder.java @@ -18,8 +18,16 @@ */ package org.elasticsearch.search.suggest.phrase; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.ParseFieldMatcher; +import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.common.xcontent.XContentParser.Token; +import org.elasticsearch.index.query.QueryParseContext; import org.elasticsearch.script.Template; import org.elasticsearch.search.suggest.SuggestBuilder.SuggestionBuilder; @@ -29,6 +37,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Map.Entry; +import java.util.Objects; import java.util.Set; /** @@ -41,7 +50,7 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder> generators = new HashMap<>(); private Integer gramSize; - private SmoothingModel model; + private SmoothingModel model; private Boolean forceUnigrams; private Integer tokenLimit; private String preTag; @@ -150,7 +159,7 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder model) { this.model = model; return this; } @@ -283,8 +292,15 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder for details. *

*/ - public static final class StupidBackoff extends SmoothingModel { - private final double discount; + public static final class StupidBackoff extends SmoothingModel { + /** + * Default discount parameter for {@link StupidBackoff} smoothing + */ + public static final double DEFAULT_BACKOFF_DISCOUNT = 0.4; + private double discount = DEFAULT_BACKOFF_DISCOUNT; + static final StupidBackoff PROTOTYPE = new StupidBackoff(DEFAULT_BACKOFF_DISCOUNT); + private static final String NAME = "stupid_backoff"; + private static final ParseField DISCOUNT_FIELD = new ParseField("discount"); /** * Creates a Stupid-Backoff smoothing model. @@ -293,15 +309,63 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder for details. *

*/ - public static final class Laplace extends SmoothingModel { - private final double alpha; + public static final class Laplace extends SmoothingModel { + private double alpha = DEFAULT_LAPLACE_ALPHA; + private static final String NAME = "laplace"; + private static final ParseField ALPHA_FIELD = new ParseField("alpha"); + /** + * Default alpha parameter for laplace smoothing + */ + public static final double DEFAULT_LAPLACE_ALPHA = 0.5; + static final Laplace PROTOTYPE = new Laplace(DEFAULT_LAPLACE_ALPHA); + /** * Creates a Laplace smoothing model. * */ public Laplace(double alpha) { - super("laplace"); this.alpha = alpha; } + /** + * @return the laplace model alpha parameter + */ + public double getAlpha() { + return this.alpha; + } + @Override protected XContentBuilder innerToXContent(XContentBuilder builder, Params params) throws IOException { - builder.field("alpha", alpha); + builder.field(ALPHA_FIELD.getPreferredName(), alpha); return builder; } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeDouble(alpha); + } + + @Override + public Laplace readFrom(StreamInput in) throws IOException { + return new Laplace(in.readDouble()); + } + + @Override + protected boolean doEquals(Laplace other) { + return Objects.equals(alpha, other.alpha); + } + + @Override + public final int hashCode() { + return Objects.hash(alpha); + } + + @Override + public Laplace fromXContent(QueryParseContext parseContext) throws IOException { + XContentParser parser = parseContext.parser(); + XContentParser.Token token; + String fieldName = null; + double alpha = DEFAULT_LAPLACE_ALPHA; + while ((token = parser.nextToken()) != Token.END_OBJECT) { + if (token == XContentParser.Token.FIELD_NAME) { + fieldName = parser.currentName(); + } + if (token.isValue() && parseContext.parseFieldMatcher().match(fieldName, ALPHA_FIELD)) { + alpha = parser.doubleValue(); + } + } + return new Laplace(alpha); + } } - public static abstract class SmoothingModel implements ToXContent { - private final String type; - - protected SmoothingModel(String type) { - this.type = type; - } + public static abstract class SmoothingModel> implements NamedWriteable, ToXContent { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(type); + builder.startObject(getWriteableName()); innerToXContent(builder,params); builder.endObject(); return builder; } + @Override + public final boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + @SuppressWarnings("unchecked") + SM other = (SM) obj; + return doEquals(other); + } + + public abstract SM fromXContent(QueryParseContext parseContext) throws IOException; + + /** + * subtype specific implementation of "equals". + */ + protected abstract boolean doEquals(SM other); + protected abstract XContentBuilder innerToXContent(XContentBuilder builder, Params params) throws IOException; } @@ -358,10 +493,15 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder for details. *

*/ - public static final class LinearInterpolation extends SmoothingModel { + public static final class LinearInterpolation extends SmoothingModel { + private static final String NAME = "linear"; + static final LinearInterpolation PROTOTYPE = new LinearInterpolation(0.8, 0.1, 0.1); private final double trigramLambda; private final double bigramLambda; private final double unigramLambda; + private static final ParseField TRIGRAM_FIELD = new ParseField("trigram_lambda"); + private static final ParseField BIGRAM_FIELD = new ParseField("bigram_lambda"); + private static final ParseField UNIGRAM_FIELD = new ParseField("unigram_lambda"); /** * Creates a linear interpolation smoothing model. @@ -376,19 +516,99 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder 0.001) { + throw new IllegalArgumentException("linear smoothing lambdas must sum to 1"); + } this.trigramLambda = trigramLambda; this.bigramLambda = bigramLambda; this.unigramLambda = unigramLambda; } + public double getTrigramLambda() { + return this.trigramLambda; + } + + public double getBigramLambda() { + return this.bigramLambda; + } + + public double getUnigramLambda() { + return this.unigramLambda; + } + @Override protected XContentBuilder innerToXContent(XContentBuilder builder, Params params) throws IOException { - builder.field("trigram_lambda", trigramLambda); - builder.field("bigram_lambda", bigramLambda); - builder.field("unigram_lambda", unigramLambda); + builder.field(TRIGRAM_FIELD.getPreferredName(), trigramLambda); + builder.field(BIGRAM_FIELD.getPreferredName(), bigramLambda); + builder.field(UNIGRAM_FIELD.getPreferredName(), unigramLambda); return builder; } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeDouble(trigramLambda); + out.writeDouble(bigramLambda); + out.writeDouble(unigramLambda); + } + + @Override + public LinearInterpolation readFrom(StreamInput in) throws IOException { + return new LinearInterpolation(in.readDouble(), in.readDouble(), in.readDouble()); + } + + @Override + protected boolean doEquals(LinearInterpolation other) { + return Objects.equals(trigramLambda, other.trigramLambda) && + Objects.equals(bigramLambda, other.bigramLambda) && + Objects.equals(unigramLambda, other.unigramLambda); + } + + @Override + public final int hashCode() { + return Objects.hash(trigramLambda, bigramLambda, unigramLambda); + } + + @Override + public LinearInterpolation fromXContent(QueryParseContext parseContext) throws IOException { + XContentParser parser = parseContext.parser(); + XContentParser.Token token; + String fieldName = null; + final double[] lambdas = new double[3]; + ParseFieldMatcher matcher = parseContext.parseFieldMatcher(); + while ((token = parser.nextToken()) != Token.END_OBJECT) { + if (token == XContentParser.Token.FIELD_NAME) { + fieldName = parser.currentName(); + } + if (token.isValue()) { + if (matcher.match(fieldName, TRIGRAM_FIELD)) { + lambdas[0] = parser.doubleValue(); + if (lambdas[0] < 0) { + throw new IllegalArgumentException("trigram_lambda must be positive"); + } + } else if (matcher.match(fieldName, BIGRAM_FIELD)) { + lambdas[1] = parser.doubleValue(); + if (lambdas[1] < 0) { + throw new IllegalArgumentException("bigram_lambda must be positive"); + } + } else if (matcher.match(fieldName, UNIGRAM_FIELD)) { + lambdas[2] = parser.doubleValue(); + if (lambdas[2] < 0) { + throw new IllegalArgumentException("unigram_lambda must be positive"); + } + } else { + throw new IllegalArgumentException( + "suggester[phrase][smoothing][linear] doesn't support field [" + fieldName + "]"); + } + } + } + return new LinearInterpolation(lambdas[0], lambdas[1], lambdas[2]); + } } /** @@ -428,7 +648,7 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder exte } XContentBuilder builder = testShape.toXContent(contentBuilder, ToXContent.EMPTY_PARAMS); XContentParser shapeParser = XContentHelper.createParser(builder.bytes()); - XContentHelper.createParser(builder.bytes()); shapeParser.nextToken(); ShapeBuilder parsedShape = ShapeBuilder.parse(shapeParser); assertNotSame(testShape, parsedShape); diff --git a/core/src/test/java/org/elasticsearch/search/suggest/phrase/LaplaceModelTests.java b/core/src/test/java/org/elasticsearch/search/suggest/phrase/LaplaceModelTests.java new file mode 100644 index 00000000000..e2256e98f6a --- /dev/null +++ b/core/src/test/java/org/elasticsearch/search/suggest/phrase/LaplaceModelTests.java @@ -0,0 +1,38 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.search.suggest.phrase; + +import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.Laplace; + +public class LaplaceModelTests extends SmoothingModelTest { + + @Override + protected Laplace createTestModel() { + return new Laplace(randomDoubleBetween(0.0, 10.0, false)); + } + + /** + * mutate the given model so the returned smoothing model is different + */ + @Override + protected Laplace createMutation(Laplace original) { + return new Laplace(original.getAlpha() + 0.1); + } +} diff --git a/core/src/test/java/org/elasticsearch/search/suggest/phrase/LinearInterpolationModelTests.java b/core/src/test/java/org/elasticsearch/search/suggest/phrase/LinearInterpolationModelTests.java new file mode 100644 index 00000000000..467bca7f0ab --- /dev/null +++ b/core/src/test/java/org/elasticsearch/search/suggest/phrase/LinearInterpolationModelTests.java @@ -0,0 +1,55 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.search.suggest.phrase; + +import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.LinearInterpolation; + +public class LinearInterpolationModelTests extends SmoothingModelTest { + + @Override + protected LinearInterpolation createTestModel() { + double trigramLambda = randomDoubleBetween(0.0, 10.0, false); + double bigramLambda = randomDoubleBetween(0.0, 10.0, false); + double unigramLambda = randomDoubleBetween(0.0, 10.0, false); + // normalize + double sum = trigramLambda + bigramLambda + unigramLambda; + return new LinearInterpolation(trigramLambda / sum, bigramLambda / sum, unigramLambda / sum); + } + + /** + * mutate the given model so the returned smoothing model is different + */ + @Override + protected LinearInterpolation createMutation(LinearInterpolation original) { + // swap two values permute original lambda values + switch (randomIntBetween(0, 2)) { + case 0: + // swap first two + return new LinearInterpolation(original.getBigramLambda(), original.getTrigramLambda(), original.getUnigramLambda()); + case 1: + // swap last two + return new LinearInterpolation(original.getTrigramLambda(), original.getUnigramLambda(), original.getBigramLambda()); + case 2: + default: + // swap first and last + return new LinearInterpolation(original.getUnigramLambda(), original.getBigramLambda(), original.getTrigramLambda()); + } + } +} diff --git a/core/src/test/java/org/elasticsearch/search/suggest/phrase/SmoothingModelTest.java b/core/src/test/java/org/elasticsearch/search/suggest/phrase/SmoothingModelTest.java new file mode 100644 index 00000000000..b2dbe17e67d --- /dev/null +++ b/core/src/test/java/org/elasticsearch/search/suggest/phrase/SmoothingModelTest.java @@ -0,0 +1,161 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.search.suggest.phrase; + +import org.elasticsearch.common.ParseFieldMatcher; +import org.elasticsearch.common.io.stream.BytesStreamOutput; +import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentFactory; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.index.query.QueryParseContext; +import org.elasticsearch.indices.query.IndicesQueriesRegistry; +import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.Laplace; +import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.LinearInterpolation; +import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.SmoothingModel; +import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.StupidBackoff; +import org.elasticsearch.test.ESTestCase; +import org.junit.AfterClass; +import org.junit.BeforeClass; + +import java.io.IOException; +import java.util.Collections; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.not; + +public abstract class SmoothingModelTest> extends ESTestCase { + + private static NamedWriteableRegistry namedWriteableRegistry; + private static IndicesQueriesRegistry indicesQueriesRegistry; + + /** + * setup for the whole base test class + */ + @BeforeClass + public static void init() { + if (namedWriteableRegistry == null) { + namedWriteableRegistry = new NamedWriteableRegistry(); + namedWriteableRegistry.registerPrototype(SmoothingModel.class, Laplace.PROTOTYPE); + namedWriteableRegistry.registerPrototype(SmoothingModel.class, LinearInterpolation.PROTOTYPE); + namedWriteableRegistry.registerPrototype(SmoothingModel.class, StupidBackoff.PROTOTYPE); + } + indicesQueriesRegistry = new IndicesQueriesRegistry(Settings.settingsBuilder().build(), Collections.emptySet(), namedWriteableRegistry); + } + + @AfterClass + public static void afterClass() throws Exception { + namedWriteableRegistry = null; + indicesQueriesRegistry = null; + } + + /** + * create random model that is put under test + */ + protected abstract SM createTestModel(); + + /** + * mutate the given model so the returned smoothing model is different + */ + protected abstract SM createMutation(SM original) throws IOException; + + /** + * Test that creates new smoothing model from a random test smoothing model and checks both for equality + */ + public void testFromXContent() throws IOException { + QueryParseContext context = new QueryParseContext(indicesQueriesRegistry); + context.parseFieldMatcher(new ParseFieldMatcher(Settings.EMPTY)); + + SM testModel = createTestModel(); + XContentBuilder contentBuilder = XContentFactory.contentBuilder(randomFrom(XContentType.values())); + if (randomBoolean()) { + contentBuilder.prettyPrint(); + } + contentBuilder.startObject(); + testModel.innerToXContent(contentBuilder, ToXContent.EMPTY_PARAMS); + contentBuilder.endObject(); + XContentParser parser = XContentHelper.createParser(contentBuilder.bytes()); + context.reset(parser); + SmoothingModel prototype = (SmoothingModel) namedWriteableRegistry.getPrototype(SmoothingModel.class, + testModel.getWriteableName()); + SmoothingModel parsedModel = prototype.fromXContent(context); + assertNotSame(testModel, parsedModel); + assertEquals(testModel, parsedModel); + assertEquals(testModel.hashCode(), parsedModel.hashCode()); + } + + /** + * Test serialization and deserialization of the tested model. + */ + @SuppressWarnings("unchecked") + public void testSerialization() throws IOException { + SM testModel = createTestModel(); + SM deserializedModel = (SM) copyModel(testModel); + assertEquals(testModel, deserializedModel); + assertEquals(testModel.hashCode(), deserializedModel.hashCode()); + assertNotSame(testModel, deserializedModel); + } + + /** + * Test equality and hashCode properties + */ + @SuppressWarnings("unchecked") + public void testEqualsAndHashcode() throws IOException { + SM firstModel = createTestModel(); + assertFalse("smoothing model is equal to null", firstModel.equals(null)); + assertFalse("smoothing model is equal to incompatible type", firstModel.equals("")); + assertTrue("smoothing model is not equal to self", firstModel.equals(firstModel)); + assertThat("same smoothing model's hashcode returns different values if called multiple times", firstModel.hashCode(), + equalTo(firstModel.hashCode())); + assertThat("different smoothing models should not be equal", createMutation(firstModel), not(equalTo(firstModel))); + + SM secondModel = (SM) copyModel(firstModel); + assertTrue("smoothing model is not equal to self", secondModel.equals(secondModel)); + assertTrue("smoothing model is not equal to its copy", firstModel.equals(secondModel)); + assertTrue("equals is not symmetric", secondModel.equals(firstModel)); + assertThat("smoothing model copy's hashcode is different from original hashcode", secondModel.hashCode(), equalTo(firstModel.hashCode())); + + SM thirdModel = (SM) copyModel(secondModel); + assertTrue("smoothing model is not equal to self", thirdModel.equals(thirdModel)); + assertTrue("smoothing model is not equal to its copy", secondModel.equals(thirdModel)); + assertThat("smoothing model copy's hashcode is different from original hashcode", secondModel.hashCode(), equalTo(thirdModel.hashCode())); + assertTrue("equals is not transitive", firstModel.equals(thirdModel)); + assertThat("smoothing model copy's hashcode is different from original hashcode", firstModel.hashCode(), equalTo(thirdModel.hashCode())); + assertTrue("equals is not symmetric", thirdModel.equals(secondModel)); + assertTrue("equals is not symmetric", thirdModel.equals(firstModel)); + } + + static SmoothingModel copyModel(SmoothingModel original) throws IOException { + try (BytesStreamOutput output = new BytesStreamOutput()) { + original.writeTo(output); + try (StreamInput in = new NamedWriteableAwareStreamInput(StreamInput.wrap(output.bytes()), namedWriteableRegistry)) { + SmoothingModel prototype = (SmoothingModel) namedWriteableRegistry.getPrototype(SmoothingModel.class, original.getWriteableName()); + return prototype.readFrom(in); + } + } + } + +} diff --git a/core/src/test/java/org/elasticsearch/search/suggest/phrase/StupidBackoffModelTests.java b/core/src/test/java/org/elasticsearch/search/suggest/phrase/StupidBackoffModelTests.java new file mode 100644 index 00000000000..5d774066e07 --- /dev/null +++ b/core/src/test/java/org/elasticsearch/search/suggest/phrase/StupidBackoffModelTests.java @@ -0,0 +1,38 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.search.suggest.phrase; + +import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.StupidBackoff; + +public class StupidBackoffModelTests extends SmoothingModelTest { + + @Override + protected StupidBackoff createTestModel() { + return new StupidBackoff(randomDoubleBetween(0.0, 10.0, false)); + } + + /** + * mutate the given model so the returned smoothing model is different + */ + @Override + protected StupidBackoff createMutation(StupidBackoff original) { + return new StupidBackoff(original.getDiscount() + 0.1); + } +}