Add serialization of smoothing model to PhraseSuggestionBuilder and add tests

This commit is contained in:
Christoph Büscher 2016-01-29 11:50:49 +01:00
parent 7cae28f96a
commit a9ba1e73e7
8 changed files with 81 additions and 10 deletions

View File

@ -39,6 +39,7 @@ import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.functionscore.ScoreFunctionBuilder; import org.elasticsearch.index.query.functionscore.ScoreFunctionBuilder;
import org.elasticsearch.search.rescore.RescoreBuilder; import org.elasticsearch.search.rescore.RescoreBuilder;
import org.elasticsearch.search.suggest.SuggestionBuilder; import org.elasticsearch.search.suggest.SuggestionBuilder;
import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.SmoothingModel;
import org.joda.time.DateTime; import org.joda.time.DateTime;
import org.joda.time.DateTimeZone; import org.joda.time.DateTimeZone;
@ -706,6 +707,13 @@ public abstract class StreamInput extends InputStream {
return readNamedWriteable(ScoreFunctionBuilder.class); return readNamedWriteable(ScoreFunctionBuilder.class);
} }
/**
* Reads a {@link SmoothingModel} from the current stream
*/
public SmoothingModel readSmoothingModel() throws IOException {
return readNamedWriteable(SmoothingModel.class);
}
/** /**
* Reads a list of objects * Reads a list of objects
*/ */

View File

@ -38,6 +38,7 @@ import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.functionscore.ScoreFunctionBuilder; import org.elasticsearch.index.query.functionscore.ScoreFunctionBuilder;
import org.elasticsearch.search.rescore.RescoreBuilder; import org.elasticsearch.search.rescore.RescoreBuilder;
import org.elasticsearch.search.suggest.SuggestionBuilder; import org.elasticsearch.search.suggest.SuggestionBuilder;
import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.SmoothingModel;
import org.joda.time.ReadableInstant; import org.joda.time.ReadableInstant;
import java.io.EOFException; import java.io.EOFException;
@ -670,6 +671,13 @@ public abstract class StreamOutput extends OutputStream {
writeNamedWriteable(scoreFunctionBuilder); writeNamedWriteable(scoreFunctionBuilder);
} }
/**
* Writes the given {@link SmoothingModel} to the stream
*/
public void writeSmoothingModel(SmoothingModel smoothinModel) throws IOException {
writeNamedWriteable(smoothinModel);
}
/** /**
* Writes the given {@link GeoPoint} to the stream * Writes the given {@link GeoPoint} to the stream
*/ */

View File

