Adding builder method to SmoothingModel implementations
Adds a method that emits a WordScorerFactory to all of the three SmoothingModel implementatins that will be needed when we switch to parsing the PhraseSuggestion on the coordinating node and need to delay creating the WordScorer on the shards.
This commit is contained in:
parent
513f4e6c57
commit
aefdee17fd
|
@ -27,14 +27,14 @@ import org.elasticsearch.search.suggest.phrase.DirectCandidateGenerator.Candidat
|
|||
import java.io.IOException;
|
||||
//TODO public for tests
|
||||
public final class LaplaceScorer extends WordScorer {
|
||||
|
||||
|
||||
public static final WordScorerFactory FACTORY = new WordScorer.WordScorerFactory() {
|
||||
@Override
|
||||
public WordScorer newScorer(IndexReader reader, Terms terms, String field, double realWordLikelyhood, BytesRef separator) throws IOException {
|
||||
return new LaplaceScorer(reader, terms, field, realWordLikelyhood, separator, 0.5);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
private double alpha;
|
||||
|
||||
public LaplaceScorer(IndexReader reader, Terms terms, String field,
|
||||
|
@ -42,7 +42,11 @@ public final class LaplaceScorer extends WordScorer {
|
|||
super(reader, terms, field, realWordLikelyhood, separator);
|
||||
this.alpha = alpha;
|
||||
}
|
||||
|
||||
|
||||
double alpha() {
|
||||
return this.alpha;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected double scoreBigram(Candidate word, Candidate w_1) throws IOException {
|
||||
SuggestUtils.join(separator, spare, w_1.term, word.term);
|
||||
|
|
|
@ -41,7 +41,19 @@ public final class LinearInterpoatingScorer extends WordScorer {
|
|||
this.bigramLambda = bigramLambda / sum;
|
||||
this.trigramLambda = trigramLambda / sum;
|
||||
}
|
||||
|
||||
|
||||
double trigramLambda() {
|
||||
return this.trigramLambda;
|
||||
}
|
||||
|
||||
double bigramLambda() {
|
||||
return this.bigramLambda;
|
||||
}
|
||||
|
||||
double unigramLambda() {
|
||||
return this.unigramLambda;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected double scoreBigram(Candidate word, Candidate w_1) throws IOException {
|
||||
SuggestUtils.join(separator, spare, w_1.term, word.term);
|
||||
|
|
|
@ -18,8 +18,12 @@
|
|||
*/
|
||||
package org.elasticsearch.search.suggest.phrase;
|
||||
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.Terms;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.ParseFieldMatcher;
|
||||
import org.elasticsearch.common.ParsingException;
|
||||
import org.elasticsearch.common.io.stream.NamedWriteable;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
|
@ -30,6 +34,7 @@ import org.elasticsearch.common.xcontent.XContentParser.Token;
|
|||
import org.elasticsearch.index.query.QueryParseContext;
|
||||
import org.elasticsearch.script.Template;
|
||||
import org.elasticsearch.search.suggest.SuggestBuilder.SuggestionBuilder;
|
||||
import org.elasticsearch.search.suggest.phrase.WordScorer.WordScorerFactory;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
|
@ -50,7 +55,7 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder<PhraseSugge
|
|||
private Float confidence;
|
||||
private final Map<String, List<CandidateGenerator>> generators = new HashMap<>();
|
||||
private Integer gramSize;
|
||||
private SmoothingModel<?> model;
|
||||
private SmoothingModel model;
|
||||
private Boolean forceUnigrams;
|
||||
private Integer tokenLimit;
|
||||
private String preTag;
|
||||
|
@ -159,7 +164,7 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder<PhraseSugge
|
|||
* Sets an explicit smoothing model used for this suggester. The default is
|
||||
* {@link PhraseSuggestionBuilder.StupidBackoff}.
|
||||
*/
|
||||
public PhraseSuggestionBuilder smoothingModel(SmoothingModel<?> model) {
|
||||
public PhraseSuggestionBuilder smoothingModel(SmoothingModel model) {
|
||||
this.model = model;
|
||||
return this;
|
||||
}
|
||||
|
@ -292,7 +297,7 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder<PhraseSugge
|
|||
* Smoothing</a> for details.
|
||||
* </p>
|
||||
*/
|
||||
public static final class StupidBackoff extends SmoothingModel<StupidBackoff> {
|
||||
public static final class StupidBackoff extends SmoothingModel {
|
||||
/**
|
||||
* Default discount parameter for {@link StupidBackoff} smoothing
|
||||
*/
|
||||
|
@ -341,8 +346,9 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder<PhraseSugge
|
|||
}
|
||||
|
||||
@Override
|
||||
protected boolean doEquals(StupidBackoff other) {
|
||||
return Objects.equals(discount, other.discount);
|
||||
protected boolean doEquals(SmoothingModel other) {
|
||||
StupidBackoff otherModel = (StupidBackoff) other;
|
||||
return Objects.equals(discount, otherModel.discount);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -351,7 +357,7 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder<PhraseSugge
|
|||
}
|
||||
|
||||
@Override
|
||||
public StupidBackoff fromXContent(QueryParseContext parseContext) throws IOException {
|
||||
public SmoothingModel fromXContent(QueryParseContext parseContext) throws IOException {
|
||||
XContentParser parser = parseContext.parser();
|
||||
XContentParser.Token token;
|
||||
String fieldName = null;
|
||||
|
@ -366,6 +372,12 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder<PhraseSugge
|
|||
}
|
||||
return new StupidBackoff(discount);
|
||||
}
|
||||
|
||||
@Override
|
||||
public WordScorerFactory buildWordScorerFactory() {
|
||||
return (IndexReader reader, Terms terms, String field, double realWordLikelyhood, BytesRef separator)
|
||||
-> new StupidBackoffScorer(reader, terms, field, realWordLikelyhood, separator, discount);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -377,7 +389,7 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder<PhraseSugge
|
|||
* Smoothing</a> for details.
|
||||
* </p>
|
||||
*/
|
||||
public static final class Laplace extends SmoothingModel<Laplace> {
|
||||
public static final class Laplace extends SmoothingModel {
|
||||
private double alpha = DEFAULT_LAPLACE_ALPHA;
|
||||
private static final String NAME = "laplace";
|
||||
private static final ParseField ALPHA_FIELD = new ParseField("alpha");
|
||||
|
@ -419,13 +431,14 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder<PhraseSugge
|
|||
}
|
||||
|
||||
@Override
|
||||
public Laplace readFrom(StreamInput in) throws IOException {
|
||||
public SmoothingModel readFrom(StreamInput in) throws IOException {
|
||||
return new Laplace(in.readDouble());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean doEquals(Laplace other) {
|
||||
return Objects.equals(alpha, other.alpha);
|
||||
protected boolean doEquals(SmoothingModel other) {
|
||||
Laplace otherModel = (Laplace) other;
|
||||
return Objects.equals(alpha, otherModel.alpha);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -434,7 +447,7 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder<PhraseSugge
|
|||
}
|
||||
|
||||
@Override
|
||||
public Laplace fromXContent(QueryParseContext parseContext) throws IOException {
|
||||
public SmoothingModel fromXContent(QueryParseContext parseContext) throws IOException {
|
||||
XContentParser parser = parseContext.parser();
|
||||
XContentParser.Token token;
|
||||
String fieldName = null;
|
||||
|
@ -449,10 +462,16 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder<PhraseSugge
|
|||
}
|
||||
return new Laplace(alpha);
|
||||
}
|
||||
|
||||
@Override
|
||||
public WordScorerFactory buildWordScorerFactory() {
|
||||
return (IndexReader reader, Terms terms, String field, double realWordLikelyhood, BytesRef separator)
|
||||
-> new LaplaceScorer(reader, terms, field, realWordLikelyhood, separator, alpha);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
public static abstract class SmoothingModel<SM extends SmoothingModel<?>> implements NamedWriteable<SM>, ToXContent {
|
||||
public static abstract class SmoothingModel implements NamedWriteable<SmoothingModel>, ToXContent {
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
|
@ -471,16 +490,18 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder<PhraseSugge
|
|||
return false;
|
||||
}
|
||||
@SuppressWarnings("unchecked")
|
||||
SM other = (SM) obj;
|
||||
SmoothingModel other = (SmoothingModel) obj;
|
||||
return doEquals(other);
|
||||
}
|
||||
|
||||
public abstract SM fromXContent(QueryParseContext parseContext) throws IOException;
|
||||
public abstract SmoothingModel fromXContent(QueryParseContext parseContext) throws IOException;
|
||||
|
||||
public abstract WordScorerFactory buildWordScorerFactory();
|
||||
|
||||
/**
|
||||
* subtype specific implementation of "equals".
|
||||
*/
|
||||
protected abstract boolean doEquals(SM other);
|
||||
protected abstract boolean doEquals(SmoothingModel other);
|
||||
|
||||
protected abstract XContentBuilder innerToXContent(XContentBuilder builder, Params params) throws IOException;
|
||||
}
|
||||
|
@ -493,7 +514,7 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder<PhraseSugge
|
|||
* Smoothing</a> for details.
|
||||
* </p>
|
||||
*/
|
||||
public static final class LinearInterpolation extends SmoothingModel<LinearInterpolation> {
|
||||
public static final class LinearInterpolation extends SmoothingModel {
|
||||
private static final String NAME = "linear";
|
||||
static final LinearInterpolation PROTOTYPE = new LinearInterpolation(0.8, 0.1, 0.1);
|
||||
private final double trigramLambda;
|
||||
|
@ -563,10 +584,11 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder<PhraseSugge
|
|||
}
|
||||
|
||||
@Override
|
||||
protected boolean doEquals(LinearInterpolation other) {
|
||||
return Objects.equals(trigramLambda, other.trigramLambda) &&
|
||||
Objects.equals(bigramLambda, other.bigramLambda) &&
|
||||
Objects.equals(unigramLambda, other.unigramLambda);
|
||||
protected boolean doEquals(SmoothingModel other) {
|
||||
final LinearInterpolation otherModel = (LinearInterpolation) other;
|
||||
return Objects.equals(trigramLambda, otherModel.trigramLambda) &&
|
||||
Objects.equals(bigramLambda, otherModel.bigramLambda) &&
|
||||
Objects.equals(unigramLambda, otherModel.unigramLambda);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -579,35 +601,45 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder<PhraseSugge
|
|||
XContentParser parser = parseContext.parser();
|
||||
XContentParser.Token token;
|
||||
String fieldName = null;
|
||||
final double[] lambdas = new double[3];
|
||||
double trigramLambda = 0.0;
|
||||
double bigramLambda = 0.0;
|
||||
double unigramLambda = 0.0;
|
||||
ParseFieldMatcher matcher = parseContext.parseFieldMatcher();
|
||||
while ((token = parser.nextToken()) != Token.END_OBJECT) {
|
||||
if (token == XContentParser.Token.FIELD_NAME) {
|
||||
fieldName = parser.currentName();
|
||||
}
|
||||
if (token.isValue()) {
|
||||
} else if (token.isValue()) {
|
||||
if (matcher.match(fieldName, TRIGRAM_FIELD)) {
|
||||
lambdas[0] = parser.doubleValue();
|
||||
if (lambdas[0] < 0) {
|
||||
trigramLambda = parser.doubleValue();
|
||||
if (trigramLambda < 0) {
|
||||
throw new IllegalArgumentException("trigram_lambda must be positive");
|
||||
}
|
||||
} else if (matcher.match(fieldName, BIGRAM_FIELD)) {
|
||||
lambdas[1] = parser.doubleValue();
|
||||
if (lambdas[1] < 0) {
|
||||
bigramLambda = parser.doubleValue();
|
||||
if (bigramLambda < 0) {
|
||||
throw new IllegalArgumentException("bigram_lambda must be positive");
|
||||
}
|
||||
} else if (matcher.match(fieldName, UNIGRAM_FIELD)) {
|
||||
lambdas[2] = parser.doubleValue();
|
||||
if (lambdas[2] < 0) {
|
||||
unigramLambda = parser.doubleValue();
|
||||
if (unigramLambda < 0) {
|
||||
throw new IllegalArgumentException("unigram_lambda must be positive");
|
||||
}
|
||||
} else {
|
||||
throw new IllegalArgumentException(
|
||||
"suggester[phrase][smoothing][linear] doesn't support field [" + fieldName + "]");
|
||||
}
|
||||
} else {
|
||||
throw new ParsingException(parser.getTokenLocation(), "[" + NAME + "] unknown token [" + token + "] after [" + fieldName + "]");
|
||||
}
|
||||
}
|
||||
return new LinearInterpolation(lambdas[0], lambdas[1], lambdas[2]);
|
||||
return new LinearInterpolation(trigramLambda, bigramLambda, unigramLambda);
|
||||
}
|
||||
|
||||
@Override
|
||||
public WordScorerFactory buildWordScorerFactory() {
|
||||
return (IndexReader reader, Terms terms, String field, double realWordLikelyhood, BytesRef separator) ->
|
||||
new LinearInterpoatingScorer(reader, terms, field, realWordLikelyhood, separator, trigramLambda, bigramLambda,
|
||||
unigramLambda);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -42,6 +42,10 @@ public class StupidBackoffScorer extends WordScorer {
|
|||
this.discount = discount;
|
||||
}
|
||||
|
||||
double discount() {
|
||||
return this.discount;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected double scoreBigram(Candidate word, Candidate w_1) throws IOException {
|
||||
SuggestUtils.join(separator, spare, w_1.term, word.term);
|
||||
|
|
|
@ -20,11 +20,14 @@
|
|||
package org.elasticsearch.search.suggest.phrase;
|
||||
|
||||
import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.Laplace;
|
||||
import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.SmoothingModel;
|
||||
|
||||
public class LaplaceModelTests extends SmoothingModelTest<Laplace> {
|
||||
import static org.hamcrest.Matchers.instanceOf;
|
||||
|
||||
public class LaplaceModelTests extends SmoothingModelTestCase {
|
||||
|
||||
@Override
|
||||
protected Laplace createTestModel() {
|
||||
protected SmoothingModel createTestModel() {
|
||||
return new Laplace(randomDoubleBetween(0.0, 10.0, false));
|
||||
}
|
||||
|
||||
|
@ -32,7 +35,15 @@ public class LaplaceModelTests extends SmoothingModelTest<Laplace> {
|
|||
* mutate the given model so the returned smoothing model is different
|
||||
*/
|
||||
@Override
|
||||
protected Laplace createMutation(Laplace original) {
|
||||
protected Laplace createMutation(SmoothingModel input) {
|
||||
Laplace original = (Laplace) input;
|
||||
return new Laplace(original.getAlpha() + 0.1);
|
||||
}
|
||||
|
||||
@Override
|
||||
void assertWordScorer(WordScorer wordScorer, SmoothingModel input) {
|
||||
Laplace model = (Laplace) input;
|
||||
assertThat(wordScorer, instanceOf(LaplaceScorer.class));
|
||||
assertEquals(model.getAlpha(), ((LaplaceScorer) wordScorer).alpha(), Double.MIN_VALUE);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,15 +20,18 @@
|
|||
package org.elasticsearch.search.suggest.phrase;
|
||||
|
||||
import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.LinearInterpolation;
|
||||
import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.SmoothingModel;
|
||||
|
||||
public class LinearInterpolationModelTests extends SmoothingModelTest<LinearInterpolation> {
|
||||
import static org.hamcrest.Matchers.instanceOf;
|
||||
|
||||
public class LinearInterpolationModelTests extends SmoothingModelTestCase {
|
||||
|
||||
@Override
|
||||
protected LinearInterpolation createTestModel() {
|
||||
protected SmoothingModel createTestModel() {
|
||||
double trigramLambda = randomDoubleBetween(0.0, 10.0, false);
|
||||
double bigramLambda = randomDoubleBetween(0.0, 10.0, false);
|
||||
double unigramLambda = randomDoubleBetween(0.0, 10.0, false);
|
||||
// normalize
|
||||
// normalize so parameters sum to 1
|
||||
double sum = trigramLambda + bigramLambda + unigramLambda;
|
||||
return new LinearInterpolation(trigramLambda / sum, bigramLambda / sum, unigramLambda / sum);
|
||||
}
|
||||
|
@ -37,7 +40,8 @@ public class LinearInterpolationModelTests extends SmoothingModelTest<LinearInte
|
|||
* mutate the given model so the returned smoothing model is different
|
||||
*/
|
||||
@Override
|
||||
protected LinearInterpolation createMutation(LinearInterpolation original) {
|
||||
protected LinearInterpolation createMutation(SmoothingModel input) {
|
||||
LinearInterpolation original = (LinearInterpolation) input;
|
||||
// swap two values permute original lambda values
|
||||
switch (randomIntBetween(0, 2)) {
|
||||
case 0:
|
||||
|
@ -52,4 +56,14 @@ public class LinearInterpolationModelTests extends SmoothingModelTest<LinearInte
|
|||
return new LinearInterpolation(original.getUnigramLambda(), original.getBigramLambda(), original.getTrigramLambda());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
void assertWordScorer(WordScorer wordScorer, SmoothingModel in) {
|
||||
LinearInterpolation testModel = (LinearInterpolation) in;
|
||||
LinearInterpoatingScorer testScorer = (LinearInterpoatingScorer) wordScorer;
|
||||
assertThat(wordScorer, instanceOf(LinearInterpoatingScorer.class));
|
||||
assertEquals(testModel.getTrigramLambda(), (testScorer).trigramLambda(), 1e-15);
|
||||
assertEquals(testModel.getBigramLambda(), (testScorer).bigramLambda(), 1e-15);
|
||||
assertEquals(testModel.getUnigramLambda(), (testScorer).unigramLambda(), 1e-15);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,11 +19,23 @@
|
|||
|
||||
package org.elasticsearch.search.suggest.phrase;
|
||||
|
||||
import org.apache.lucene.analysis.Analyzer;
|
||||
import org.apache.lucene.analysis.core.WhitespaceAnalyzer;
|
||||
import org.apache.lucene.analysis.miscellaneous.PerFieldAnalyzerWrapper;
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.document.Field;
|
||||
import org.apache.lucene.document.TextField;
|
||||
import org.apache.lucene.index.DirectoryReader;
|
||||
import org.apache.lucene.index.IndexWriter;
|
||||
import org.apache.lucene.index.IndexWriterConfig;
|
||||
import org.apache.lucene.index.MultiFields;
|
||||
import org.apache.lucene.store.RAMDirectory;
|
||||
import org.elasticsearch.common.ParseFieldMatcher;
|
||||
import org.elasticsearch.common.io.stream.BytesStreamOutput;
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput;
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.lucene.BytesRefs;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.xcontent.ToXContent;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
|
@ -43,14 +55,15 @@ import org.junit.BeforeClass;
|
|||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.not;
|
||||
|
||||
public abstract class SmoothingModelTest<SM extends SmoothingModel<?>> extends ESTestCase {
|
||||
public abstract class SmoothingModelTestCase extends ESTestCase {
|
||||
|
||||
private static NamedWriteableRegistry namedWriteableRegistry;
|
||||
private static IndicesQueriesRegistry indicesQueriesRegistry;
|
||||
|
||||
/**
|
||||
* setup for the whole base test class
|
||||
|
@ -63,33 +76,31 @@ public abstract class SmoothingModelTest<SM extends SmoothingModel<?>> extends E
|
|||
namedWriteableRegistry.registerPrototype(SmoothingModel.class, LinearInterpolation.PROTOTYPE);
|
||||
namedWriteableRegistry.registerPrototype(SmoothingModel.class, StupidBackoff.PROTOTYPE);
|
||||
}
|
||||
indicesQueriesRegistry = new IndicesQueriesRegistry(Settings.settingsBuilder().build(), Collections.emptySet(), namedWriteableRegistry);
|
||||
}
|
||||
|
||||
@AfterClass
|
||||
public static void afterClass() throws Exception {
|
||||
namedWriteableRegistry = null;
|
||||
indicesQueriesRegistry = null;
|
||||
}
|
||||
|
||||
/**
|
||||
* create random model that is put under test
|
||||
*/
|
||||
protected abstract SM createTestModel();
|
||||
protected abstract SmoothingModel createTestModel();
|
||||
|
||||
/**
|
||||
* mutate the given model so the returned smoothing model is different
|
||||
*/
|
||||
protected abstract SM createMutation(SM original) throws IOException;
|
||||
protected abstract SmoothingModel createMutation(SmoothingModel original) throws IOException;
|
||||
|
||||
/**
|
||||
* Test that creates new smoothing model from a random test smoothing model and checks both for equality
|
||||
*/
|
||||
public void testFromXContent() throws IOException {
|
||||
QueryParseContext context = new QueryParseContext(indicesQueriesRegistry);
|
||||
QueryParseContext context = new QueryParseContext(new IndicesQueriesRegistry(Settings.settingsBuilder().build(), Collections.emptyMap()));
|
||||
context.parseFieldMatcher(new ParseFieldMatcher(Settings.EMPTY));
|
||||
|
||||
SM testModel = createTestModel();
|
||||
SmoothingModel testModel = createTestModel();
|
||||
XContentBuilder contentBuilder = XContentFactory.contentBuilder(randomFrom(XContentType.values()));
|
||||
if (randomBoolean()) {
|
||||
contentBuilder.prettyPrint();
|
||||
|
@ -99,21 +110,45 @@ public abstract class SmoothingModelTest<SM extends SmoothingModel<?>> extends E
|
|||
contentBuilder.endObject();
|
||||
XContentParser parser = XContentHelper.createParser(contentBuilder.bytes());
|
||||
context.reset(parser);
|
||||
SmoothingModel<?> prototype = (SmoothingModel<?>) namedWriteableRegistry.getPrototype(SmoothingModel.class,
|
||||
parser.nextToken(); // go to start token, real parsing would do that in the outer element parser
|
||||
SmoothingModel prototype = (SmoothingModel) namedWriteableRegistry.getPrototype(SmoothingModel.class,
|
||||
testModel.getWriteableName());
|
||||
SmoothingModel<?> parsedModel = prototype.fromXContent(context);
|
||||
SmoothingModel parsedModel = prototype.fromXContent(context);
|
||||
assertNotSame(testModel, parsedModel);
|
||||
assertEquals(testModel, parsedModel);
|
||||
assertEquals(testModel.hashCode(), parsedModel.hashCode());
|
||||
}
|
||||
|
||||
/**
|
||||
* Test the WordScorer emitted by the smoothing model
|
||||
*/
|
||||
public void testBuildWordScorer() throws IOException {
|
||||
SmoothingModel testModel = createTestModel();
|
||||
|
||||
Map<String, Analyzer> mapping = new HashMap<>();
|
||||
mapping.put("field", new WhitespaceAnalyzer());
|
||||
PerFieldAnalyzerWrapper wrapper = new PerFieldAnalyzerWrapper(new WhitespaceAnalyzer(), mapping);
|
||||
IndexWriter writer = new IndexWriter(new RAMDirectory(), new IndexWriterConfig(wrapper));
|
||||
Document doc = new Document();
|
||||
doc.add(new Field("field", "someText", TextField.TYPE_NOT_STORED));
|
||||
writer.addDocument(doc);
|
||||
DirectoryReader ir = DirectoryReader.open(writer, false);
|
||||
|
||||
WordScorer wordScorer = testModel.buildWordScorerFactory().newScorer(ir, MultiFields.getTerms(ir , "field"), "field", 0.9d, BytesRefs.toBytesRef(" "));
|
||||
assertWordScorer(wordScorer, testModel);
|
||||
}
|
||||
|
||||
/**
|
||||
* implementation dependant assertions on the wordScorer produced by the smoothing model under test
|
||||
*/
|
||||
abstract void assertWordScorer(WordScorer wordScorer, SmoothingModel testModel);
|
||||
|
||||
/**
|
||||
* Test serialization and deserialization of the tested model.
|
||||
*/
|
||||
@SuppressWarnings("unchecked")
|
||||
public void testSerialization() throws IOException {
|
||||
SM testModel = createTestModel();
|
||||
SM deserializedModel = (SM) copyModel(testModel);
|
||||
SmoothingModel testModel = createTestModel();
|
||||
SmoothingModel deserializedModel = copyModel(testModel);
|
||||
assertEquals(testModel, deserializedModel);
|
||||
assertEquals(testModel.hashCode(), deserializedModel.hashCode());
|
||||
assertNotSame(testModel, deserializedModel);
|
||||
|
@ -124,7 +159,7 @@ public abstract class SmoothingModelTest<SM extends SmoothingModel<?>> extends E
|
|||
*/
|
||||
@SuppressWarnings("unchecked")
|
||||
public void testEqualsAndHashcode() throws IOException {
|
||||
SM firstModel = createTestModel();
|
||||
SmoothingModel firstModel = createTestModel();
|
||||
assertFalse("smoothing model is equal to null", firstModel.equals(null));
|
||||
assertFalse("smoothing model is equal to incompatible type", firstModel.equals(""));
|
||||
assertTrue("smoothing model is not equal to self", firstModel.equals(firstModel));
|
||||
|
@ -132,13 +167,13 @@ public abstract class SmoothingModelTest<SM extends SmoothingModel<?>> extends E
|
|||
equalTo(firstModel.hashCode()));
|
||||
assertThat("different smoothing models should not be equal", createMutation(firstModel), not(equalTo(firstModel)));
|
||||
|
||||
SM secondModel = (SM) copyModel(firstModel);
|
||||
SmoothingModel secondModel = copyModel(firstModel);
|
||||
assertTrue("smoothing model is not equal to self", secondModel.equals(secondModel));
|
||||
assertTrue("smoothing model is not equal to its copy", firstModel.equals(secondModel));
|
||||
assertTrue("equals is not symmetric", secondModel.equals(firstModel));
|
||||
assertThat("smoothing model copy's hashcode is different from original hashcode", secondModel.hashCode(), equalTo(firstModel.hashCode()));
|
||||
|
||||
SM thirdModel = (SM) copyModel(secondModel);
|
||||
SmoothingModel thirdModel = copyModel(secondModel);
|
||||
assertTrue("smoothing model is not equal to self", thirdModel.equals(thirdModel));
|
||||
assertTrue("smoothing model is not equal to its copy", secondModel.equals(thirdModel));
|
||||
assertThat("smoothing model copy's hashcode is different from original hashcode", secondModel.hashCode(), equalTo(thirdModel.hashCode()));
|
||||
|
@ -148,11 +183,11 @@ public abstract class SmoothingModelTest<SM extends SmoothingModel<?>> extends E
|
|||
assertTrue("equals is not symmetric", thirdModel.equals(firstModel));
|
||||
}
|
||||
|
||||
static SmoothingModel<?> copyModel(SmoothingModel<?> original) throws IOException {
|
||||
static SmoothingModel copyModel(SmoothingModel original) throws IOException {
|
||||
try (BytesStreamOutput output = new BytesStreamOutput()) {
|
||||
original.writeTo(output);
|
||||
try (StreamInput in = new NamedWriteableAwareStreamInput(StreamInput.wrap(output.bytes()), namedWriteableRegistry)) {
|
||||
SmoothingModel<?> prototype = (SmoothingModel<?>) namedWriteableRegistry.getPrototype(SmoothingModel.class, original.getWriteableName());
|
||||
SmoothingModel prototype = (SmoothingModel) namedWriteableRegistry.getPrototype(SmoothingModel.class, original.getWriteableName());
|
||||
return prototype.readFrom(in);
|
||||
}
|
||||
}
|
|
@ -19,12 +19,15 @@
|
|||
|
||||
package org.elasticsearch.search.suggest.phrase;
|
||||
|
||||
import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.SmoothingModel;
|
||||
import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.StupidBackoff;
|
||||
|
||||
public class StupidBackoffModelTests extends SmoothingModelTest<StupidBackoff> {
|
||||
import static org.hamcrest.Matchers.instanceOf;
|
||||
|
||||
public class StupidBackoffModelTests extends SmoothingModelTestCase {
|
||||
|
||||
@Override
|
||||
protected StupidBackoff createTestModel() {
|
||||
protected SmoothingModel createTestModel() {
|
||||
return new StupidBackoff(randomDoubleBetween(0.0, 10.0, false));
|
||||
}
|
||||
|
||||
|
@ -32,7 +35,15 @@ public class StupidBackoffModelTests extends SmoothingModelTest<StupidBackoff> {
|
|||
* mutate the given model so the returned smoothing model is different
|
||||
*/
|
||||
@Override
|
||||
protected StupidBackoff createMutation(StupidBackoff original) {
|
||||
protected StupidBackoff createMutation(SmoothingModel input) {
|
||||
StupidBackoff original = (StupidBackoff) input;
|
||||
return new StupidBackoff(original.getDiscount() + 0.1);
|
||||
}
|
||||
|
||||
@Override
|
||||
void assertWordScorer(WordScorer wordScorer, SmoothingModel input) {
|
||||
assertThat(wordScorer, instanceOf(StupidBackoffScorer.class));
|
||||
StupidBackoff testModel = (StupidBackoff) input;
|
||||
assertEquals(testModel.getDiscount(), ((StupidBackoffScorer) wordScorer).discount(), Double.MIN_VALUE);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue