Adding builder method to SmoothingModel implementations

Adds a method that emits a WordScorerFactory to all of the
three SmoothingModel implementatins that will be needed when
we switch to parsing the PhraseSuggestion on the coordinating
node and need to delay creating the WordScorer on the shards.
This commit is contained in:
Christoph Büscher 2016-01-27 18:56:19 +01:00
parent 513f4e6c57
commit aefdee17fd
8 changed files with 185 additions and 62 deletions

View File

@ -27,14 +27,14 @@ import org.elasticsearch.search.suggest.phrase.DirectCandidateGenerator.Candidat
import java.io.IOException;
//TODO public for tests
public final class LaplaceScorer extends WordScorer {
public static final WordScorerFactory FACTORY = new WordScorer.WordScorerFactory() {
@Override
public WordScorer newScorer(IndexReader reader, Terms terms, String field, double realWordLikelyhood, BytesRef separator) throws IOException {
return new LaplaceScorer(reader, terms, field, realWordLikelyhood, separator, 0.5);
}
};
private double alpha;
public LaplaceScorer(IndexReader reader, Terms terms, String field,
@ -42,7 +42,11 @@ public final class LaplaceScorer extends WordScorer {
super(reader, terms, field, realWordLikelyhood, separator);
this.alpha = alpha;
}
double alpha() {
return this.alpha;
}
@Override
protected double scoreBigram(Candidate word, Candidate w_1) throws IOException {
SuggestUtils.join(separator, spare, w_1.term, word.term);

View File

@ -41,7 +41,19 @@ public final class LinearInterpoatingScorer extends WordScorer {
this.bigramLambda = bigramLambda / sum;
this.trigramLambda = trigramLambda / sum;
}
double trigramLambda() {
return this.trigramLambda;
}
double bigramLambda() {
return this.bigramLambda;
}
double unigramLambda() {
return this.unigramLambda;
}
@Override
protected double scoreBigram(Candidate word, Candidate w_1) throws IOException {
SuggestUtils.join(separator, spare, w_1.term, word.term);

View File

@ -18,8 +18,12 @@
*/
package org.elasticsearch.search.suggest.phrase;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.Terms;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.ParseFieldMatcher;
import org.elasticsearch.common.ParsingException;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
@ -30,6 +34,7 @@ import org.elasticsearch.common.xcontent.XContentParser.Token;
import org.elasticsearch.index.query.QueryParseContext;
import org.elasticsearch.script.Template;
import org.elasticsearch.search.suggest.SuggestBuilder.SuggestionBuilder;
import org.elasticsearch.search.suggest.phrase.WordScorer.WordScorerFactory;
import java.io.IOException;
import java.util.ArrayList;
@ -50,7 +55,7 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder<PhraseSugge
private Float confidence;
private final Map<String, List<CandidateGenerator>> generators = new HashMap<>();
private Integer gramSize;
private SmoothingModel<?> model;
private SmoothingModel model;
private Boolean forceUnigrams;
private Integer tokenLimit;
private String preTag;
@ -159,7 +164,7 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder<PhraseSugge
* Sets an explicit smoothing model used for this suggester. The default is
* {@link PhraseSuggestionBuilder.StupidBackoff}.
*/
public PhraseSuggestionBuilder smoothingModel(SmoothingModel<?> model) {
public PhraseSuggestionBuilder smoothingModel(SmoothingModel model) {
this.model = model;
return this;
}
@ -292,7 +297,7 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder<PhraseSugge
* Smoothing</a> for details.
* </p>
*/
public static final class StupidBackoff extends SmoothingModel<StupidBackoff> {
public static final class StupidBackoff extends SmoothingModel {
/**
* Default discount parameter for {@link StupidBackoff} smoothing
*/
@ -341,8 +346,9 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder<PhraseSugge
}
@Override
protected boolean doEquals(StupidBackoff other) {
return Objects.equals(discount, other.discount);
protected boolean doEquals(SmoothingModel other) {
StupidBackoff otherModel = (StupidBackoff) other;
return Objects.equals(discount, otherModel.discount);
}
@Override
@ -351,7 +357,7 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder<PhraseSugge
}
@Override
public StupidBackoff fromXContent(QueryParseContext parseContext) throws IOException {
public SmoothingModel fromXContent(QueryParseContext parseContext) throws IOException {
XContentParser parser = parseContext.parser();
XContentParser.Token token;
String fieldName = null;
@ -366,6 +372,12 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder<PhraseSugge
}
return new StupidBackoff(discount);
}
@Override
public WordScorerFactory buildWordScorerFactory() {
return (IndexReader reader, Terms terms, String field, double realWordLikelyhood, BytesRef separator)
-> new StupidBackoffScorer(reader, terms, field, realWordLikelyhood, separator, discount);
}
}
/**
@ -377,7 +389,7 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder<PhraseSugge
* Smoothing</a> for details.
* </p>
*/
public static final class Laplace extends SmoothingModel<Laplace> {
public static final class Laplace extends SmoothingModel {
private double alpha = DEFAULT_LAPLACE_ALPHA;
private static final String NAME = "laplace";
private static final ParseField ALPHA_FIELD = new ParseField("alpha");
@ -419,13 +431,14 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder<PhraseSugge
}
@Override
public Laplace readFrom(StreamInput in) throws IOException {
public SmoothingModel readFrom(StreamInput in) throws IOException {
return new Laplace(in.readDouble());
}
@Override
protected boolean doEquals(Laplace other) {
return Objects.equals(alpha, other.alpha);
protected boolean doEquals(SmoothingModel other) {
Laplace otherModel = (Laplace) other;
return Objects.equals(alpha, otherModel.alpha);
}
@Override
@ -434,7 +447,7 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder<PhraseSugge
}
@Override
public Laplace fromXContent(QueryParseContext parseContext) throws IOException {
public SmoothingModel fromXContent(QueryParseContext parseContext) throws IOException {
XContentParser parser = parseContext.parser();
XContentParser.Token token;
String fieldName = null;
@ -449,10 +462,16 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder<PhraseSugge
}
return new Laplace(alpha);
}
@Override
public WordScorerFactory buildWordScorerFactory() {
return (IndexReader reader, Terms terms, String field, double realWordLikelyhood, BytesRef separator)
-> new LaplaceScorer(reader, terms, field, realWordLikelyhood, separator, alpha);
}
}
public static abstract class SmoothingModel<SM extends SmoothingModel<?>> implements NamedWriteable<SM>, ToXContent {
public static abstract class SmoothingModel implements NamedWriteable<SmoothingModel>, ToXContent {
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
@ -471,16 +490,18 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder<PhraseSugge
return false;
}
@SuppressWarnings("unchecked")
SM other = (SM) obj;
SmoothingModel other = (SmoothingModel) obj;
return doEquals(other);
}
public abstract SM fromXContent(QueryParseContext parseContext) throws IOException;
public abstract SmoothingModel fromXContent(QueryParseContext parseContext) throws IOException;
public abstract WordScorerFactory buildWordScorerFactory();
/**
* subtype specific implementation of "equals".
*/
protected abstract boolean doEquals(SM other);
protected abstract boolean doEquals(SmoothingModel other);
protected abstract XContentBuilder innerToXContent(XContentBuilder builder, Params params) throws IOException;
}
@ -493,7 +514,7 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder<PhraseSugge
* Smoothing</a> for details.
* </p>
*/
public static final class LinearInterpolation extends SmoothingModel<LinearInterpolation> {
public static final class LinearInterpolation extends SmoothingModel {
private static final String NAME = "linear";
static final LinearInterpolation PROTOTYPE = new LinearInterpolation(0.8, 0.1, 0.1);
private final double trigramLambda;
@ -563,10 +584,11 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder<PhraseSugge
}
@Override
protected boolean doEquals(LinearInterpolation other) {
return Objects.equals(trigramLambda, other.trigramLambda) &&
Objects.equals(bigramLambda, other.bigramLambda) &&
Objects.equals(unigramLambda, other.unigramLambda);
protected boolean doEquals(SmoothingModel other) {
final LinearInterpolation otherModel = (LinearInterpolation) other;
return Objects.equals(trigramLambda, otherModel.trigramLambda) &&
Objects.equals(bigramLambda, otherModel.bigramLambda) &&
Objects.equals(unigramLambda, otherModel.unigramLambda);
}
@Override
@ -579,35 +601,45 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder<PhraseSugge
XContentParser parser = parseContext.parser();
XContentParser.Token token;
String fieldName = null;
final double[] lambdas = new double[3];
double trigramLambda = 0.0;
double bigramLambda = 0.0;
double unigramLambda = 0.0;
ParseFieldMatcher matcher = parseContext.parseFieldMatcher();
while ((token = parser.nextToken()) != Token.END_OBJECT) {
if (token == XContentParser.Token.FIELD_NAME) {
fieldName = parser.currentName();
}
if (token.isValue()) {
} else if (token.isValue()) {
if (matcher.match(fieldName, TRIGRAM_FIELD)) {
lambdas[0] = parser.doubleValue();
if (lambdas[0] < 0) {
trigramLambda = parser.doubleValue();
if (trigramLambda < 0) {
throw new IllegalArgumentException("trigram_lambda must be positive");
}
} else if (matcher.match(fieldName, BIGRAM_FIELD)) {
lambdas[1] = parser.doubleValue();
if (lambdas[1] < 0) {
bigramLambda = parser.doubleValue();
if (bigramLambda < 0) {
throw new IllegalArgumentException("bigram_lambda must be positive");
}
} else if (matcher.match(fieldName, UNIGRAM_FIELD)) {
lambdas[2] = parser.doubleValue();
if (lambdas[2] < 0) {
unigramLambda = parser.doubleValue();
if (unigramLambda < 0) {
throw new IllegalArgumentException("unigram_lambda must be positive");
}
} else {
throw new IllegalArgumentException(
"suggester[phrase][smoothing][linear] doesn't support field [" + fieldName + "]");
}
} else {
throw new ParsingException(parser.getTokenLocation(), "[" + NAME + "] unknown token [" + token + "] after [" + fieldName + "]");
}
}
return new LinearInterpolation(lambdas[0], lambdas[1], lambdas[2]);
return new LinearInterpolation(trigramLambda, bigramLambda, unigramLambda);
}
@Override
public WordScorerFactory buildWordScorerFactory() {
return (IndexReader reader, Terms terms, String field, double realWordLikelyhood, BytesRef separator) ->
new LinearInterpoatingScorer(reader, terms, field, realWordLikelyhood, separator, trigramLambda, bigramLambda,
unigramLambda);
}
}

View File

@ -42,6 +42,10 @@ public class StupidBackoffScorer extends WordScorer {
this.discount = discount;
}
double discount() {
return this.discount;
}
@Override
protected double scoreBigram(Candidate word, Candidate w_1) throws IOException {
SuggestUtils.join(separator, spare, w_1.term, word.term);

View File

@ -20,11 +20,14 @@
package org.elasticsearch.search.suggest.phrase;
import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.Laplace;
import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.SmoothingModel;
public class LaplaceModelTests extends SmoothingModelTest<Laplace> {
import static org.hamcrest.Matchers.instanceOf;
public class LaplaceModelTests extends SmoothingModelTestCase {
@Override
protected Laplace createTestModel() {
protected SmoothingModel createTestModel() {
return new Laplace(randomDoubleBetween(0.0, 10.0, false));
}
@ -32,7 +35,15 @@ public class LaplaceModelTests extends SmoothingModelTest<Laplace> {
* mutate the given model so the returned smoothing model is different
*/
@Override
protected Laplace createMutation(Laplace original) {
protected Laplace createMutation(SmoothingModel input) {
Laplace original = (Laplace) input;
return new Laplace(original.getAlpha() + 0.1);
}
@Override
void assertWordScorer(WordScorer wordScorer, SmoothingModel input) {
Laplace model = (Laplace) input;
assertThat(wordScorer, instanceOf(LaplaceScorer.class));
assertEquals(model.getAlpha(), ((LaplaceScorer) wordScorer).alpha(), Double.MIN_VALUE);
}
}

View File

@ -20,15 +20,18 @@
package org.elasticsearch.search.suggest.phrase;
import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.LinearInterpolation;
import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.SmoothingModel;
public class LinearInterpolationModelTests extends SmoothingModelTest<LinearInterpolation> {
import static org.hamcrest.Matchers.instanceOf;
public class LinearInterpolationModelTests extends SmoothingModelTestCase {
@Override
protected LinearInterpolation createTestModel() {
protected SmoothingModel createTestModel() {
double trigramLambda = randomDoubleBetween(0.0, 10.0, false);
double bigramLambda = randomDoubleBetween(0.0, 10.0, false);
double unigramLambda = randomDoubleBetween(0.0, 10.0, false);
// normalize
// normalize so parameters sum to 1
double sum = trigramLambda + bigramLambda + unigramLambda;
return new LinearInterpolation(trigramLambda / sum, bigramLambda / sum, unigramLambda / sum);
}
@ -37,7 +40,8 @@ public class LinearInterpolationModelTests extends SmoothingModelTest<LinearInte
* mutate the given model so the returned smoothing model is different
*/
@Override
protected LinearInterpolation createMutation(LinearInterpolation original) {
protected LinearInterpolation createMutation(SmoothingModel input) {
LinearInterpolation original = (LinearInterpolation) input;
// swap two values permute original lambda values
switch (randomIntBetween(0, 2)) {
case 0:
@ -52,4 +56,14 @@ public class LinearInterpolationModelTests extends SmoothingModelTest<LinearInte
return new LinearInterpolation(original.getUnigramLambda(), original.getBigramLambda(), original.getTrigramLambda());
}
}
@Override
void assertWordScorer(WordScorer wordScorer, SmoothingModel in) {
LinearInterpolation testModel = (LinearInterpolation) in;
LinearInterpoatingScorer testScorer = (LinearInterpoatingScorer) wordScorer;
assertThat(wordScorer, instanceOf(LinearInterpoatingScorer.class));
assertEquals(testModel.getTrigramLambda(), (testScorer).trigramLambda(), 1e-15);
assertEquals(testModel.getBigramLambda(), (testScorer).bigramLambda(), 1e-15);
assertEquals(testModel.getUnigramLambda(), (testScorer).unigramLambda(), 1e-15);
}
}

View File

@ -19,11 +19,23 @@
package org.elasticsearch.search.suggest.phrase;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.core.WhitespaceAnalyzer;
import org.apache.lucene.analysis.miscellaneous.PerFieldAnalyzerWrapper;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.TextField;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.MultiFields;
import org.apache.lucene.store.RAMDirectory;
import org.elasticsearch.common.ParseFieldMatcher;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.lucene.BytesRefs;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
@ -43,14 +55,15 @@ import org.junit.BeforeClass;
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.not;
public abstract class SmoothingModelTest<SM extends SmoothingModel<?>> extends ESTestCase {
public abstract class SmoothingModelTestCase extends ESTestCase {
private static NamedWriteableRegistry namedWriteableRegistry;
private static IndicesQueriesRegistry indicesQueriesRegistry;
/**
* setup for the whole base test class
@ -63,33 +76,31 @@ public abstract class SmoothingModelTest<SM extends SmoothingModel<?>> extends E
namedWriteableRegistry.registerPrototype(SmoothingModel.class, LinearInterpolation.PROTOTYPE);
namedWriteableRegistry.registerPrototype(SmoothingModel.class, StupidBackoff.PROTOTYPE);
}
indicesQueriesRegistry = new IndicesQueriesRegistry(Settings.settingsBuilder().build(), Collections.emptySet(), namedWriteableRegistry);
}
@AfterClass
public static void afterClass() throws Exception {
namedWriteableRegistry = null;
indicesQueriesRegistry = null;
}
/**
* create random model that is put under test
*/
protected abstract SM createTestModel();
protected abstract SmoothingModel createTestModel();
/**
* mutate the given model so the returned smoothing model is different
*/
protected abstract SM createMutation(SM original) throws IOException;
protected abstract SmoothingModel createMutation(SmoothingModel original) throws IOException;
/**
* Test that creates new smoothing model from a random test smoothing model and checks both for equality
*/
public void testFromXContent() throws IOException {
QueryParseContext context = new QueryParseContext(indicesQueriesRegistry);
QueryParseContext context = new QueryParseContext(new IndicesQueriesRegistry(Settings.settingsBuilder().build(), Collections.emptyMap()));
context.parseFieldMatcher(new ParseFieldMatcher(Settings.EMPTY));
SM testModel = createTestModel();
SmoothingModel testModel = createTestModel();
XContentBuilder contentBuilder = XContentFactory.contentBuilder(randomFrom(XContentType.values()));
if (randomBoolean()) {
contentBuilder.prettyPrint();
@ -99,21 +110,45 @@ public abstract class SmoothingModelTest<SM extends SmoothingModel<?>> extends E
contentBuilder.endObject();
XContentParser parser = XContentHelper.createParser(contentBuilder.bytes());
context.reset(parser);
SmoothingModel<?> prototype = (SmoothingModel<?>) namedWriteableRegistry.getPrototype(SmoothingModel.class,
parser.nextToken(); // go to start token, real parsing would do that in the outer element parser
SmoothingModel prototype = (SmoothingModel) namedWriteableRegistry.getPrototype(SmoothingModel.class,
testModel.getWriteableName());
SmoothingModel<?> parsedModel = prototype.fromXContent(context);
SmoothingModel parsedModel = prototype.fromXContent(context);
assertNotSame(testModel, parsedModel);
assertEquals(testModel, parsedModel);
assertEquals(testModel.hashCode(), parsedModel.hashCode());
}
/**
* Test the WordScorer emitted by the smoothing model
*/
public void testBuildWordScorer() throws IOException {
SmoothingModel testModel = createTestModel();
Map<String, Analyzer> mapping = new HashMap<>();
mapping.put("field", new WhitespaceAnalyzer());
PerFieldAnalyzerWrapper wrapper = new PerFieldAnalyzerWrapper(new WhitespaceAnalyzer(), mapping);
IndexWriter writer = new IndexWriter(new RAMDirectory(), new IndexWriterConfig(wrapper));
Document doc = new Document();
doc.add(new Field("field", "someText", TextField.TYPE_NOT_STORED));
writer.addDocument(doc);
DirectoryReader ir = DirectoryReader.open(writer, false);
WordScorer wordScorer = testModel.buildWordScorerFactory().newScorer(ir, MultiFields.getTerms(ir , "field"), "field", 0.9d, BytesRefs.toBytesRef(" "));
assertWordScorer(wordScorer, testModel);
}
/**
* implementation dependant assertions on the wordScorer produced by the smoothing model under test
*/
abstract void assertWordScorer(WordScorer wordScorer, SmoothingModel testModel);
/**
* Test serialization and deserialization of the tested model.
*/
@SuppressWarnings("unchecked")
public void testSerialization() throws IOException {
SM testModel = createTestModel();
SM deserializedModel = (SM) copyModel(testModel);
SmoothingModel testModel = createTestModel();
SmoothingModel deserializedModel = copyModel(testModel);
assertEquals(testModel, deserializedModel);
assertEquals(testModel.hashCode(), deserializedModel.hashCode());
assertNotSame(testModel, deserializedModel);
@ -124,7 +159,7 @@ public abstract class SmoothingModelTest<SM extends SmoothingModel<?>> extends E
*/
@SuppressWarnings("unchecked")
public void testEqualsAndHashcode() throws IOException {
SM firstModel = createTestModel();
SmoothingModel firstModel = createTestModel();
assertFalse("smoothing model is equal to null", firstModel.equals(null));
assertFalse("smoothing model is equal to incompatible type", firstModel.equals(""));
assertTrue("smoothing model is not equal to self", firstModel.equals(firstModel));
@ -132,13 +167,13 @@ public abstract class SmoothingModelTest<SM extends SmoothingModel<?>> extends E
equalTo(firstModel.hashCode()));
assertThat("different smoothing models should not be equal", createMutation(firstModel), not(equalTo(firstModel)));
SM secondModel = (SM) copyModel(firstModel);
SmoothingModel secondModel = copyModel(firstModel);
assertTrue("smoothing model is not equal to self", secondModel.equals(secondModel));
assertTrue("smoothing model is not equal to its copy", firstModel.equals(secondModel));
assertTrue("equals is not symmetric", secondModel.equals(firstModel));
assertThat("smoothing model copy's hashcode is different from original hashcode", secondModel.hashCode(), equalTo(firstModel.hashCode()));
SM thirdModel = (SM) copyModel(secondModel);
SmoothingModel thirdModel = copyModel(secondModel);
assertTrue("smoothing model is not equal to self", thirdModel.equals(thirdModel));
assertTrue("smoothing model is not equal to its copy", secondModel.equals(thirdModel));
assertThat("smoothing model copy's hashcode is different from original hashcode", secondModel.hashCode(), equalTo(thirdModel.hashCode()));
@ -148,11 +183,11 @@ public abstract class SmoothingModelTest<SM extends SmoothingModel<?>> extends E
assertTrue("equals is not symmetric", thirdModel.equals(firstModel));
}
static SmoothingModel<?> copyModel(SmoothingModel<?> original) throws IOException {
static SmoothingModel copyModel(SmoothingModel original) throws IOException {
try (BytesStreamOutput output = new BytesStreamOutput()) {
original.writeTo(output);
try (StreamInput in = new NamedWriteableAwareStreamInput(StreamInput.wrap(output.bytes()), namedWriteableRegistry)) {
SmoothingModel<?> prototype = (SmoothingModel<?>) namedWriteableRegistry.getPrototype(SmoothingModel.class, original.getWriteableName());
SmoothingModel prototype = (SmoothingModel) namedWriteableRegistry.getPrototype(SmoothingModel.class, original.getWriteableName());
return prototype.readFrom(in);
}
}

View File

@ -19,12 +19,15 @@
package org.elasticsearch.search.suggest.phrase;
import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.SmoothingModel;
import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.StupidBackoff;
public class StupidBackoffModelTests extends SmoothingModelTest<StupidBackoff> {
import static org.hamcrest.Matchers.instanceOf;
public class StupidBackoffModelTests extends SmoothingModelTestCase {
@Override
protected StupidBackoff createTestModel() {
protected SmoothingModel createTestModel() {
return new StupidBackoff(randomDoubleBetween(0.0, 10.0, false));
}
@ -32,7 +35,15 @@ public class StupidBackoffModelTests extends SmoothingModelTest<StupidBackoff> {
* mutate the given model so the returned smoothing model is different
*/
@Override
protected StupidBackoff createMutation(StupidBackoff original) {
protected StupidBackoff createMutation(SmoothingModel input) {
StupidBackoff original = (StupidBackoff) input;
return new StupidBackoff(original.getDiscount() + 0.1);
}
@Override
void assertWordScorer(WordScorer wordScorer, SmoothingModel input) {
assertThat(wordScorer, instanceOf(StupidBackoffScorer.class));
StupidBackoff testModel = (StupidBackoff) input;
assertEquals(testModel.getDiscount(), ((StupidBackoffScorer) wordScorer).discount(), Double.MIN_VALUE);
}
}