diff --git a/core/src/main/java/org/elasticsearch/search/suggest/phrase/LaplaceScorer.java b/core/src/main/java/org/elasticsearch/search/suggest/phrase/LaplaceScorer.java index 04d98c3827d..678f3082bac 100644 --- a/core/src/main/java/org/elasticsearch/search/suggest/phrase/LaplaceScorer.java +++ b/core/src/main/java/org/elasticsearch/search/suggest/phrase/LaplaceScorer.java @@ -27,14 +27,14 @@ import org.elasticsearch.search.suggest.phrase.DirectCandidateGenerator.Candidat import java.io.IOException; //TODO public for tests public final class LaplaceScorer extends WordScorer { - + public static final WordScorerFactory FACTORY = new WordScorer.WordScorerFactory() { @Override public WordScorer newScorer(IndexReader reader, Terms terms, String field, double realWordLikelyhood, BytesRef separator) throws IOException { return new LaplaceScorer(reader, terms, field, realWordLikelyhood, separator, 0.5); } }; - + private double alpha; public LaplaceScorer(IndexReader reader, Terms terms, String field, @@ -42,7 +42,11 @@ public final class LaplaceScorer extends WordScorer { super(reader, terms, field, realWordLikelyhood, separator); this.alpha = alpha; } - + + double alpha() { + return this.alpha; + } + @Override protected double scoreBigram(Candidate word, Candidate w_1) throws IOException { SuggestUtils.join(separator, spare, w_1.term, word.term); diff --git a/core/src/main/java/org/elasticsearch/search/suggest/phrase/LinearInterpoatingScorer.java b/core/src/main/java/org/elasticsearch/search/suggest/phrase/LinearInterpoatingScorer.java index d2b1ba48b13..368d461fc53 100644 --- a/core/src/main/java/org/elasticsearch/search/suggest/phrase/LinearInterpoatingScorer.java +++ b/core/src/main/java/org/elasticsearch/search/suggest/phrase/LinearInterpoatingScorer.java @@ -41,7 +41,19 @@ public final class LinearInterpoatingScorer extends WordScorer { this.bigramLambda = bigramLambda / sum; this.trigramLambda = trigramLambda / sum; } - + + double trigramLambda() { + return this.trigramLambda; + } + + double bigramLambda() { + return this.bigramLambda; + } + + double unigramLambda() { + return this.unigramLambda; + } + @Override protected double scoreBigram(Candidate word, Candidate w_1) throws IOException { SuggestUtils.join(separator, spare, w_1.term, word.term); diff --git a/core/src/main/java/org/elasticsearch/search/suggest/phrase/PhraseSuggestionBuilder.java b/core/src/main/java/org/elasticsearch/search/suggest/phrase/PhraseSuggestionBuilder.java index 97ca09d25a1..0e1fec6c7b2 100644 --- a/core/src/main/java/org/elasticsearch/search/suggest/phrase/PhraseSuggestionBuilder.java +++ b/core/src/main/java/org/elasticsearch/search/suggest/phrase/PhraseSuggestionBuilder.java @@ -18,8 +18,12 @@ */ package org.elasticsearch.search.suggest.phrase; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.Terms; +import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.ParseFieldMatcher; +import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.io.stream.NamedWriteable; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -30,6 +34,7 @@ import org.elasticsearch.common.xcontent.XContentParser.Token; import org.elasticsearch.index.query.QueryParseContext; import org.elasticsearch.script.Template; import org.elasticsearch.search.suggest.SuggestBuilder.SuggestionBuilder; +import org.elasticsearch.search.suggest.phrase.WordScorer.WordScorerFactory; import java.io.IOException; import java.util.ArrayList; @@ -50,7 +55,7 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder> generators = new HashMap<>(); private Integer gramSize; - private SmoothingModel model; + private SmoothingModel model; private Boolean forceUnigrams; private Integer tokenLimit; private String preTag; @@ -159,7 +164,7 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder model) { + public PhraseSuggestionBuilder smoothingModel(SmoothingModel model) { this.model = model; return this; } @@ -292,7 +297,7 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder for details. *

