diff --git a/core/src/main/java/org/elasticsearch/search/suggest/phrase/LaplaceScorer.java b/core/src/main/java/org/elasticsearch/search/suggest/phrase/LaplaceScorer.java
index 04d98c3827d..678f3082bac 100644
--- a/core/src/main/java/org/elasticsearch/search/suggest/phrase/LaplaceScorer.java
+++ b/core/src/main/java/org/elasticsearch/search/suggest/phrase/LaplaceScorer.java
@@ -27,14 +27,14 @@ import org.elasticsearch.search.suggest.phrase.DirectCandidateGenerator.Candidat
import java.io.IOException;
//TODO public for tests
public final class LaplaceScorer extends WordScorer {
-
+
public static final WordScorerFactory FACTORY = new WordScorer.WordScorerFactory() {
@Override
public WordScorer newScorer(IndexReader reader, Terms terms, String field, double realWordLikelyhood, BytesRef separator) throws IOException {
return new LaplaceScorer(reader, terms, field, realWordLikelyhood, separator, 0.5);
}
};
-
+
private double alpha;
public LaplaceScorer(IndexReader reader, Terms terms, String field,
@@ -42,7 +42,11 @@ public final class LaplaceScorer extends WordScorer {
super(reader, terms, field, realWordLikelyhood, separator);
this.alpha = alpha;
}
-
+
+ double alpha() {
+ return this.alpha;
+ }
+
@Override
protected double scoreBigram(Candidate word, Candidate w_1) throws IOException {
SuggestUtils.join(separator, spare, w_1.term, word.term);
diff --git a/core/src/main/java/org/elasticsearch/search/suggest/phrase/LinearInterpoatingScorer.java b/core/src/main/java/org/elasticsearch/search/suggest/phrase/LinearInterpoatingScorer.java
index d2b1ba48b13..368d461fc53 100644
--- a/core/src/main/java/org/elasticsearch/search/suggest/phrase/LinearInterpoatingScorer.java
+++ b/core/src/main/java/org/elasticsearch/search/suggest/phrase/LinearInterpoatingScorer.java
@@ -41,7 +41,19 @@ public final class LinearInterpoatingScorer extends WordScorer {
this.bigramLambda = bigramLambda / sum;
this.trigramLambda = trigramLambda / sum;
}
-
+
+ double trigramLambda() {
+ return this.trigramLambda;
+ }
+
+ double bigramLambda() {
+ return this.bigramLambda;
+ }
+
+ double unigramLambda() {
+ return this.unigramLambda;
+ }
+
@Override
protected double scoreBigram(Candidate word, Candidate w_1) throws IOException {
SuggestUtils.join(separator, spare, w_1.term, word.term);
diff --git a/core/src/main/java/org/elasticsearch/search/suggest/phrase/PhraseSuggestionBuilder.java b/core/src/main/java/org/elasticsearch/search/suggest/phrase/PhraseSuggestionBuilder.java
index 97ca09d25a1..0e1fec6c7b2 100644
--- a/core/src/main/java/org/elasticsearch/search/suggest/phrase/PhraseSuggestionBuilder.java
+++ b/core/src/main/java/org/elasticsearch/search/suggest/phrase/PhraseSuggestionBuilder.java
@@ -18,8 +18,12 @@
*/
package org.elasticsearch.search.suggest.phrase;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.Terms;
+import org.apache.lucene.util.BytesRef;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.ParseFieldMatcher;
+import org.elasticsearch.common.ParsingException;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
@@ -30,6 +34,7 @@ import org.elasticsearch.common.xcontent.XContentParser.Token;
import org.elasticsearch.index.query.QueryParseContext;
import org.elasticsearch.script.Template;
import org.elasticsearch.search.suggest.SuggestBuilder.SuggestionBuilder;
+import org.elasticsearch.search.suggest.phrase.WordScorer.WordScorerFactory;
import java.io.IOException;
import java.util.ArrayList;
@@ -50,7 +55,7 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder> generators = new HashMap<>();
private Integer gramSize;
- private SmoothingModel> model;
+ private SmoothingModel model;
private Boolean forceUnigrams;
private Integer tokenLimit;
private String preTag;
@@ -159,7 +164,7 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder model) {
+ public PhraseSuggestionBuilder smoothingModel(SmoothingModel model) {
this.model = model;
return this;
}
@@ -292,7 +297,7 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder for details.
*
*/
- public static final class StupidBackoff extends SmoothingModel {
+ public static final class StupidBackoff extends SmoothingModel {
/**
* Default discount parameter for {@link StupidBackoff} smoothing
*/
@@ -341,8 +346,9 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder new StupidBackoffScorer(reader, terms, field, realWordLikelyhood, separator, discount);
+ }
}
/**
@@ -377,7 +389,7 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder for details.
*
*/
- public static final class Laplace extends SmoothingModel {
+ public static final class Laplace extends SmoothingModel {
private double alpha = DEFAULT_LAPLACE_ALPHA;
private static final String NAME = "laplace";
private static final ParseField ALPHA_FIELD = new ParseField("alpha");
@@ -419,13 +431,14 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder new LaplaceScorer(reader, terms, field, realWordLikelyhood, separator, alpha);
+ }
}
- public static abstract class SmoothingModel> implements NamedWriteable, ToXContent {
+ public static abstract class SmoothingModel implements NamedWriteable, ToXContent {
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
@@ -471,16 +490,18 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder for details.
*
*/
- public static final class LinearInterpolation extends SmoothingModel {
+ public static final class LinearInterpolation extends SmoothingModel {
private static final String NAME = "linear";
static final LinearInterpolation PROTOTYPE = new LinearInterpolation(0.8, 0.1, 0.1);
private final double trigramLambda;
@@ -563,10 +584,11 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder
+ new LinearInterpoatingScorer(reader, terms, field, realWordLikelyhood, separator, trigramLambda, bigramLambda,
+ unigramLambda);
}
}
diff --git a/core/src/main/java/org/elasticsearch/search/suggest/phrase/StupidBackoffScorer.java b/core/src/main/java/org/elasticsearch/search/suggest/phrase/StupidBackoffScorer.java
index fcf6064d228..5bd3d942b1a 100644
--- a/core/src/main/java/org/elasticsearch/search/suggest/phrase/StupidBackoffScorer.java
+++ b/core/src/main/java/org/elasticsearch/search/suggest/phrase/StupidBackoffScorer.java
@@ -42,6 +42,10 @@ public class StupidBackoffScorer extends WordScorer {
this.discount = discount;
}
+ double discount() {
+ return this.discount;
+ }
+
@Override
protected double scoreBigram(Candidate word, Candidate w_1) throws IOException {
SuggestUtils.join(separator, spare, w_1.term, word.term);
diff --git a/core/src/test/java/org/elasticsearch/search/suggest/phrase/LaplaceModelTests.java b/core/src/test/java/org/elasticsearch/search/suggest/phrase/LaplaceModelTests.java
index e2256e98f6a..87ad654e0cd 100644
--- a/core/src/test/java/org/elasticsearch/search/suggest/phrase/LaplaceModelTests.java
+++ b/core/src/test/java/org/elasticsearch/search/suggest/phrase/LaplaceModelTests.java
@@ -20,11 +20,14 @@
package org.elasticsearch.search.suggest.phrase;
import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.Laplace;
+import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.SmoothingModel;
-public class LaplaceModelTests extends SmoothingModelTest {
+import static org.hamcrest.Matchers.instanceOf;
+
+public class LaplaceModelTests extends SmoothingModelTestCase {
@Override
- protected Laplace createTestModel() {
+ protected SmoothingModel createTestModel() {
return new Laplace(randomDoubleBetween(0.0, 10.0, false));
}
@@ -32,7 +35,15 @@ public class LaplaceModelTests extends SmoothingModelTest {
* mutate the given model so the returned smoothing model is different
*/
@Override
- protected Laplace createMutation(Laplace original) {
+ protected Laplace createMutation(SmoothingModel input) {
+ Laplace original = (Laplace) input;
return new Laplace(original.getAlpha() + 0.1);
}
+
+ @Override
+ void assertWordScorer(WordScorer wordScorer, SmoothingModel input) {
+ Laplace model = (Laplace) input;
+ assertThat(wordScorer, instanceOf(LaplaceScorer.class));
+ assertEquals(model.getAlpha(), ((LaplaceScorer) wordScorer).alpha(), Double.MIN_VALUE);
+ }
}
diff --git a/core/src/test/java/org/elasticsearch/search/suggest/phrase/LinearInterpolationModelTests.java b/core/src/test/java/org/elasticsearch/search/suggest/phrase/LinearInterpolationModelTests.java
index 467bca7f0ab..1112b7a5ed7 100644
--- a/core/src/test/java/org/elasticsearch/search/suggest/phrase/LinearInterpolationModelTests.java
+++ b/core/src/test/java/org/elasticsearch/search/suggest/phrase/LinearInterpolationModelTests.java
@@ -20,15 +20,18 @@
package org.elasticsearch.search.suggest.phrase;
import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.LinearInterpolation;
+import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.SmoothingModel;
-public class LinearInterpolationModelTests extends SmoothingModelTest {
+import static org.hamcrest.Matchers.instanceOf;
+
+public class LinearInterpolationModelTests extends SmoothingModelTestCase {
@Override
- protected LinearInterpolation createTestModel() {
+ protected SmoothingModel createTestModel() {
double trigramLambda = randomDoubleBetween(0.0, 10.0, false);
double bigramLambda = randomDoubleBetween(0.0, 10.0, false);
double unigramLambda = randomDoubleBetween(0.0, 10.0, false);
- // normalize
+ // normalize so parameters sum to 1
double sum = trigramLambda + bigramLambda + unigramLambda;
return new LinearInterpolation(trigramLambda / sum, bigramLambda / sum, unigramLambda / sum);
}
@@ -37,7 +40,8 @@ public class LinearInterpolationModelTests extends SmoothingModelTest> extends ESTestCase {
+public abstract class SmoothingModelTestCase extends ESTestCase {
private static NamedWriteableRegistry namedWriteableRegistry;
- private static IndicesQueriesRegistry indicesQueriesRegistry;
/**
* setup for the whole base test class
@@ -63,33 +76,31 @@ public abstract class SmoothingModelTest> extends E
namedWriteableRegistry.registerPrototype(SmoothingModel.class, LinearInterpolation.PROTOTYPE);
namedWriteableRegistry.registerPrototype(SmoothingModel.class, StupidBackoff.PROTOTYPE);
}
- indicesQueriesRegistry = new IndicesQueriesRegistry(Settings.settingsBuilder().build(), Collections.emptySet(), namedWriteableRegistry);
}
@AfterClass
public static void afterClass() throws Exception {
namedWriteableRegistry = null;
- indicesQueriesRegistry = null;
}
/**
* create random model that is put under test
*/
- protected abstract SM createTestModel();
+ protected abstract SmoothingModel createTestModel();
/**
* mutate the given model so the returned smoothing model is different
*/
- protected abstract SM createMutation(SM original) throws IOException;
+ protected abstract SmoothingModel createMutation(SmoothingModel original) throws IOException;
/**
* Test that creates new smoothing model from a random test smoothing model and checks both for equality
*/
public void testFromXContent() throws IOException {
- QueryParseContext context = new QueryParseContext(indicesQueriesRegistry);
+ QueryParseContext context = new QueryParseContext(new IndicesQueriesRegistry(Settings.settingsBuilder().build(), Collections.emptyMap()));
context.parseFieldMatcher(new ParseFieldMatcher(Settings.EMPTY));
- SM testModel = createTestModel();
+ SmoothingModel testModel = createTestModel();
XContentBuilder contentBuilder = XContentFactory.contentBuilder(randomFrom(XContentType.values()));
if (randomBoolean()) {
contentBuilder.prettyPrint();
@@ -99,21 +110,45 @@ public abstract class SmoothingModelTest> extends E
contentBuilder.endObject();
XContentParser parser = XContentHelper.createParser(contentBuilder.bytes());
context.reset(parser);
- SmoothingModel> prototype = (SmoothingModel>) namedWriteableRegistry.getPrototype(SmoothingModel.class,
+ parser.nextToken(); // go to start token, real parsing would do that in the outer element parser
+ SmoothingModel prototype = (SmoothingModel) namedWriteableRegistry.getPrototype(SmoothingModel.class,
testModel.getWriteableName());
- SmoothingModel> parsedModel = prototype.fromXContent(context);
+ SmoothingModel parsedModel = prototype.fromXContent(context);
assertNotSame(testModel, parsedModel);
assertEquals(testModel, parsedModel);
assertEquals(testModel.hashCode(), parsedModel.hashCode());
}
+ /**
+ * Test the WordScorer emitted by the smoothing model
+ */
+ public void testBuildWordScorer() throws IOException {
+ SmoothingModel testModel = createTestModel();
+
+ Map mapping = new HashMap<>();
+ mapping.put("field", new WhitespaceAnalyzer());
+ PerFieldAnalyzerWrapper wrapper = new PerFieldAnalyzerWrapper(new WhitespaceAnalyzer(), mapping);
+ IndexWriter writer = new IndexWriter(new RAMDirectory(), new IndexWriterConfig(wrapper));
+ Document doc = new Document();
+ doc.add(new Field("field", "someText", TextField.TYPE_NOT_STORED));
+ writer.addDocument(doc);
+ DirectoryReader ir = DirectoryReader.open(writer, false);
+
+ WordScorer wordScorer = testModel.buildWordScorerFactory().newScorer(ir, MultiFields.getTerms(ir , "field"), "field", 0.9d, BytesRefs.toBytesRef(" "));
+ assertWordScorer(wordScorer, testModel);
+ }
+
+ /**
+ * implementation dependant assertions on the wordScorer produced by the smoothing model under test
+ */
+ abstract void assertWordScorer(WordScorer wordScorer, SmoothingModel testModel);
+
/**
* Test serialization and deserialization of the tested model.
*/
- @SuppressWarnings("unchecked")
public void testSerialization() throws IOException {
- SM testModel = createTestModel();
- SM deserializedModel = (SM) copyModel(testModel);
+ SmoothingModel testModel = createTestModel();
+ SmoothingModel deserializedModel = copyModel(testModel);
assertEquals(testModel, deserializedModel);
assertEquals(testModel.hashCode(), deserializedModel.hashCode());
assertNotSame(testModel, deserializedModel);
@@ -124,7 +159,7 @@ public abstract class SmoothingModelTest> extends E
*/
@SuppressWarnings("unchecked")
public void testEqualsAndHashcode() throws IOException {
- SM firstModel = createTestModel();
+ SmoothingModel firstModel = createTestModel();
assertFalse("smoothing model is equal to null", firstModel.equals(null));
assertFalse("smoothing model is equal to incompatible type", firstModel.equals(""));
assertTrue("smoothing model is not equal to self", firstModel.equals(firstModel));
@@ -132,13 +167,13 @@ public abstract class SmoothingModelTest> extends E
equalTo(firstModel.hashCode()));
assertThat("different smoothing models should not be equal", createMutation(firstModel), not(equalTo(firstModel)));
- SM secondModel = (SM) copyModel(firstModel);
+ SmoothingModel secondModel = copyModel(firstModel);
assertTrue("smoothing model is not equal to self", secondModel.equals(secondModel));
assertTrue("smoothing model is not equal to its copy", firstModel.equals(secondModel));
assertTrue("equals is not symmetric", secondModel.equals(firstModel));
assertThat("smoothing model copy's hashcode is different from original hashcode", secondModel.hashCode(), equalTo(firstModel.hashCode()));
- SM thirdModel = (SM) copyModel(secondModel);
+ SmoothingModel thirdModel = copyModel(secondModel);
assertTrue("smoothing model is not equal to self", thirdModel.equals(thirdModel));
assertTrue("smoothing model is not equal to its copy", secondModel.equals(thirdModel));
assertThat("smoothing model copy's hashcode is different from original hashcode", secondModel.hashCode(), equalTo(thirdModel.hashCode()));
@@ -148,11 +183,11 @@ public abstract class SmoothingModelTest> extends E
assertTrue("equals is not symmetric", thirdModel.equals(firstModel));
}
- static SmoothingModel> copyModel(SmoothingModel> original) throws IOException {
+ static SmoothingModel copyModel(SmoothingModel original) throws IOException {
try (BytesStreamOutput output = new BytesStreamOutput()) {
original.writeTo(output);
try (StreamInput in = new NamedWriteableAwareStreamInput(StreamInput.wrap(output.bytes()), namedWriteableRegistry)) {
- SmoothingModel> prototype = (SmoothingModel>) namedWriteableRegistry.getPrototype(SmoothingModel.class, original.getWriteableName());
+ SmoothingModel prototype = (SmoothingModel) namedWriteableRegistry.getPrototype(SmoothingModel.class, original.getWriteableName());
return prototype.readFrom(in);
}
}
diff --git a/core/src/test/java/org/elasticsearch/search/suggest/phrase/StupidBackoffModelTests.java b/core/src/test/java/org/elasticsearch/search/suggest/phrase/StupidBackoffModelTests.java
index 5d774066e07..c3bd66d2a81 100644
--- a/core/src/test/java/org/elasticsearch/search/suggest/phrase/StupidBackoffModelTests.java
+++ b/core/src/test/java/org/elasticsearch/search/suggest/phrase/StupidBackoffModelTests.java
@@ -19,12 +19,15 @@
package org.elasticsearch.search.suggest.phrase;
+import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.SmoothingModel;
import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.StupidBackoff;
-public class StupidBackoffModelTests extends SmoothingModelTest {
+import static org.hamcrest.Matchers.instanceOf;
+
+public class StupidBackoffModelTests extends SmoothingModelTestCase {
@Override
- protected StupidBackoff createTestModel() {
+ protected SmoothingModel createTestModel() {
return new StupidBackoff(randomDoubleBetween(0.0, 10.0, false));
}
@@ -32,7 +35,15 @@ public class StupidBackoffModelTests extends SmoothingModelTest {
* mutate the given model so the returned smoothing model is different
*/
@Override
- protected StupidBackoff createMutation(StupidBackoff original) {
+ protected StupidBackoff createMutation(SmoothingModel input) {
+ StupidBackoff original = (StupidBackoff) input;
return new StupidBackoff(original.getDiscount() + 0.1);
}
+
+ @Override
+ void assertWordScorer(WordScorer wordScorer, SmoothingModel input) {
+ assertThat(wordScorer, instanceOf(StupidBackoffScorer.class));
+ StupidBackoff testModel = (StupidBackoff) input;
+ assertEquals(testModel.getDiscount(), ((StupidBackoffScorer) wordScorer).discount(), Double.MIN_VALUE);
+ }
}