* [ML] Add bwc serialization unit test scaffold (#51889) Adds new `AbstractBWCSerializationTestCase` which provides easy scaffolding for BWC serialization unit tests. These are no replacement for true BWC tests (which execute actual old code). These tests do provide some good coverage for the current code when serializing to/from old versions. * removing unnecessary override for 7.series branch * adding necessary import Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
This commit is contained in:
parent
c6111eb90e
commit
dffcd021df
|
@ -157,6 +157,11 @@ public class DataFrameAnalyticsSource implements Writeable, ToXContentObject {
|
|||
return queryProvider.getParsingException();
|
||||
}
|
||||
|
||||
// visible for testing
|
||||
QueryProvider getQueryProvider() {
|
||||
return queryProvider;
|
||||
}
|
||||
|
||||
/**
|
||||
* Calls the parser and returns any gathered deprecations
|
||||
*
|
||||
|
|
|
@ -190,6 +190,15 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable {
|
|||
|
||||
private Builder() {}
|
||||
|
||||
Builder(BoostedTreeParams params) {
|
||||
this.lambda = params.lambda;
|
||||
this.gamma = params.gamma;
|
||||
this.eta = params.eta;
|
||||
this.maximumNumberTrees = params.maximumNumberTrees;
|
||||
this.featureBagFraction = params.featureBagFraction;
|
||||
this.numTopFeatureImportanceValues = params.numTopFeatureImportanceValues;
|
||||
}
|
||||
|
||||
public Builder setLambda(Double lambda) {
|
||||
this.lambda = lambda;
|
||||
return this;
|
||||
|
|
|
@ -152,8 +152,7 @@ public class Classification implements DataFrameAnalysis {
|
|||
return trainingPercent;
|
||||
}
|
||||
|
||||
@Nullable
|
||||
public Long getRandomizeSeed() {
|
||||
public long getRandomizeSeed() {
|
||||
return randomizeSeed;
|
||||
}
|
||||
|
||||
|
|
|
@ -271,6 +271,17 @@ public class OutlierDetection implements DataFrameAnalysis {
|
|||
private double outlierFraction = 0.05;
|
||||
private boolean standardizationEnabled = true;
|
||||
|
||||
public Builder() {}
|
||||
|
||||
public Builder(OutlierDetection other) {
|
||||
this.nNeighbors = other.nNeighbors;
|
||||
this.method = other.method;
|
||||
this.featureInfluenceThreshold = other.featureInfluenceThreshold;
|
||||
this.computeFeatureInfluence = other.computeFeatureInfluence;
|
||||
this.outlierFraction = other.outlierFraction;
|
||||
this.standardizationEnabled = other.standardizationEnabled;
|
||||
}
|
||||
|
||||
public Builder setNNeighbors(Integer nNeighbors) {
|
||||
this.nNeighbors = nNeighbors;
|
||||
return this;
|
||||
|
|
|
@ -116,8 +116,7 @@ public class Regression implements DataFrameAnalysis {
|
|||
return trainingPercent;
|
||||
}
|
||||
|
||||
@Nullable
|
||||
public Long getRandomizeSeed() {
|
||||
public long getRandomizeSeed() {
|
||||
return randomizeSeed;
|
||||
}
|
||||
|
||||
|
@ -222,7 +221,7 @@ public class Regression implements DataFrameAnalysis {
|
|||
&& Objects.equals(boostedTreeParams, that.boostedTreeParams)
|
||||
&& Objects.equals(predictionFieldName, that.predictionFieldName)
|
||||
&& trainingPercent == that.trainingPercent
|
||||
&& randomizeSeed == randomizeSeed;
|
||||
&& randomizeSeed == that.randomizeSeed;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -0,0 +1,74 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml;
|
||||
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.ToXContent;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static org.elasticsearch.Version.getDeclaredVersions;
|
||||
|
||||
public abstract class AbstractBWCSerializationTestCase<T extends Writeable & ToXContent> extends AbstractSerializingTestCase<T> {
|
||||
|
||||
private static final List<Version> ALL_VERSIONS = Collections.unmodifiableList(getDeclaredVersions(Version.class));
|
||||
|
||||
public static List<Version> getAllBWCVersions(Version version) {
|
||||
return ALL_VERSIONS.stream().filter(v -> v.before(version) && version.isCompatible(v)).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private static final List<Version> DEFAULT_BWC_VERSIONS = getAllBWCVersions(Version.CURRENT);
|
||||
|
||||
/**
|
||||
* Returns the expected instance if serialized from the given version.
|
||||
*/
|
||||
protected abstract T mutateInstanceForVersion(T instance, Version version);
|
||||
|
||||
/**
|
||||
* The bwc versions to test serialization against
|
||||
*/
|
||||
protected List<Version> bwcVersions() {
|
||||
return DEFAULT_BWC_VERSIONS;
|
||||
}
|
||||
|
||||
/**
|
||||
* Test serialization and deserialization of the test instance across versions
|
||||
*/
|
||||
public final void testBwcSerialization() throws IOException {
|
||||
for (int runs = 0; runs < NUMBER_OF_TEST_RUNS; runs++) {
|
||||
T testInstance = createTestInstance();
|
||||
for (Version bwcVersion : bwcVersions()) {
|
||||
assertBwcSerialization(testInstance, bwcVersion);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Assert that instances copied at a particular version are equal. The version is useful
|
||||
* for sanity checking the backwards compatibility of the wire. It isn't a substitute for
|
||||
* real backwards compatibility tests but it is *so* much faster.
|
||||
*/
|
||||
protected final void assertBwcSerialization(T testInstance, Version version) throws IOException {
|
||||
T deserializedInstance = copyWriteable(testInstance, getNamedWriteableRegistry(), instanceReader(), version);
|
||||
assertOnBWCObject(deserializedInstance, mutateInstanceForVersion(testInstance, version), version);
|
||||
}
|
||||
|
||||
/**
|
||||
* @param bwcSerializedObject The object deserialized from the previous version
|
||||
* @param testInstance The original test instance
|
||||
* @param version The version which serialized
|
||||
*/
|
||||
protected void assertOnBWCObject(T bwcSerializedObject, T testInstance, Version version) {
|
||||
assertNotSame(version.toString(), bwcSerializedObject, testInstance);
|
||||
assertEquals(version.toString(), bwcSerializedObject, testInstance);
|
||||
assertEquals(version.toString(), bwcSerializedObject.hashCode(), testInstance.hashCode());
|
||||
}
|
||||
}
|
|
@ -31,10 +31,15 @@ import org.elasticsearch.common.xcontent.json.JsonXContent;
|
|||
import org.elasticsearch.index.query.MatchAllQueryBuilder;
|
||||
import org.elasticsearch.search.SearchModule;
|
||||
import org.elasticsearch.search.fetch.subphase.FetchSourceContext;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.ClassificationTests;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetectionTests;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.RegressionTests;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
|
||||
import org.junit.Before;
|
||||
|
||||
|
@ -56,7 +61,7 @@ import static org.hamcrest.Matchers.notNullValue;
|
|||
import static org.hamcrest.Matchers.nullValue;
|
||||
import static org.hamcrest.Matchers.startsWith;
|
||||
|
||||
public class DataFrameAnalyticsConfigTests extends AbstractSerializingTestCase<DataFrameAnalyticsConfig> {
|
||||
public class DataFrameAnalyticsConfigTests extends AbstractBWCSerializationTestCase<DataFrameAnalyticsConfig> {
|
||||
|
||||
@Override
|
||||
protected DataFrameAnalyticsConfig doParseInstance(XContentParser parser) throws IOException {
|
||||
|
@ -88,6 +93,83 @@ public class DataFrameAnalyticsConfigTests extends AbstractSerializingTestCase<D
|
|||
return createRandom(randomValidId(), lenient);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected DataFrameAnalyticsConfig mutateInstanceForVersion(DataFrameAnalyticsConfig instance, Version version) {
|
||||
DataFrameAnalyticsConfig.Builder builder = new DataFrameAnalyticsConfig.Builder(instance)
|
||||
.setSource(DataFrameAnalyticsSourceTests.mutateForVersion(instance.getSource(), version))
|
||||
.setDest(DataFrameAnalyticsDestTests.mutateForVersion(instance.getDest(), version));
|
||||
if (instance.getAnalysis() instanceof OutlierDetection) {
|
||||
builder.setAnalysis(OutlierDetectionTests.mutateForVersion((OutlierDetection)instance.getAnalysis(), version));
|
||||
}
|
||||
if (instance.getAnalysis() instanceof Regression) {
|
||||
builder.setAnalysis(RegressionTests.mutateForVersion((Regression)instance.getAnalysis(), version));
|
||||
}
|
||||
if (instance.getAnalysis() instanceof Classification) {
|
||||
builder.setAnalysis(ClassificationTests.mutateForVersion((Classification)instance.getAnalysis(), version));
|
||||
}
|
||||
if (version.before(Version.V_7_5_0)) {
|
||||
builder.setAllowLazyStart(false);
|
||||
}
|
||||
if (version.before(Version.V_7_4_0)) {
|
||||
builder.setDescription(null);
|
||||
}
|
||||
if (version.before(Version.V_7_3_0)) {
|
||||
builder.setCreateTime(null);
|
||||
builder.setVersion(null);
|
||||
}
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void assertOnBWCObject(DataFrameAnalyticsConfig bwcSerializedObject, DataFrameAnalyticsConfig testInstance, Version version) {
|
||||
|
||||
// Don't have to worry about Regression/Classifications Seeds
|
||||
if (version.onOrAfter(Version.V_7_6_0) || testInstance.getAnalysis() instanceof OutlierDetection) {
|
||||
super.assertOnBWCObject(bwcSerializedObject, testInstance, version);
|
||||
return;
|
||||
}
|
||||
DataFrameAnalysis bwcAnalysis;
|
||||
DataFrameAnalysis testAnalysis;
|
||||
if (testInstance.getAnalysis() instanceof Regression) {
|
||||
Regression testRegression = (Regression)testInstance.getAnalysis();
|
||||
Regression bwcRegression = (Regression)bwcSerializedObject.getAnalysis();
|
||||
|
||||
bwcAnalysis = new Regression(bwcRegression.getDependentVariable(),
|
||||
bwcRegression.getBoostedTreeParams(),
|
||||
bwcRegression.getPredictionFieldName(),
|
||||
bwcRegression.getTrainingPercent(),
|
||||
42L);
|
||||
testAnalysis = new Regression(testRegression.getDependentVariable(),
|
||||
testRegression.getBoostedTreeParams(),
|
||||
testRegression.getPredictionFieldName(),
|
||||
testRegression.getTrainingPercent(),
|
||||
42L);
|
||||
} else {
|
||||
Classification testClassification = (Classification)testInstance.getAnalysis();
|
||||
Classification bwcClassification = (Classification)bwcSerializedObject.getAnalysis();
|
||||
bwcAnalysis = new Classification(bwcClassification.getDependentVariable(),
|
||||
bwcClassification.getBoostedTreeParams(),
|
||||
bwcClassification.getPredictionFieldName(),
|
||||
bwcClassification.getNumTopClasses(),
|
||||
bwcClassification.getTrainingPercent(),
|
||||
42L);
|
||||
testAnalysis = new Classification(testClassification.getDependentVariable(),
|
||||
testClassification.getBoostedTreeParams(),
|
||||
testClassification.getPredictionFieldName(),
|
||||
testClassification.getNumTopClasses(),
|
||||
testClassification.getTrainingPercent(),
|
||||
42L);
|
||||
}
|
||||
super.assertOnBWCObject(new DataFrameAnalyticsConfig.Builder(bwcSerializedObject)
|
||||
.setAnalysis(bwcAnalysis)
|
||||
.build(),
|
||||
new DataFrameAnalyticsConfig.Builder(testInstance)
|
||||
.setAnalysis(testAnalysis)
|
||||
.build(),
|
||||
version);
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<DataFrameAnalyticsConfig> instanceReader() {
|
||||
return DataFrameAnalyticsConfig::new;
|
||||
|
@ -98,19 +180,23 @@ public class DataFrameAnalyticsConfigTests extends AbstractSerializingTestCase<D
|
|||
}
|
||||
|
||||
public static DataFrameAnalyticsConfig createRandom(String id, boolean withGeneratedFields) {
|
||||
return createRandomBuilder(id, withGeneratedFields).build();
|
||||
return createRandomBuilder(id, withGeneratedFields, randomFrom(OutlierDetectionTests.createRandom(),
|
||||
RegressionTests.createRandom(),
|
||||
ClassificationTests.createRandom())).build();
|
||||
}
|
||||
|
||||
public static DataFrameAnalyticsConfig.Builder createRandomBuilder(String id) {
|
||||
return createRandomBuilder(id, false);
|
||||
return createRandomBuilder(id, false, randomFrom(OutlierDetectionTests.createRandom(),
|
||||
RegressionTests.createRandom(),
|
||||
ClassificationTests.createRandom()));
|
||||
}
|
||||
|
||||
public static DataFrameAnalyticsConfig.Builder createRandomBuilder(String id, boolean withGeneratedFields) {
|
||||
public static DataFrameAnalyticsConfig.Builder createRandomBuilder(String id, boolean withGeneratedFields, DataFrameAnalysis analysis) {
|
||||
DataFrameAnalyticsSource source = DataFrameAnalyticsSourceTests.createRandom();
|
||||
DataFrameAnalyticsDest dest = DataFrameAnalyticsDestTests.createRandom();
|
||||
DataFrameAnalyticsConfig.Builder builder = new DataFrameAnalyticsConfig.Builder()
|
||||
.setId(id)
|
||||
.setAnalysis(OutlierDetectionTests.createRandom())
|
||||
.setAnalysis(analysis)
|
||||
.setSource(source)
|
||||
.setDest(dest);
|
||||
if (randomBoolean()) {
|
||||
|
|
|
@ -5,13 +5,14 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe;
|
||||
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public class DataFrameAnalyticsDestTests extends AbstractSerializingTestCase<DataFrameAnalyticsDest> {
|
||||
public class DataFrameAnalyticsDestTests extends AbstractBWCSerializationTestCase<DataFrameAnalyticsDest> {
|
||||
|
||||
@Override
|
||||
protected DataFrameAnalyticsDest doParseInstance(XContentParser parser) throws IOException {
|
||||
|
@ -29,8 +30,17 @@ public class DataFrameAnalyticsDestTests extends AbstractSerializingTestCase<Dat
|
|||
return new DataFrameAnalyticsDest(index, resultsField);
|
||||
}
|
||||
|
||||
public static DataFrameAnalyticsDest mutateForVersion(DataFrameAnalyticsDest instance, Version version) {
|
||||
return instance;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<DataFrameAnalyticsDest> instanceReader() {
|
||||
return DataFrameAnalyticsDest::new;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected DataFrameAnalyticsDest mutateInstanceForVersion(DataFrameAnalyticsDest instance, Version version) {
|
||||
return mutateForVersion(instance, version);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe;
|
||||
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
|
@ -13,7 +14,7 @@ import org.elasticsearch.common.xcontent.XContentParser;
|
|||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.search.SearchModule;
|
||||
import org.elasticsearch.search.fetch.subphase.FetchSourceContext;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.utils.QueryProvider;
|
||||
|
||||
import java.io.IOException;
|
||||
|
@ -25,7 +26,7 @@ import java.util.List;
|
|||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
|
||||
public class DataFrameAnalyticsSourceTests extends AbstractSerializingTestCase<DataFrameAnalyticsSource> {
|
||||
public class DataFrameAnalyticsSourceTests extends AbstractBWCSerializationTestCase<DataFrameAnalyticsSource> {
|
||||
|
||||
@Override
|
||||
protected NamedWriteableRegistry getNamedWriteableRegistry() {
|
||||
|
@ -69,6 +70,13 @@ public class DataFrameAnalyticsSourceTests extends AbstractSerializingTestCase<D
|
|||
return new DataFrameAnalyticsSource(index, queryProvider, sourceFiltering);
|
||||
}
|
||||
|
||||
public static DataFrameAnalyticsSource mutateForVersion(DataFrameAnalyticsSource instance, Version version) {
|
||||
if (version.before(Version.V_7_6_0)) {
|
||||
return new DataFrameAnalyticsSource(instance.getIndex(), instance.getQueryProvider(), null);
|
||||
}
|
||||
return instance;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<DataFrameAnalyticsSource> instanceReader() {
|
||||
return DataFrameAnalyticsSource::new;
|
||||
|
@ -132,4 +140,9 @@ public class DataFrameAnalyticsSourceTests extends AbstractSerializingTestCase<D
|
|||
includes.toArray(new String[0]), excludes.toArray(new String[0]));
|
||||
return new DataFrameAnalyticsSource(new String[] { "index" } , null, sourceFiltering);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected DataFrameAnalyticsSource mutateInstanceForVersion(DataFrameAnalyticsSource instance, Version version) {
|
||||
return mutateForVersion(instance, version);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,16 +6,17 @@
|
|||
package org.elasticsearch.xpack.core.ml.dataframe.analyses;
|
||||
|
||||
import org.elasticsearch.ElasticsearchStatusException;
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public class BoostedTreeParamsTests extends AbstractSerializingTestCase<BoostedTreeParams> {
|
||||
public class BoostedTreeParamsTests extends AbstractBWCSerializationTestCase<BoostedTreeParams> {
|
||||
|
||||
@Override
|
||||
protected BoostedTreeParams doParseInstance(XContentParser parser) throws IOException {
|
||||
|
@ -44,6 +45,14 @@ public class BoostedTreeParamsTests extends AbstractSerializingTestCase<BoostedT
|
|||
.build();
|
||||
}
|
||||
|
||||
public static BoostedTreeParams mutateForVersion(BoostedTreeParams instance, Version version) {
|
||||
BoostedTreeParams.Builder builder = new BoostedTreeParams.Builder(instance);
|
||||
if (version.before(Version.V_7_6_0)) {
|
||||
builder.setNumTopFeatureImportanceValues(null);
|
||||
}
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<BoostedTreeParams> instanceReader() {
|
||||
return BoostedTreeParams::new;
|
||||
|
@ -111,4 +120,9 @@ public class BoostedTreeParamsTests extends AbstractSerializingTestCase<BoostedT
|
|||
|
||||
assertThat(e.getMessage(), equalTo("[num_top_feature_importance_values] must be a non-negative integer"));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected BoostedTreeParams mutateInstanceForVersion(BoostedTreeParams instance, Version version) {
|
||||
return mutateForVersion(instance, version);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,7 +16,7 @@ import org.elasticsearch.common.xcontent.json.JsonXContent;
|
|||
import org.elasticsearch.index.mapper.BooleanFieldMapper;
|
||||
import org.elasticsearch.index.mapper.KeywordFieldMapper;
|
||||
import org.elasticsearch.index.mapper.NumberFieldMapper;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
|
||||
import org.hamcrest.Matchers;
|
||||
|
||||
import java.io.IOException;
|
||||
|
@ -37,7 +37,7 @@ import static org.hamcrest.Matchers.not;
|
|||
import static org.hamcrest.Matchers.notNullValue;
|
||||
import static org.hamcrest.Matchers.nullValue;
|
||||
|
||||
public class ClassificationTests extends AbstractSerializingTestCase<Classification> {
|
||||
public class ClassificationTests extends AbstractBWCSerializationTestCase<Classification> {
|
||||
|
||||
private static final BoostedTreeParams BOOSTED_TREE_PARAMS = BoostedTreeParams.builder().build();
|
||||
|
||||
|
@ -62,6 +62,37 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
|
|||
randomizeSeed);
|
||||
}
|
||||
|
||||
public static Classification mutateForVersion(Classification instance, Version version) {
|
||||
return new Classification(instance.getDependentVariable(),
|
||||
BoostedTreeParamsTests.mutateForVersion(instance.getBoostedTreeParams(), version),
|
||||
instance.getPredictionFieldName(),
|
||||
instance.getNumTopClasses(),
|
||||
instance.getTrainingPercent(),
|
||||
instance.getRandomizeSeed());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void assertOnBWCObject(Classification bwcSerializedObject, Classification testInstance, Version version) {
|
||||
if (version.onOrAfter(Version.V_7_6_0)) {
|
||||
super.assertOnBWCObject(bwcSerializedObject, testInstance, version);
|
||||
return;
|
||||
}
|
||||
|
||||
Classification newBwc = new Classification(bwcSerializedObject.getDependentVariable(),
|
||||
bwcSerializedObject.getBoostedTreeParams(),
|
||||
bwcSerializedObject.getPredictionFieldName(),
|
||||
bwcSerializedObject.getNumTopClasses(),
|
||||
bwcSerializedObject.getTrainingPercent(),
|
||||
42L);
|
||||
Classification newInstance = new Classification(testInstance.getDependentVariable(),
|
||||
testInstance.getBoostedTreeParams(),
|
||||
testInstance.getPredictionFieldName(),
|
||||
testInstance.getNumTopClasses(),
|
||||
testInstance.getTrainingPercent(),
|
||||
42L);
|
||||
super.assertOnBWCObject(newBwc, newInstance, version);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<Classification> instanceReader() {
|
||||
return Classification::new;
|
||||
|
@ -270,4 +301,9 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
|
|||
assertThat(Classification.extractJobIdFromStateDoc("foo_bar-1_classification_state#1"), equalTo("foo_bar-1"));
|
||||
assertThat(Classification.extractJobIdFromStateDoc("noop"), is(nullValue()));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Classification mutateInstanceForVersion(Classification instance, Version version) {
|
||||
return mutateForVersion(instance, version);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,9 +5,10 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.analyses;
|
||||
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Map;
|
||||
|
@ -18,7 +19,7 @@ import static org.hamcrest.Matchers.empty;
|
|||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
|
||||
public class OutlierDetectionTests extends AbstractSerializingTestCase<OutlierDetection> {
|
||||
public class OutlierDetectionTests extends AbstractBWCSerializationTestCase<OutlierDetection> {
|
||||
|
||||
@Override
|
||||
protected OutlierDetection doParseInstance(XContentParser parser) throws IOException {
|
||||
|
@ -44,6 +45,17 @@ public class OutlierDetectionTests extends AbstractSerializingTestCase<OutlierDe
|
|||
.build();
|
||||
}
|
||||
|
||||
public static OutlierDetection mutateForVersion(OutlierDetection instance, Version version) {
|
||||
if (version.before(Version.V_7_5_0)) {
|
||||
return new OutlierDetection.Builder(instance)
|
||||
.setComputeFeatureInfluence(true)
|
||||
.setOutlierFraction(0.05)
|
||||
.setStandardizationEnabled(true)
|
||||
.build();
|
||||
}
|
||||
return instance;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<OutlierDetection> instanceReader() {
|
||||
return OutlierDetection::new;
|
||||
|
@ -101,4 +113,9 @@ public class OutlierDetectionTests extends AbstractSerializingTestCase<OutlierDe
|
|||
assertThat(outlierDetection.persistsState(), is(false));
|
||||
expectThrows(UnsupportedOperationException.class, () -> outlierDetection.getStateDocId("foo"));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected OutlierDetection mutateInstanceForVersion(OutlierDetection instance, Version version) {
|
||||
return mutateForVersion(instance, version);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -13,7 +13,7 @@ import org.elasticsearch.common.xcontent.ToXContent;
|
|||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.common.xcontent.json.JsonXContent;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
|
@ -28,7 +28,7 @@ import static org.hamcrest.Matchers.not;
|
|||
import static org.hamcrest.Matchers.notNullValue;
|
||||
import static org.hamcrest.Matchers.nullValue;
|
||||
|
||||
public class RegressionTests extends AbstractSerializingTestCase<Regression> {
|
||||
public class RegressionTests extends AbstractBWCSerializationTestCase<Regression> {
|
||||
|
||||
private static final BoostedTreeParams BOOSTED_TREE_PARAMS = BoostedTreeParams.builder().build();
|
||||
|
||||
|
@ -42,7 +42,7 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
|
|||
return createRandom();
|
||||
}
|
||||
|
||||
private static Regression createRandom() {
|
||||
public static Regression createRandom() {
|
||||
String dependentVariableName = randomAlphaOfLength(10);
|
||||
BoostedTreeParams boostedTreeParams = BoostedTreeParamsTests.createRandom();
|
||||
String predictionFieldName = randomBoolean() ? null : randomAlphaOfLength(10);
|
||||
|
@ -51,6 +51,39 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
|
|||
return new Regression(dependentVariableName, boostedTreeParams, predictionFieldName, trainingPercent, randomizeSeed);
|
||||
}
|
||||
|
||||
public static Regression mutateForVersion(Regression instance, Version version) {
|
||||
return new Regression(instance.getDependentVariable(),
|
||||
BoostedTreeParamsTests.mutateForVersion(instance.getBoostedTreeParams(), version),
|
||||
instance.getPredictionFieldName(),
|
||||
instance.getTrainingPercent(),
|
||||
instance.getRandomizeSeed());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void assertOnBWCObject(Regression bwcSerializedObject, Regression testInstance, Version version) {
|
||||
if (version.onOrAfter(Version.V_7_6_0)) {
|
||||
super.assertOnBWCObject(bwcSerializedObject, testInstance, version);
|
||||
return;
|
||||
}
|
||||
|
||||
Regression newBwc = new Regression(bwcSerializedObject.getDependentVariable(),
|
||||
bwcSerializedObject.getBoostedTreeParams(),
|
||||
bwcSerializedObject.getPredictionFieldName(),
|
||||
bwcSerializedObject.getTrainingPercent(),
|
||||
42L);
|
||||
Regression newInstance = new Regression(testInstance.getDependentVariable(),
|
||||
testInstance.getBoostedTreeParams(),
|
||||
testInstance.getPredictionFieldName(),
|
||||
testInstance.getTrainingPercent(),
|
||||
42L);
|
||||
super.assertOnBWCObject(newBwc, newInstance, version);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Regression mutateInstanceForVersion(Regression instance, Version version) {
|
||||
return mutateForVersion(instance, version);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<Regression> instanceReader() {
|
||||
return Regression::new;
|
||||
|
|
|
@ -15,6 +15,7 @@ import org.elasticsearch.threadpool.ThreadPool;
|
|||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigTests;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetectionTests;
|
||||
import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask;
|
||||
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
|
||||
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory;
|
||||
|
@ -87,7 +88,9 @@ public class AnalyticsProcessManagerTests extends ESTestCase {
|
|||
task = mock(DataFrameAnalyticsTask.class);
|
||||
when(task.getAllocationId()).thenReturn(TASK_ALLOCATION_ID);
|
||||
when(task.getProgressTracker()).thenReturn(mock(DataFrameAnalyticsTask.ProgressTracker.class));
|
||||
dataFrameAnalyticsConfig = DataFrameAnalyticsConfigTests.createRandom(CONFIG_ID);
|
||||
dataFrameAnalyticsConfig = DataFrameAnalyticsConfigTests.createRandomBuilder(CONFIG_ID,
|
||||
false,
|
||||
OutlierDetectionTests.createRandom()).build();
|
||||
dataExtractor = mock(DataFrameDataExtractor.class);
|
||||
when(dataExtractor.collectDataSummary()).thenReturn(new DataFrameDataExtractor.DataSummary(NUM_ROWS, NUM_COLS));
|
||||
dataExtractorFactory = mock(DataFrameDataExtractorFactory.class);
|
||||
|
|
Loading…
Reference in New Issue