Move smoothing model into its own sub-object in the PhraseSuggest request

Closes #2735
This commit is contained in:
Simon Willnauer 2013-03-06 14:31:21 +01:00
parent e1409a9f0e
commit 1f217f6a7b
2 changed files with 109 additions and 100 deletions

View File

@ -89,8 +89,9 @@ public final class PhraseSuggestParser implements SuggestContextParser {
gramSizeSet = true;
} else if ("force_unigrams".equals(fieldName)) {
suggestion.setRequireUnigram(parser.booleanValue());
} else {
throw new ElasticSearchIllegalArgumentException("suggester[phrase] doesn't support field [" + fieldName + "]");
}
}
} else if (token == Token.START_ARRAY) {
if ("direct_generator".equals(fieldName)) {
@ -111,97 +112,8 @@ public final class PhraseSuggestParser implements SuggestContextParser {
} else {
throw new ElasticSearchIllegalArgumentException("suggester[phrase] doesn't support array field [" + fieldName + "]");
}
} else if (token == Token.START_OBJECT) {
if ("linear".equals(fieldName)) {
ensureNoSmoothing(suggestion);
final double[] lambdas = new double[3];
while ((token = parser.nextToken()) != Token.END_OBJECT) {
if (token == XContentParser.Token.FIELD_NAME) {
fieldName = parser.currentName();
}
if (token.isValue()) {
if ("trigram_lambda".equals(fieldName)) {
lambdas[0] = parser.doubleValue();
if (lambdas[0] < 0) {
throw new ElasticSearchIllegalArgumentException("trigram_lambda must be positive");
}
}
if ("bigram_lambda".equals(fieldName)) {
lambdas[1] = parser.doubleValue();
if (lambdas[1] < 0) {
throw new ElasticSearchIllegalArgumentException("bigram_lambda must be positive");
}
}
if ("unigram_lambda".equals(fieldName)) {
lambdas[2] = parser.doubleValue();
if (lambdas[2] < 0) {
throw new ElasticSearchIllegalArgumentException("unigram_lambda must be positive");
}
}
}
}
double sum = 0.0d;
for (int i = 0; i < lambdas.length; i++) {
sum += lambdas[i];
}
if (Math.abs(sum - 1.0) > 0.001) {
throw new ElasticSearchIllegalArgumentException("linear smoothing lambdas must sum to 1");
}
suggestion.setModel(new WordScorer.WordScorerFactory() {
@Override
public WordScorer newScorer(IndexReader reader, String field, double realWordLikelyhood, BytesRef separator)
throws IOException {
return new LinearInterpoatingScorer(reader, field, realWordLikelyhood, separator, lambdas[0], lambdas[1],
lambdas[2]);
}
});
} else if ("laplace".equals(fieldName)) {
ensureNoSmoothing(suggestion);
double theAlpha = 0.5;
while ((token = parser.nextToken()) != Token.END_OBJECT) {
if (token == XContentParser.Token.FIELD_NAME) {
fieldName = parser.currentName();
}
if (token.isValue()) {
if ("alpha".equals(fieldName)) {
theAlpha = parser.doubleValue();
}
}
}
final double alpha = theAlpha;
suggestion.setModel( new WordScorer.WordScorerFactory() {
@Override
public WordScorer newScorer(IndexReader reader, String field, double realWordLikelyhood, BytesRef separator) throws IOException {
return new LaplaceScorer(reader, field, realWordLikelyhood, separator, alpha);
}
});
} else if ("stupid_backoff".equals(fieldName)) {
ensureNoSmoothing(suggestion);
double theDiscount = 0.4;
while ((token = parser.nextToken()) != Token.END_OBJECT) {
if (token == XContentParser.Token.FIELD_NAME) {
fieldName = parser.currentName();
}
if (token.isValue()) {
if ("discount".equals(fieldName)) {
theDiscount = parser.doubleValue();
}
}
}
final double discount = theDiscount;
suggestion.setModel( new WordScorer.WordScorerFactory() {
@Override
public WordScorer newScorer(IndexReader reader, String field, double realWordLikelyhood, BytesRef separator) throws IOException {
return new StupidBackoffScorer(reader, field, realWordLikelyhood, separator, discount);
}
});
} else {
throw new ElasticSearchIllegalArgumentException("suggester[phrase] doesn't support object field [" + fieldName + "]");
}
} else if (token == Token.START_OBJECT && "smoothing".equals(fieldName)) {
parseSmoothingModel(parser, suggestion, fieldName);
} else {
throw new ElasticSearchIllegalArgumentException("suggester[phrase] doesn't support field [" + fieldName + "]");
}
@ -242,6 +154,101 @@ public final class PhraseSuggestParser implements SuggestContextParser {
return suggestion;
}
public void parseSmoothingModel(XContentParser parser, PhraseSuggestionContext suggestion, String fieldName) throws IOException {
XContentParser.Token token;
while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
if (token == XContentParser.Token.FIELD_NAME) {
fieldName = parser.currentName();
break;
}
}
if ("linear".equals(fieldName)) {
ensureNoSmoothing(suggestion);
final double[] lambdas = new double[3];
while ((token = parser.nextToken()) != Token.END_OBJECT) {
if (token == XContentParser.Token.FIELD_NAME) {
fieldName = parser.currentName();
}
if (token.isValue()) {
if ("trigram_lambda".equals(fieldName)) {
lambdas[0] = parser.doubleValue();
if (lambdas[0] < 0) {
throw new ElasticSearchIllegalArgumentException("trigram_lambda must be positive");
}
} else if ("bigram_lambda".equals(fieldName)) {
lambdas[1] = parser.doubleValue();
if (lambdas[1] < 0) {
throw new ElasticSearchIllegalArgumentException("bigram_lambda must be positive");
}
} else if ("unigram_lambda".equals(fieldName)) {
lambdas[2] = parser.doubleValue();
if (lambdas[2] < 0) {
throw new ElasticSearchIllegalArgumentException("unigram_lambda must be positive");
}
} else {
throw new ElasticSearchIllegalArgumentException("suggester[phrase][smoothing][linear] doesn't support field [" + fieldName + "]");
}
}
}
double sum = 0.0d;
for (int i = 0; i < lambdas.length; i++) {
sum += lambdas[i];
}
if (Math.abs(sum - 1.0) > 0.001) {
throw new ElasticSearchIllegalArgumentException("linear smoothing lambdas must sum to 1");
}
suggestion.setModel(new WordScorer.WordScorerFactory() {
@Override
public WordScorer newScorer(IndexReader reader, String field, double realWordLikelyhood, BytesRef separator)
throws IOException {
return new LinearInterpoatingScorer(reader, field, realWordLikelyhood, separator, lambdas[0], lambdas[1],
lambdas[2]);
}
});
} else if ("laplace".equals(fieldName)) {
ensureNoSmoothing(suggestion);
double theAlpha = 0.5;
while ((token = parser.nextToken()) != Token.END_OBJECT) {
if (token == XContentParser.Token.FIELD_NAME) {
fieldName = parser.currentName();
}
if (token.isValue() && "alpha".equals(fieldName)) {
theAlpha = parser.doubleValue();
}
}
final double alpha = theAlpha;
suggestion.setModel( new WordScorer.WordScorerFactory() {
@Override
public WordScorer newScorer(IndexReader reader, String field, double realWordLikelyhood, BytesRef separator) throws IOException {
return new LaplaceScorer(reader, field, realWordLikelyhood, separator, alpha);
}
});
} else if ("stupid_backoff".equals(fieldName)) {
ensureNoSmoothing(suggestion);
double theDiscount = 0.4;
while ((token = parser.nextToken()) != Token.END_OBJECT) {
if (token == XContentParser.Token.FIELD_NAME) {
fieldName = parser.currentName();
}
if (token.isValue() && "discount".equals(fieldName)) {
theDiscount = parser.doubleValue();
}
}
final double discount = theDiscount;
suggestion.setModel( new WordScorer.WordScorerFactory() {
@Override
public WordScorer newScorer(IndexReader reader, String field, double realWordLikelyhood, BytesRef separator) throws IOException {
return new StupidBackoffScorer(reader, field, realWordLikelyhood, separator, discount);
}
});
} else {
throw new ElasticSearchIllegalArgumentException("suggester[phrase] doesn't support object field [" + fieldName + "]");
}
}
private void ensureNoSmoothing(PhraseSuggestionContext suggestion) {
if (suggestion.model() != null) {
throw new ElasticSearchIllegalArgumentException("only one smoothing model supported");

View File

@ -172,7 +172,7 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder<PhraseSugge
}
}
if (model != null) {
builder.startObject(model.type);
builder.startObject("smoothing");
model.toXContent(builder, params);
builder.endObject();
}
@ -214,8 +214,7 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder<PhraseSugge
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder = super.toXContent(builder, params);
protected XContentBuilder innerToXContent(XContentBuilder builder, Params params) throws IOException {
builder.field("discount", discount);
return builder;
}
@ -245,15 +244,14 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder<PhraseSugge
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder = super.toXContent(builder, params);
protected XContentBuilder innerToXContent(XContentBuilder builder, Params params) throws IOException {
builder.field("alpha", alpha);
return builder;
}
}
public static class SmoothingModel implements ToXContent {
public static abstract class SmoothingModel implements ToXContent {
private final String type;
protected SmoothingModel(String type) {
@ -262,8 +260,13 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder<PhraseSugge
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(type);
innerToXContent(builder,params);
builder.endObject();
return builder;
}
protected abstract XContentBuilder innerToXContent(XContentBuilder builder, Params params) throws IOException;
}
/**
@ -299,8 +302,7 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder<PhraseSugge
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder = super.toXContent(builder, params);
protected XContentBuilder innerToXContent(XContentBuilder builder, Params params) throws IOException {
builder.field("trigram_lambda", trigramLambda);
builder.field("bigram_lambda", bigramLambda);
builder.field("unigram_lambda", unigramLambda);