[7.x] [ML] Add bwc serialization unit test scaffold (#51889) (#52061)

* [ML] Add bwc serialization unit test scaffold (#51889)

Adds new `AbstractBWCSerializationTestCase` which provides easy scaffolding for BWC serialization unit tests.

These are no replacement for true BWC tests (which execute actual old code). These tests do provide some good coverage for the current code when serializing to/from old versions.

* removing unnecessary override for 7.series branch

* adding necessary import

Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
This commit is contained in:
Benjamin Trent 2020-02-07 17:17:11 -05:00 committed by GitHub
parent c6111eb90e
commit dffcd021df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 334 additions and 25 deletions

View File

@ -157,6 +157,11 @@ public class DataFrameAnalyticsSource implements Writeable, ToXContentObject {
return queryProvider.getParsingException();
}
// visible for testing
QueryProvider getQueryProvider() {
return queryProvider;
}
/**
* Calls the parser and returns any gathered deprecations
*

View File

@ -190,6 +190,15 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable {
private Builder() {}
Builder(BoostedTreeParams params) {
this.lambda = params.lambda;
this.gamma = params.gamma;
this.eta = params.eta;
this.maximumNumberTrees = params.maximumNumberTrees;
this.featureBagFraction = params.featureBagFraction;
this.numTopFeatureImportanceValues = params.numTopFeatureImportanceValues;
}
public Builder setLambda(Double lambda) {
this.lambda = lambda;
return this;

View File

@ -152,8 +152,7 @@ public class Classification implements DataFrameAnalysis {
return trainingPercent;
}
@Nullable
public Long getRandomizeSeed() {
public long getRandomizeSeed() {
return randomizeSeed;
}

View File

@ -271,6 +271,17 @@ public class OutlierDetection implements DataFrameAnalysis {
private double outlierFraction = 0.05;
private boolean standardizationEnabled = true;
public Builder() {}
public Builder(OutlierDetection other) {
this.nNeighbors = other.nNeighbors;
this.method = other.method;
this.featureInfluenceThreshold = other.featureInfluenceThreshold;
this.computeFeatureInfluence = other.computeFeatureInfluence;
this.outlierFraction = other.outlierFraction;
this.standardizationEnabled = other.standardizationEnabled;
}
public Builder setNNeighbors(Integer nNeighbors) {
this.nNeighbors = nNeighbors;
return this;

View File

@ -116,8 +116,7 @@ public class Regression implements DataFrameAnalysis {
return trainingPercent;
}
@Nullable
public Long getRandomizeSeed() {
public long getRandomizeSeed() {
return randomizeSeed;
}
@ -222,7 +221,7 @@ public class Regression implements DataFrameAnalysis {
&& Objects.equals(boostedTreeParams, that.boostedTreeParams)
&& Objects.equals(predictionFieldName, that.predictionFieldName)
&& trainingPercent == that.trainingPercent
&& randomizeSeed == randomizeSeed;
&& randomizeSeed == that.randomizeSeed;
}
@Override

View File

@ -0,0 +1,74 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml;
import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.test.AbstractSerializingTestCase;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import static org.elasticsearch.Version.getDeclaredVersions;
public abstract class AbstractBWCSerializationTestCase<T extends Writeable & ToXContent> extends AbstractSerializingTestCase<T> {
private static final List<Version> ALL_VERSIONS = Collections.unmodifiableList(getDeclaredVersions(Version.class));
public static List<Version> getAllBWCVersions(Version version) {
return ALL_VERSIONS.stream().filter(v -> v.before(version) && version.isCompatible(v)).collect(Collectors.toList());
}
private static final List<Version> DEFAULT_BWC_VERSIONS = getAllBWCVersions(Version.CURRENT);
/**
* Returns the expected instance if serialized from the given version.
*/
protected abstract T mutateInstanceForVersion(T instance, Version version);
/**
* The bwc versions to test serialization against
*/
protected List<Version> bwcVersions() {
return DEFAULT_BWC_VERSIONS;
}
/**
* Test serialization and deserialization of the test instance across versions
*/
public final void testBwcSerialization() throws IOException {
for (int runs = 0; runs < NUMBER_OF_TEST_RUNS; runs++) {
T testInstance = createTestInstance();
for (Version bwcVersion : bwcVersions()) {
assertBwcSerialization(testInstance, bwcVersion);
}
}
}
/**
* Assert that instances copied at a particular version are equal. The version is useful
* for sanity checking the backwards compatibility of the wire. It isn't a substitute for
* real backwards compatibility tests but it is *so* much faster.
*/
protected final void assertBwcSerialization(T testInstance, Version version) throws IOException {
T deserializedInstance = copyWriteable(testInstance, getNamedWriteableRegistry(), instanceReader(), version);
assertOnBWCObject(deserializedInstance, mutateInstanceForVersion(testInstance, version), version);
}
/**
* @param bwcSerializedObject The object deserialized from the previous version
* @param testInstance The original test instance
* @param version The version which serialized
*/
protected void assertOnBWCObject(T bwcSerializedObject, T testInstance, Version version) {
assertNotSame(version.toString(), bwcSerializedObject, testInstance);
assertEquals(version.toString(), bwcSerializedObject, testInstance);
assertEquals(version.toString(), bwcSerializedObject.hashCode(), testInstance.hashCode());
}
}

View File

@ -31,10 +31,15 @@ import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.index.query.MatchAllQueryBuilder;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.search.fetch.subphase.FetchSourceContext;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.ClassificationTests;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetectionTests;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.RegressionTests;
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
import org.junit.Before;
@ -56,7 +61,7 @@ import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.nullValue;
import static org.hamcrest.Matchers.startsWith;
public class DataFrameAnalyticsConfigTests extends AbstractSerializingTestCase<DataFrameAnalyticsConfig> {
public class DataFrameAnalyticsConfigTests extends AbstractBWCSerializationTestCase<DataFrameAnalyticsConfig> {
@Override
protected DataFrameAnalyticsConfig doParseInstance(XContentParser parser) throws IOException {
@ -88,6 +93,83 @@ public class DataFrameAnalyticsConfigTests extends AbstractSerializingTestCase<D
return createRandom(randomValidId(), lenient);
}
@Override
protected DataFrameAnalyticsConfig mutateInstanceForVersion(DataFrameAnalyticsConfig instance, Version version) {
DataFrameAnalyticsConfig.Builder builder = new DataFrameAnalyticsConfig.Builder(instance)
.setSource(DataFrameAnalyticsSourceTests.mutateForVersion(instance.getSource(), version))
.setDest(DataFrameAnalyticsDestTests.mutateForVersion(instance.getDest(), version));
if (instance.getAnalysis() instanceof OutlierDetection) {
builder.setAnalysis(OutlierDetectionTests.mutateForVersion((OutlierDetection)instance.getAnalysis(), version));
}
if (instance.getAnalysis() instanceof Regression) {
builder.setAnalysis(RegressionTests.mutateForVersion((Regression)instance.getAnalysis(), version));
}
if (instance.getAnalysis() instanceof Classification) {
builder.setAnalysis(ClassificationTests.mutateForVersion((Classification)instance.getAnalysis(), version));
}
if (version.before(Version.V_7_5_0)) {
builder.setAllowLazyStart(false);
}
if (version.before(Version.V_7_4_0)) {
builder.setDescription(null);
}
if (version.before(Version.V_7_3_0)) {
builder.setCreateTime(null);
builder.setVersion(null);
}
return builder.build();
}
@Override
protected void assertOnBWCObject(DataFrameAnalyticsConfig bwcSerializedObject, DataFrameAnalyticsConfig testInstance, Version version) {
// Don't have to worry about Regression/Classifications Seeds
if (version.onOrAfter(Version.V_7_6_0) || testInstance.getAnalysis() instanceof OutlierDetection) {
super.assertOnBWCObject(bwcSerializedObject, testInstance, version);
return;
}
DataFrameAnalysis bwcAnalysis;
DataFrameAnalysis testAnalysis;
if (testInstance.getAnalysis() instanceof Regression) {
Regression testRegression = (Regression)testInstance.getAnalysis();
Regression bwcRegression = (Regression)bwcSerializedObject.getAnalysis();
bwcAnalysis = new Regression(bwcRegression.getDependentVariable(),
bwcRegression.getBoostedTreeParams(),
bwcRegression.getPredictionFieldName(),
bwcRegression.getTrainingPercent(),
42L);
testAnalysis = new Regression(testRegression.getDependentVariable(),
testRegression.getBoostedTreeParams(),
testRegression.getPredictionFieldName(),
testRegression.getTrainingPercent(),
42L);
} else {
Classification testClassification = (Classification)testInstance.getAnalysis();
Classification bwcClassification = (Classification)bwcSerializedObject.getAnalysis();
bwcAnalysis = new Classification(bwcClassification.getDependentVariable(),
bwcClassification.getBoostedTreeParams(),
bwcClassification.getPredictionFieldName(),
bwcClassification.getNumTopClasses(),
bwcClassification.getTrainingPercent(),
42L);
testAnalysis = new Classification(testClassification.getDependentVariable(),
testClassification.getBoostedTreeParams(),
testClassification.getPredictionFieldName(),
testClassification.getNumTopClasses(),
testClassification.getTrainingPercent(),
42L);
}
super.assertOnBWCObject(new DataFrameAnalyticsConfig.Builder(bwcSerializedObject)
.setAnalysis(bwcAnalysis)
.build(),
new DataFrameAnalyticsConfig.Builder(testInstance)
.setAnalysis(testAnalysis)
.build(),
version);
}
@Override
protected Writeable.Reader<DataFrameAnalyticsConfig> instanceReader() {
return DataFrameAnalyticsConfig::new;
@ -98,19 +180,23 @@ public class DataFrameAnalyticsConfigTests extends AbstractSerializingTestCase<D
}
public static DataFrameAnalyticsConfig createRandom(String id, boolean withGeneratedFields) {
return createRandomBuilder(id, withGeneratedFields).build();
return createRandomBuilder(id, withGeneratedFields, randomFrom(OutlierDetectionTests.createRandom(),
RegressionTests.createRandom(),
ClassificationTests.createRandom())).build();
}
public static DataFrameAnalyticsConfig.Builder createRandomBuilder(String id) {
return createRandomBuilder(id, false);
return createRandomBuilder(id, false, randomFrom(OutlierDetectionTests.createRandom(),
RegressionTests.createRandom(),
ClassificationTests.createRandom()));
}
public static DataFrameAnalyticsConfig.Builder createRandomBuilder(String id, boolean withGeneratedFields) {
public static DataFrameAnalyticsConfig.Builder createRandomBuilder(String id, boolean withGeneratedFields, DataFrameAnalysis analysis) {
DataFrameAnalyticsSource source = DataFrameAnalyticsSourceTests.createRandom();
DataFrameAnalyticsDest dest = DataFrameAnalyticsDestTests.createRandom();
DataFrameAnalyticsConfig.Builder builder = new DataFrameAnalyticsConfig.Builder()
.setId(id)
.setAnalysis(OutlierDetectionTests.createRandom())
.setAnalysis(analysis)
.setSource(source)
.setDest(dest);
if (randomBoolean()) {

View File

@ -5,13 +5,14 @@
*/
package org.elasticsearch.xpack.core.ml.dataframe;
import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
import java.io.IOException;
public class DataFrameAnalyticsDestTests extends AbstractSerializingTestCase<DataFrameAnalyticsDest> {
public class DataFrameAnalyticsDestTests extends AbstractBWCSerializationTestCase<DataFrameAnalyticsDest> {
@Override
protected DataFrameAnalyticsDest doParseInstance(XContentParser parser) throws IOException {
@ -29,8 +30,17 @@ public class DataFrameAnalyticsDestTests extends AbstractSerializingTestCase<Dat
return new DataFrameAnalyticsDest(index, resultsField);
}
public static DataFrameAnalyticsDest mutateForVersion(DataFrameAnalyticsDest instance, Version version) {
return instance;
}
@Override
protected Writeable.Reader<DataFrameAnalyticsDest> instanceReader() {
return DataFrameAnalyticsDest::new;
}
@Override
protected DataFrameAnalyticsDest mutateInstanceForVersion(DataFrameAnalyticsDest instance, Version version) {
return mutateForVersion(instance, version);
}
}

View File

@ -5,6 +5,7 @@
*/
package org.elasticsearch.xpack.core.ml.dataframe;
import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.settings.Settings;
@ -13,7 +14,7 @@ import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.search.fetch.subphase.FetchSourceContext;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
import org.elasticsearch.xpack.core.ml.utils.QueryProvider;
import java.io.IOException;
@ -25,7 +26,7 @@ import java.util.List;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
public class DataFrameAnalyticsSourceTests extends AbstractSerializingTestCase<DataFrameAnalyticsSource> {
public class DataFrameAnalyticsSourceTests extends AbstractBWCSerializationTestCase<DataFrameAnalyticsSource> {
@Override
protected NamedWriteableRegistry getNamedWriteableRegistry() {
@ -69,6 +70,13 @@ public class DataFrameAnalyticsSourceTests extends AbstractSerializingTestCase<D
return new DataFrameAnalyticsSource(index, queryProvider, sourceFiltering);
}
public static DataFrameAnalyticsSource mutateForVersion(DataFrameAnalyticsSource instance, Version version) {
if (version.before(Version.V_7_6_0)) {
return new DataFrameAnalyticsSource(instance.getIndex(), instance.getQueryProvider(), null);
}
return instance;
}
@Override
protected Writeable.Reader<DataFrameAnalyticsSource> instanceReader() {
return DataFrameAnalyticsSource::new;
@ -132,4 +140,9 @@ public class DataFrameAnalyticsSourceTests extends AbstractSerializingTestCase<D
includes.toArray(new String[0]), excludes.toArray(new String[0]));
return new DataFrameAnalyticsSource(new String[] { "index" } , null, sourceFiltering);
}
@Override
protected DataFrameAnalyticsSource mutateInstanceForVersion(DataFrameAnalyticsSource instance, Version version) {
return mutateForVersion(instance, version);
}
}

View File

@ -6,16 +6,17 @@
package org.elasticsearch.xpack.core.ml.dataframe.analyses;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
import java.io.IOException;
import static org.hamcrest.Matchers.equalTo;
public class BoostedTreeParamsTests extends AbstractSerializingTestCase<BoostedTreeParams> {
public class BoostedTreeParamsTests extends AbstractBWCSerializationTestCase<BoostedTreeParams> {
@Override
protected BoostedTreeParams doParseInstance(XContentParser parser) throws IOException {
@ -44,6 +45,14 @@ public class BoostedTreeParamsTests extends AbstractSerializingTestCase<BoostedT
.build();
}
public static BoostedTreeParams mutateForVersion(BoostedTreeParams instance, Version version) {
BoostedTreeParams.Builder builder = new BoostedTreeParams.Builder(instance);
if (version.before(Version.V_7_6_0)) {
builder.setNumTopFeatureImportanceValues(null);
}
return builder.build();
}
@Override
protected Writeable.Reader<BoostedTreeParams> instanceReader() {
return BoostedTreeParams::new;
@ -111,4 +120,9 @@ public class BoostedTreeParamsTests extends AbstractSerializingTestCase<BoostedT
assertThat(e.getMessage(), equalTo("[num_top_feature_importance_values] must be a non-negative integer"));
}
@Override
protected BoostedTreeParams mutateInstanceForVersion(BoostedTreeParams instance, Version version) {
return mutateForVersion(instance, version);
}
}

View File

@ -16,7 +16,7 @@ import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.index.mapper.BooleanFieldMapper;
import org.elasticsearch.index.mapper.KeywordFieldMapper;
import org.elasticsearch.index.mapper.NumberFieldMapper;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
import org.hamcrest.Matchers;
import java.io.IOException;
@ -37,7 +37,7 @@ import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.nullValue;
public class ClassificationTests extends AbstractSerializingTestCase<Classification> {
public class ClassificationTests extends AbstractBWCSerializationTestCase<Classification> {
private static final BoostedTreeParams BOOSTED_TREE_PARAMS = BoostedTreeParams.builder().build();
@ -62,6 +62,37 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
randomizeSeed);
}
public static Classification mutateForVersion(Classification instance, Version version) {
return new Classification(instance.getDependentVariable(),
BoostedTreeParamsTests.mutateForVersion(instance.getBoostedTreeParams(), version),
instance.getPredictionFieldName(),
instance.getNumTopClasses(),
instance.getTrainingPercent(),
instance.getRandomizeSeed());
}
@Override
protected void assertOnBWCObject(Classification bwcSerializedObject, Classification testInstance, Version version) {
if (version.onOrAfter(Version.V_7_6_0)) {
super.assertOnBWCObject(bwcSerializedObject, testInstance, version);
return;
}
Classification newBwc = new Classification(bwcSerializedObject.getDependentVariable(),
bwcSerializedObject.getBoostedTreeParams(),
bwcSerializedObject.getPredictionFieldName(),
bwcSerializedObject.getNumTopClasses(),
bwcSerializedObject.getTrainingPercent(),
42L);
Classification newInstance = new Classification(testInstance.getDependentVariable(),
testInstance.getBoostedTreeParams(),
testInstance.getPredictionFieldName(),
testInstance.getNumTopClasses(),
testInstance.getTrainingPercent(),
42L);
super.assertOnBWCObject(newBwc, newInstance, version);
}
@Override
protected Writeable.Reader<Classification> instanceReader() {
return Classification::new;
@ -270,4 +301,9 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
assertThat(Classification.extractJobIdFromStateDoc("foo_bar-1_classification_state#1"), equalTo("foo_bar-1"));
assertThat(Classification.extractJobIdFromStateDoc("noop"), is(nullValue()));
}
@Override
protected Classification mutateInstanceForVersion(Classification instance, Version version) {
return mutateForVersion(instance, version);
}
}

View File

@ -5,9 +5,10 @@
*/
package org.elasticsearch.xpack.core.ml.dataframe.analyses;
import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
import java.io.IOException;
import java.util.Map;
@ -18,7 +19,7 @@ import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
public class OutlierDetectionTests extends AbstractSerializingTestCase<OutlierDetection> {
public class OutlierDetectionTests extends AbstractBWCSerializationTestCase<OutlierDetection> {
@Override
protected OutlierDetection doParseInstance(XContentParser parser) throws IOException {
@ -44,6 +45,17 @@ public class OutlierDetectionTests extends AbstractSerializingTestCase<OutlierDe
.build();
}
public static OutlierDetection mutateForVersion(OutlierDetection instance, Version version) {
if (version.before(Version.V_7_5_0)) {
return new OutlierDetection.Builder(instance)
.setComputeFeatureInfluence(true)
.setOutlierFraction(0.05)
.setStandardizationEnabled(true)
.build();
}
return instance;
}
@Override
protected Writeable.Reader<OutlierDetection> instanceReader() {
return OutlierDetection::new;
@ -101,4 +113,9 @@ public class OutlierDetectionTests extends AbstractSerializingTestCase<OutlierDe
assertThat(outlierDetection.persistsState(), is(false));
expectThrows(UnsupportedOperationException.class, () -> outlierDetection.getStateDocId("foo"));
}
@Override
protected OutlierDetection mutateInstanceForVersion(OutlierDetection instance, Version version) {
return mutateForVersion(instance, version);
}
}

View File

@ -13,7 +13,7 @@ import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
import java.io.IOException;
import java.util.Collections;
@ -28,7 +28,7 @@ import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.nullValue;
public class RegressionTests extends AbstractSerializingTestCase<Regression> {
public class RegressionTests extends AbstractBWCSerializationTestCase<Regression> {
private static final BoostedTreeParams BOOSTED_TREE_PARAMS = BoostedTreeParams.builder().build();
@ -42,7 +42,7 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
return createRandom();
}
private static Regression createRandom() {
public static Regression createRandom() {
String dependentVariableName = randomAlphaOfLength(10);
BoostedTreeParams boostedTreeParams = BoostedTreeParamsTests.createRandom();
String predictionFieldName = randomBoolean() ? null : randomAlphaOfLength(10);
@ -51,6 +51,39 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
return new Regression(dependentVariableName, boostedTreeParams, predictionFieldName, trainingPercent, randomizeSeed);
}
public static Regression mutateForVersion(Regression instance, Version version) {
return new Regression(instance.getDependentVariable(),
BoostedTreeParamsTests.mutateForVersion(instance.getBoostedTreeParams(), version),
instance.getPredictionFieldName(),
instance.getTrainingPercent(),
instance.getRandomizeSeed());
}
@Override
protected void assertOnBWCObject(Regression bwcSerializedObject, Regression testInstance, Version version) {
if (version.onOrAfter(Version.V_7_6_0)) {
super.assertOnBWCObject(bwcSerializedObject, testInstance, version);
return;
}
Regression newBwc = new Regression(bwcSerializedObject.getDependentVariable(),
bwcSerializedObject.getBoostedTreeParams(),
bwcSerializedObject.getPredictionFieldName(),
bwcSerializedObject.getTrainingPercent(),
42L);
Regression newInstance = new Regression(testInstance.getDependentVariable(),
testInstance.getBoostedTreeParams(),
testInstance.getPredictionFieldName(),
testInstance.getTrainingPercent(),
42L);
super.assertOnBWCObject(newBwc, newInstance, version);
}
@Override
protected Regression mutateInstanceForVersion(Regression instance, Version version) {
return mutateForVersion(instance, version);
}
@Override
protected Writeable.Reader<Regression> instanceReader() {
return Regression::new;

View File

@ -15,6 +15,7 @@ import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigTests;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetectionTests;
import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask;
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory;
@ -87,7 +88,9 @@ public class AnalyticsProcessManagerTests extends ESTestCase {
task = mock(DataFrameAnalyticsTask.class);
when(task.getAllocationId()).thenReturn(TASK_ALLOCATION_ID);
when(task.getProgressTracker()).thenReturn(mock(DataFrameAnalyticsTask.ProgressTracker.class));
dataFrameAnalyticsConfig = DataFrameAnalyticsConfigTests.createRandom(CONFIG_ID);
dataFrameAnalyticsConfig = DataFrameAnalyticsConfigTests.createRandomBuilder(CONFIG_ID,
false,
OutlierDetectionTests.createRandom()).build();
dataExtractor = mock(DataFrameDataExtractor.class);
when(dataExtractor.collectDataSummary()).thenReturn(new DataFrameDataExtractor.DataSummary(NUM_ROWS, NUM_COLS));
dataExtractorFactory = mock(DataFrameDataExtractorFactory.class);