@ -216,6 +216,13 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder<PhraseSugge
return this; return this;
} }
/**
* Gets the {@link SmoothingModel}
*/
public SmoothingModel smoothingModel() {
return this.model;
}
public PhraseSuggestionBuilder tokenLimit(int tokenLimit) { public PhraseSuggestionBuilder tokenLimit(int tokenLimit) {
this.tokenLimit = tokenLimit; this.tokenLimit = tokenLimit;
return this; return this;
@ -391,8 +398,8 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder<PhraseSugge
* Default discount parameter for {@link StupidBackoff} smoothing * Default discount parameter for {@link StupidBackoff} smoothing
*/ */
public static final double DEFAULT_BACKOFF_DISCOUNT = 0.4; public static final double DEFAULT_BACKOFF_DISCOUNT = 0.4;
private double discount = DEFAULT_BACKOFF_DISCOUNT;
static final StupidBackoff PROTOTYPE = new StupidBackoff(DEFAULT_BACKOFF_DISCOUNT); static final StupidBackoff PROTOTYPE = new StupidBackoff(DEFAULT_BACKOFF_DISCOUNT);
private double discount = DEFAULT_BACKOFF_DISCOUNT;
private static final String NAME = "stupid_backoff"; private static final String NAME = "stupid_backoff";
private static final ParseField DISCOUNT_FIELD = new ParseField("discount"); private static final ParseField DISCOUNT_FIELD = new ParseField("discount");
@ -743,7 +750,11 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder<PhraseSugge
out.writeOptionalFloat(realWordErrorLikelihood); out.writeOptionalFloat(realWordErrorLikelihood);
out.writeOptionalFloat(confidence); out.writeOptionalFloat(confidence);
out.writeOptionalVInt(gramSize); out.writeOptionalVInt(gramSize);
// NORELEASE model.writeTo(); boolean hasModel = model != null;
out.writeBoolean(hasModel);
if (hasModel) {
out.writeSmoothingModel(model);
}
out.writeOptionalBoolean(forceUnigrams); out.writeOptionalBoolean(forceUnigrams);
out.writeOptionalVInt(tokenLimit); out.writeOptionalVInt(tokenLimit);
out.writeOptionalString(preTag); out.writeOptionalString(preTag);
@ -767,7 +778,9 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder<PhraseSugge
builder.realWordErrorLikelihood = in.readOptionalFloat(); builder.realWordErrorLikelihood = in.readOptionalFloat();
builder.confidence = in.readOptionalFloat(); builder.confidence = in.readOptionalFloat();
builder.gramSize = in.readOptionalVInt(); builder.gramSize = in.readOptionalVInt();
// NORELEASE read model if (in.readBoolean()) {
builder.model = in.readSmoothingModel();
}
builder.forceUnigrams = in.readOptionalBoolean(); builder.forceUnigrams = in.readOptionalBoolean();
builder.tokenLimit = in.readOptionalVInt(); builder.tokenLimit = in.readOptionalVInt();
builder.preTag = in.readOptionalString(); builder.preTag = in.readOptionalString();
@ -790,7 +803,7 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder<PhraseSugge
Objects.equals(confidence, other.confidence) && Objects.equals(confidence, other.confidence) &&
// NORELEASE Objects.equals(generator, other.generator) && // NORELEASE Objects.equals(generator, other.generator) &&
Objects.equals(gramSize, other.gramSize) && Objects.equals(gramSize, other.gramSize) &&
// NORELEASE Objects.equals(model, other.model) && Objects.equals(model, other.model) &&
Objects.equals(forceUnigrams, other.forceUnigrams) && Objects.equals(forceUnigrams, other.forceUnigrams) &&
Objects.equals(tokenLimit, other.tokenLimit) && Objects.equals(tokenLimit, other.tokenLimit) &&
Objects.equals(preTag, other.preTag) && Objects.equals(preTag, other.preTag) &&
@ -803,10 +816,8 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder<PhraseSugge
@Override @Override
protected int doHashCode() { protected int doHashCode() {
return Objects.hash(maxErrors, separator, realWordErrorLikelihood, confidence, return Objects.hash(maxErrors, separator, realWordErrorLikelihood, confidence,
/** NORELEASE generators, */ // NORELEASE generators,
gramSize, gramSize, model, forceUnigrams, tokenLimit, preTag, postTag,
/** NORELEASE model, */
forceUnigrams, tokenLimit, preTag, postTag,
collateQuery, collateParams, collatePrune); collateQuery, collateParams, collatePrune);
} }

View File

@ -38,7 +38,7 @@ import static org.hamcrest.Matchers.not;
public abstract class AbstractSuggestionBuilderTestCase<SB extends SuggestionBuilder<SB>> extends ESTestCase { public abstract class AbstractSuggestionBuilderTestCase<SB extends SuggestionBuilder<SB>> extends ESTestCase {
private static final int NUMBER_OF_TESTBUILDERS = 20; private static final int NUMBER_OF_TESTBUILDERS = 20;
private static NamedWriteableRegistry namedWriteableRegistry; protected static NamedWriteableRegistry namedWriteableRegistry;
/** /**
* setup for the whole base test class * setup for the whole base test class

View File

@ -28,6 +28,11 @@ public class LaplaceModelTests extends SmoothingModelTestCase {
@Override @Override
protected SmoothingModel createTestModel() { protected SmoothingModel createTestModel() {
return createRandomModel();
}
static SmoothingModel createRandomModel() {
return new Laplace(randomDoubleBetween(0.0, 10.0, false)); return new Laplace(randomDoubleBetween(0.0, 10.0, false));
} }

View File

@ -28,6 +28,10 @@ public class LinearInterpolationModelTests extends SmoothingModelTestCase {
@Override @Override
protected SmoothingModel createTestModel() { protected SmoothingModel createTestModel() {
return createRandomModel();
}
static LinearInterpolation createRandomModel() {
double trigramLambda = randomDoubleBetween(0.0, 10.0, false); double trigramLambda = randomDoubleBetween(0.0, 10.0, false);
double bigramLambda = randomDoubleBetween(0.0, 10.0, false); double bigramLambda = randomDoubleBetween(0.0, 10.0, false);
double unigramLambda = randomDoubleBetween(0.0, 10.0, false); double unigramLambda = randomDoubleBetween(0.0, 10.0, false);

View File

@ -21,6 +21,11 @@ package org.elasticsearch.search.suggest.phrase;
import org.elasticsearch.script.Template; import org.elasticsearch.script.Template;
import org.elasticsearch.search.suggest.AbstractSuggestionBuilderTestCase; import org.elasticsearch.search.suggest.AbstractSuggestionBuilderTestCase;
import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.Laplace;
import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.LinearInterpolation;
import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.SmoothingModel;
import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.StupidBackoff;
import org.junit.BeforeClass;
import java.io.IOException; import java.io.IOException;
import java.util.HashMap; import java.util.HashMap;
@ -28,6 +33,13 @@ import java.util.Map;
public class PhraseSuggestionBuilderTests extends AbstractSuggestionBuilderTestCase<PhraseSuggestionBuilder> { public class PhraseSuggestionBuilderTests extends AbstractSuggestionBuilderTestCase<PhraseSuggestionBuilder> {
@BeforeClass
public static void initSmoothingModels() {
namedWriteableRegistry.registerPrototype(SmoothingModel.class, Laplace.PROTOTYPE);
namedWriteableRegistry.registerPrototype(SmoothingModel.class, LinearInterpolation.PROTOTYPE);
namedWriteableRegistry.registerPrototype(SmoothingModel.class, StupidBackoff.PROTOTYPE);
}
@Override @Override
protected PhraseSuggestionBuilder randomSuggestionBuilder() { protected PhraseSuggestionBuilder randomSuggestionBuilder() {
PhraseSuggestionBuilder testBuilder = new PhraseSuggestionBuilder(randomAsciiOfLength(10)); PhraseSuggestionBuilder testBuilder = new PhraseSuggestionBuilder(randomAsciiOfLength(10));
@ -50,7 +62,7 @@ public class PhraseSuggestionBuilderTests extends AbstractSuggestionBuilderTestC
testBuilder.collateParams(collateParams ); testBuilder.collateParams(collateParams );
} }
if (randomBoolean()) { if (randomBoolean()) {
// NORELEASE add random model randomSmoothingModel();
} }
if (randomBoolean()) { if (randomBoolean()) {
@ -59,6 +71,22 @@ public class PhraseSuggestionBuilderTests extends AbstractSuggestionBuilderTestC
return testBuilder; return testBuilder;
} }
private static SmoothingModel randomSmoothingModel() {
SmoothingModel model = null;
switch (randomIntBetween(0,2)) {
case 0:
model = LaplaceModelTests.createRandomModel();
break;
case 1:
model = StupidBackoffModelTests.createRandomModel();
break;
case 2:
model = LinearInterpolationModelTests.createRandomModel();
break;
}
return model;
}
@Override @Override
protected void mutateSpecificParameters(PhraseSuggestionBuilder builder) throws IOException { protected void mutateSpecificParameters(PhraseSuggestionBuilder builder) throws IOException {
switch (randomIntBetween(0, 7)) { switch (randomIntBetween(0, 7)) {
@ -107,6 +135,9 @@ public class PhraseSuggestionBuilderTests extends AbstractSuggestionBuilderTestC
case 10: case 10:
builder.collateParams().put(randomAsciiOfLength(5), randomAsciiOfLength(5)); builder.collateParams().put(randomAsciiOfLength(5), randomAsciiOfLength(5));
break; break;
case 11:
builder.smoothingModel(randomValueOtherThan(builder.smoothingModel(), PhraseSuggestionBuilderTests::randomSmoothingModel));
break;
// TODO mutate random Model && generator // TODO mutate random Model && generator
} }
} }

View File

@ -28,6 +28,10 @@ public class StupidBackoffModelTests extends SmoothingModelTestCase {
@Override @Override
protected SmoothingModel createTestModel() { protected SmoothingModel createTestModel() {
return createRandomModel();
}
static SmoothingModel createRandomModel() {
return new StupidBackoff(randomDoubleBetween(0.0, 10.0, false)); return new StupidBackoff(randomDoubleBetween(0.0, 10.0, false));
} }