*/ - public static final class StupidBackoff extends SmoothingModel { + public static final class StupidBackoff extends SmoothingModel { /** * Default discount parameter for {@link StupidBackoff} smoothing */ @@ -341,8 +346,9 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder new StupidBackoffScorer(reader, terms, field, realWordLikelyhood, separator, discount); + } } /** @@ -377,7 +389,7 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder for details. *

*/ - public static final class Laplace extends SmoothingModel { + public static final class Laplace extends SmoothingModel { private double alpha = DEFAULT_LAPLACE_ALPHA; private static final String NAME = "laplace"; private static final ParseField ALPHA_FIELD = new ParseField("alpha"); @@ -419,13 +431,14 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder new LaplaceScorer(reader, terms, field, realWordLikelyhood, separator, alpha); + } } - public static abstract class SmoothingModel> implements NamedWriteable, ToXContent { + public static abstract class SmoothingModel implements NamedWriteable, ToXContent { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { @@ -471,16 +490,18 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder for details. *

*/ - public static final class LinearInterpolation extends SmoothingModel { + public static final class LinearInterpolation extends SmoothingModel { private static final String NAME = "linear"; static final LinearInterpolation PROTOTYPE = new LinearInterpolation(0.8, 0.1, 0.1); private final double trigramLambda; @@ -563,10 +584,11 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder + new LinearInterpoatingScorer(reader, terms, field, realWordLikelyhood, separator, trigramLambda, bigramLambda, + unigramLambda); } } diff --git a/core/src/main/java/org/elasticsearch/search/suggest/phrase/StupidBackoffScorer.java b/core/src/main/java/org/elasticsearch/search/suggest/phrase/StupidBackoffScorer.java index fcf6064d228..5bd3d942b1a 100644 --- a/core/src/main/java/org/elasticsearch/search/suggest/phrase/StupidBackoffScorer.java +++ b/core/src/main/java/org/elasticsearch/search/suggest/phrase/StupidBackoffScorer.java @@ -42,6 +42,10 @@ public class StupidBackoffScorer extends WordScorer { this.discount = discount; } + double discount() { + return this.discount; + } + @Override protected double scoreBigram(Candidate word, Candidate w_1) throws IOException { SuggestUtils.join(separator, spare, w_1.term, word.term); diff --git a/core/src/test/java/org/elasticsearch/search/suggest/phrase/LaplaceModelTests.java b/core/src/test/java/org/elasticsearch/search/suggest/phrase/LaplaceModelTests.java index e2256e98f6a..87ad654e0cd 100644 --- a/core/src/test/java/org/elasticsearch/search/suggest/phrase/LaplaceModelTests.java +++ b/core/src/test/java/org/elasticsearch/search/suggest/phrase/LaplaceModelTests.java @@ -20,11 +20,14 @@ package org.elasticsearch.search.suggest.phrase; import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.Laplace; +import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.SmoothingModel; -public class LaplaceModelTests extends SmoothingModelTest { +import static org.hamcrest.Matchers.instanceOf; + +public class LaplaceModelTests extends SmoothingModelTestCase { @Override - protected Laplace createTestModel() { + protected SmoothingModel createTestModel() { return new Laplace(randomDoubleBetween(0.0, 10.0, false)); } @@ -32,7 +35,15 @@ public class LaplaceModelTests extends SmoothingModelTest { * mutate the given model so the returned smoothing model is different */ @Override - protected Laplace createMutation(Laplace original) { + protected Laplace createMutation(SmoothingModel input) { + Laplace original = (Laplace) input; return new Laplace(original.getAlpha() + 0.1); } + + @Override + void assertWordScorer(WordScorer wordScorer, SmoothingModel input) { + Laplace model = (Laplace) input; + assertThat(wordScorer, instanceOf(LaplaceScorer.class)); + assertEquals(model.getAlpha(), ((LaplaceScorer) wordScorer).alpha(), Double.MIN_VALUE); + } } diff --git a/core/src/test/java/org/elasticsearch/search/suggest/phrase/LinearInterpolationModelTests.java b/core/src/test/java/org/elasticsearch/search/suggest/phrase/LinearInterpolationModelTests.java index 467bca7f0ab..1112b7a5ed7 100644 --- a/core/src/test/java/org/elasticsearch/search/suggest/phrase/LinearInterpolationModelTests.java +++ b/core/src/test/java/org/elasticsearch/search/suggest/phrase/LinearInterpolationModelTests.java @@ -20,15 +20,18 @@ package org.elasticsearch.search.suggest.phrase; import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.LinearInterpolation; +import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.SmoothingModel; -public class LinearInterpolationModelTests extends SmoothingModelTest { +import static org.hamcrest.Matchers.instanceOf; + +public class LinearInterpolationModelTests extends SmoothingModelTestCase { @Override - protected LinearInterpolation createTestModel() { + protected SmoothingModel createTestModel() { double trigramLambda = randomDoubleBetween(0.0, 10.0, false); double bigramLambda = randomDoubleBetween(0.0, 10.0, false); double unigramLambda = randomDoubleBetween(0.0, 10.0, false); - // normalize + // normalize so parameters sum to 1 double sum = trigramLambda + bigramLambda + unigramLambda; return new LinearInterpolation(trigramLambda / sum, bigramLambda / sum, unigramLambda / sum); } @@ -37,7 +40,8 @@ public class LinearInterpolationModelTests extends SmoothingModelTest> extends ESTestCase { +public abstract class SmoothingModelTestCase extends ESTestCase { private static NamedWriteableRegistry namedWriteableRegistry; - private static IndicesQueriesRegistry indicesQueriesRegistry; /** * setup for the whole base test class @@ -63,33 +76,31 @@ public abstract class SmoothingModelTest> extends E namedWriteableRegistry.registerPrototype(SmoothingModel.class, LinearInterpolation.PROTOTYPE); namedWriteableRegistry.registerPrototype(SmoothingModel.class, StupidBackoff.PROTOTYPE); } - indicesQueriesRegistry = new IndicesQueriesRegistry(Settings.settingsBuilder().build(), Collections.emptySet(), namedWriteableRegistry); } @AfterClass public static void afterClass() throws Exception { namedWriteableRegistry = null; - indicesQueriesRegistry = null; } /** * create random model that is put under test */ - protected abstract SM createTestModel(); + protected abstract SmoothingModel createTestModel(); /** * mutate the given model so the returned smoothing model is different */ - protected abstract SM createMutation(SM original) throws IOException; + protected abstract SmoothingModel createMutation(SmoothingModel original) throws IOException; /** * Test that creates new smoothing model from a random test smoothing model and checks both for equality */ public void testFromXContent() throws IOException { - QueryParseContext context = new QueryParseContext(indicesQueriesRegistry); + QueryParseContext context = new QueryParseContext(new IndicesQueriesRegistry(Settings.settingsBuilder().build(), Collections.emptyMap())); context.parseFieldMatcher(new ParseFieldMatcher(Settings.EMPTY)); - SM testModel = createTestModel(); + SmoothingModel testModel = createTestModel(); XContentBuilder contentBuilder = XContentFactory.contentBuilder(randomFrom(XContentType.values())); if (randomBoolean()) { contentBuilder.prettyPrint(); @@ -99,21 +110,45 @@ public abstract class SmoothingModelTest> extends E contentBuilder.endObject(); XContentParser parser = XContentHelper.createParser(contentBuilder.bytes()); context.reset(parser); - SmoothingModel prototype = (SmoothingModel) namedWriteableRegistry.getPrototype(SmoothingModel.class, + parser.nextToken(); // go to start token, real parsing would do that in the outer element parser + SmoothingModel prototype = (SmoothingModel) namedWriteableRegistry.getPrototype(SmoothingModel.class, testModel.getWriteableName()); - SmoothingModel parsedModel = prototype.fromXContent(context); + SmoothingModel parsedModel = prototype.fromXContent(context); assertNotSame(testModel, parsedModel); assertEquals(testModel, parsedModel); assertEquals(testModel.hashCode(), parsedModel.hashCode()); } + /** + * Test the WordScorer emitted by the smoothing model + */ + public void testBuildWordScorer() throws IOException { + SmoothingModel testModel = createTestModel(); + + Map mapping = new HashMap<>(); + mapping.put("field", new WhitespaceAnalyzer()); + PerFieldAnalyzerWrapper wrapper = new PerFieldAnalyzerWrapper(new WhitespaceAnalyzer(), mapping); + IndexWriter writer = new IndexWriter(new RAMDirectory(), new IndexWriterConfig(wrapper)); + Document doc = new Document(); + doc.add(new Field("field", "someText", TextField.TYPE_NOT_STORED)); + writer.addDocument(doc); + DirectoryReader ir = DirectoryReader.open(writer, false); + + WordScorer wordScorer = testModel.buildWordScorerFactory().newScorer(ir, MultiFields.getTerms(ir , "field"), "field", 0.9d, BytesRefs.toBytesRef(" ")); + assertWordScorer(wordScorer, testModel); + } + + /** + * implementation dependant assertions on the wordScorer produced by the smoothing model under test + */ + abstract void assertWordScorer(WordScorer wordScorer, SmoothingModel testModel); + /** * Test serialization and deserialization of the tested model. */ - @SuppressWarnings("unchecked") public void testSerialization() throws IOException { - SM testModel = createTestModel(); - SM deserializedModel = (SM) copyModel(testModel); + SmoothingModel testModel = createTestModel(); + SmoothingModel deserializedModel = copyModel(testModel); assertEquals(testModel, deserializedModel); assertEquals(testModel.hashCode(), deserializedModel.hashCode()); assertNotSame(testModel, deserializedModel); @@ -124,7 +159,7 @@ public abstract class SmoothingModelTest> extends E */ @SuppressWarnings("unchecked") public void testEqualsAndHashcode() throws IOException { - SM firstModel = createTestModel(); + SmoothingModel firstModel = createTestModel(); assertFalse("smoothing model is equal to null", firstModel.equals(null)); assertFalse("smoothing model is equal to incompatible type", firstModel.equals("")); assertTrue("smoothing model is not equal to self", firstModel.equals(firstModel)); @@ -132,13 +167,13 @@ public abstract class SmoothingModelTest> extends E equalTo(firstModel.hashCode())); assertThat("different smoothing models should not be equal", createMutation(firstModel), not(equalTo(firstModel))); - SM secondModel = (SM) copyModel(firstModel); + SmoothingModel secondModel = copyModel(firstModel); assertTrue("smoothing model is not equal to self", secondModel.equals(secondModel)); assertTrue("smoothing model is not equal to its copy", firstModel.equals(secondModel)); assertTrue("equals is not symmetric", secondModel.equals(firstModel)); assertThat("smoothing model copy's hashcode is different from original hashcode", secondModel.hashCode(), equalTo(firstModel.hashCode())); - SM thirdModel = (SM) copyModel(secondModel); + SmoothingModel thirdModel = copyModel(secondModel); assertTrue("smoothing model is not equal to self", thirdModel.equals(thirdModel)); assertTrue("smoothing model is not equal to its copy", secondModel.equals(thirdModel)); assertThat("smoothing model copy's hashcode is different from original hashcode", secondModel.hashCode(), equalTo(thirdModel.hashCode())); @@ -148,11 +183,11 @@ public abstract class SmoothingModelTest> extends E assertTrue("equals is not symmetric", thirdModel.equals(firstModel)); } - static SmoothingModel copyModel(SmoothingModel original) throws IOException { + static SmoothingModel copyModel(SmoothingModel original) throws IOException { try (BytesStreamOutput output = new BytesStreamOutput()) { original.writeTo(output); try (StreamInput in = new NamedWriteableAwareStreamInput(StreamInput.wrap(output.bytes()), namedWriteableRegistry)) { - SmoothingModel prototype = (SmoothingModel) namedWriteableRegistry.getPrototype(SmoothingModel.class, original.getWriteableName()); + SmoothingModel prototype = (SmoothingModel) namedWriteableRegistry.getPrototype(SmoothingModel.class, original.getWriteableName()); return prototype.readFrom(in); } } diff --git a/core/src/test/java/org/elasticsearch/search/suggest/phrase/StupidBackoffModelTests.java b/core/src/test/java/org/elasticsearch/search/suggest/phrase/StupidBackoffModelTests.java index 5d774066e07..c3bd66d2a81 100644 --- a/core/src/test/java/org/elasticsearch/search/suggest/phrase/StupidBackoffModelTests.java +++ b/core/src/test/java/org/elasticsearch/search/suggest/phrase/StupidBackoffModelTests.java @@ -19,12 +19,15 @@ package org.elasticsearch.search.suggest.phrase; +import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.SmoothingModel; import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.StupidBackoff; -public class StupidBackoffModelTests extends SmoothingModelTest { +import static org.hamcrest.Matchers.instanceOf; + +public class StupidBackoffModelTests extends SmoothingModelTestCase { @Override - protected StupidBackoff createTestModel() { + protected SmoothingModel createTestModel() { return new StupidBackoff(randomDoubleBetween(0.0, 10.0, false)); } @@ -32,7 +35,15 @@ public class StupidBackoffModelTests extends SmoothingModelTest { * mutate the given model so the returned smoothing model is different */ @Override - protected StupidBackoff createMutation(StupidBackoff original) { + protected StupidBackoff createMutation(SmoothingModel input) { + StupidBackoff original = (StupidBackoff) input; return new StupidBackoff(original.getDiscount() + 0.1); } + + @Override + void assertWordScorer(WordScorer wordScorer, SmoothingModel input) { + assertThat(wordScorer, instanceOf(StupidBackoffScorer.class)); + StupidBackoff testModel = (StupidBackoff) input; + assertEquals(testModel.getDiscount(), ((StupidBackoffScorer) wordScorer).discount(), Double.MIN_VALUE); + } }