diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsSource.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsSource.java index 42ab45d101a..ac78f3ebfe2 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsSource.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsSource.java @@ -157,6 +157,11 @@ public class DataFrameAnalyticsSource implements Writeable, ToXContentObject { return queryProvider.getParsingException(); } + // visible for testing + QueryProvider getQueryProvider() { + return queryProvider; + } + /** * Calls the parser and returns any gathered deprecations * diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/BoostedTreeParams.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/BoostedTreeParams.java index e0890c21377..04ced8a2f4c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/BoostedTreeParams.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/BoostedTreeParams.java @@ -190,6 +190,15 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable { private Builder() {} + Builder(BoostedTreeParams params) { + this.lambda = params.lambda; + this.gamma = params.gamma; + this.eta = params.eta; + this.maximumNumberTrees = params.maximumNumberTrees; + this.featureBagFraction = params.featureBagFraction; + this.numTopFeatureImportanceValues = params.numTopFeatureImportanceValues; + } + public Builder setLambda(Double lambda) { this.lambda = lambda; return this; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java index d834547c020..ece6f6a278b 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java @@ -152,8 +152,7 @@ public class Classification implements DataFrameAnalysis { return trainingPercent; } - @Nullable - public Long getRandomizeSeed() { + public long getRandomizeSeed() { return randomizeSeed; } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java index 264cd06f9f6..654d5ba4d1a 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java @@ -271,6 +271,17 @@ public class OutlierDetection implements DataFrameAnalysis { private double outlierFraction = 0.05; private boolean standardizationEnabled = true; + public Builder() {} + + public Builder(OutlierDetection other) { + this.nNeighbors = other.nNeighbors; + this.method = other.method; + this.featureInfluenceThreshold = other.featureInfluenceThreshold; + this.computeFeatureInfluence = other.computeFeatureInfluence; + this.outlierFraction = other.outlierFraction; + this.standardizationEnabled = other.standardizationEnabled; + } + public Builder setNNeighbors(Integer nNeighbors) { this.nNeighbors = nNeighbors; return this; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java index a4e4c423782..86f8039090c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java @@ -116,8 +116,7 @@ public class Regression implements DataFrameAnalysis { return trainingPercent; } - @Nullable - public Long getRandomizeSeed() { + public long getRandomizeSeed() { return randomizeSeed; } @@ -222,7 +221,7 @@ public class Regression implements DataFrameAnalysis { && Objects.equals(boostedTreeParams, that.boostedTreeParams) && Objects.equals(predictionFieldName, that.predictionFieldName) && trainingPercent == that.trainingPercent - && randomizeSeed == randomizeSeed; + && randomizeSeed == that.randomizeSeed; } @Override diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/AbstractBWCSerializationTestCase.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/AbstractBWCSerializationTestCase.java new file mode 100644 index 00000000000..f34f8fc008f --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/AbstractBWCSerializationTestCase.java @@ -0,0 +1,74 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml; + +import org.elasticsearch.Version; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.test.AbstractSerializingTestCase; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +import static org.elasticsearch.Version.getDeclaredVersions; + +public abstract class AbstractBWCSerializationTestCase extends AbstractSerializingTestCase { + + private static final List ALL_VERSIONS = Collections.unmodifiableList(getDeclaredVersions(Version.class)); + + public static List getAllBWCVersions(Version version) { + return ALL_VERSIONS.stream().filter(v -> v.before(version) && version.isCompatible(v)).collect(Collectors.toList()); + } + + private static final List DEFAULT_BWC_VERSIONS = getAllBWCVersions(Version.CURRENT); + + /** + * Returns the expected instance if serialized from the given version. + */ + protected abstract T mutateInstanceForVersion(T instance, Version version); + + /** + * The bwc versions to test serialization against + */ + protected List bwcVersions() { + return DEFAULT_BWC_VERSIONS; + } + + /** + * Test serialization and deserialization of the test instance across versions + */ + public final void testBwcSerialization() throws IOException { + for (int runs = 0; runs < NUMBER_OF_TEST_RUNS; runs++) { + T testInstance = createTestInstance(); + for (Version bwcVersion : bwcVersions()) { + assertBwcSerialization(testInstance, bwcVersion); + } + } + } + + /** + * Assert that instances copied at a particular version are equal. The version is useful + * for sanity checking the backwards compatibility of the wire. It isn't a substitute for + * real backwards compatibility tests but it is *so* much faster. + */ + protected final void assertBwcSerialization(T testInstance, Version version) throws IOException { + T deserializedInstance = copyWriteable(testInstance, getNamedWriteableRegistry(), instanceReader(), version); + assertOnBWCObject(deserializedInstance, mutateInstanceForVersion(testInstance, version), version); + } + + /** + * @param bwcSerializedObject The object deserialized from the previous version + * @param testInstance The original test instance + * @param version The version which serialized + */ + protected void assertOnBWCObject(T bwcSerializedObject, T testInstance, Version version) { + assertNotSame(version.toString(), bwcSerializedObject, testInstance); + assertEquals(version.toString(), bwcSerializedObject, testInstance); + assertEquals(version.toString(), bwcSerializedObject.hashCode(), testInstance.hashCode()); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java index 46df2a9d31e..8992530c1db 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java @@ -31,10 +31,15 @@ import org.elasticsearch.common.xcontent.json.JsonXContent; import org.elasticsearch.index.query.MatchAllQueryBuilder; import org.elasticsearch.search.SearchModule; import org.elasticsearch.search.fetch.subphase.FetchSourceContext; -import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.ClassificationTests; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis; import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection; import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetectionTests; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.RegressionTests; import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; import org.junit.Before; @@ -56,7 +61,7 @@ import static org.hamcrest.Matchers.notNullValue; import static org.hamcrest.Matchers.nullValue; import static org.hamcrest.Matchers.startsWith; -public class DataFrameAnalyticsConfigTests extends AbstractSerializingTestCase { +public class DataFrameAnalyticsConfigTests extends AbstractBWCSerializationTestCase { @Override protected DataFrameAnalyticsConfig doParseInstance(XContentParser parser) throws IOException { @@ -88,6 +93,83 @@ public class DataFrameAnalyticsConfigTests extends AbstractSerializingTestCase instanceReader() { return DataFrameAnalyticsConfig::new; @@ -98,19 +180,23 @@ public class DataFrameAnalyticsConfigTests extends AbstractSerializingTestCase { +public class DataFrameAnalyticsDestTests extends AbstractBWCSerializationTestCase { @Override protected DataFrameAnalyticsDest doParseInstance(XContentParser parser) throws IOException { @@ -29,8 +30,17 @@ public class DataFrameAnalyticsDestTests extends AbstractSerializingTestCase instanceReader() { return DataFrameAnalyticsDest::new; } + + @Override + protected DataFrameAnalyticsDest mutateInstanceForVersion(DataFrameAnalyticsDest instance, Version version) { + return mutateForVersion(instance, version); + } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsSourceTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsSourceTests.java index da62dd8cf4e..d227d4b2059 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsSourceTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsSourceTests.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.core.ml.dataframe; +import org.elasticsearch.Version; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.settings.Settings; @@ -13,7 +14,7 @@ import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.search.SearchModule; import org.elasticsearch.search.fetch.subphase.FetchSourceContext; -import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase; import org.elasticsearch.xpack.core.ml.utils.QueryProvider; import java.io.IOException; @@ -25,7 +26,7 @@ import java.util.List; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; -public class DataFrameAnalyticsSourceTests extends AbstractSerializingTestCase { +public class DataFrameAnalyticsSourceTests extends AbstractBWCSerializationTestCase { @Override protected NamedWriteableRegistry getNamedWriteableRegistry() { @@ -69,6 +70,13 @@ public class DataFrameAnalyticsSourceTests extends AbstractSerializingTestCase instanceReader() { return DataFrameAnalyticsSource::new; @@ -132,4 +140,9 @@ public class DataFrameAnalyticsSourceTests extends AbstractSerializingTestCase { +public class BoostedTreeParamsTests extends AbstractBWCSerializationTestCase { @Override protected BoostedTreeParams doParseInstance(XContentParser parser) throws IOException { @@ -44,6 +45,14 @@ public class BoostedTreeParamsTests extends AbstractSerializingTestCase instanceReader() { return BoostedTreeParams::new; @@ -111,4 +120,9 @@ public class BoostedTreeParamsTests extends AbstractSerializingTestCase { +public class ClassificationTests extends AbstractBWCSerializationTestCase { private static final BoostedTreeParams BOOSTED_TREE_PARAMS = BoostedTreeParams.builder().build(); @@ -62,6 +62,37 @@ public class ClassificationTests extends AbstractSerializingTestCase instanceReader() { return Classification::new; @@ -270,4 +301,9 @@ public class ClassificationTests extends AbstractSerializingTestCase { +public class OutlierDetectionTests extends AbstractBWCSerializationTestCase { @Override protected OutlierDetection doParseInstance(XContentParser parser) throws IOException { @@ -44,6 +45,17 @@ public class OutlierDetectionTests extends AbstractSerializingTestCase instanceReader() { return OutlierDetection::new; @@ -101,4 +113,9 @@ public class OutlierDetectionTests extends AbstractSerializingTestCase outlierDetection.getStateDocId("foo")); } + + @Override + protected OutlierDetection mutateInstanceForVersion(OutlierDetection instance, Version version) { + return mutateForVersion(instance, version); + } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java index 46e866abb6e..d843acdea4d 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java @@ -13,7 +13,7 @@ import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.json.JsonXContent; -import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase; import java.io.IOException; import java.util.Collections; @@ -28,7 +28,7 @@ import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.notNullValue; import static org.hamcrest.Matchers.nullValue; -public class RegressionTests extends AbstractSerializingTestCase { +public class RegressionTests extends AbstractBWCSerializationTestCase { private static final BoostedTreeParams BOOSTED_TREE_PARAMS = BoostedTreeParams.builder().build(); @@ -42,7 +42,7 @@ public class RegressionTests extends AbstractSerializingTestCase { return createRandom(); } - private static Regression createRandom() { + public static Regression createRandom() { String dependentVariableName = randomAlphaOfLength(10); BoostedTreeParams boostedTreeParams = BoostedTreeParamsTests.createRandom(); String predictionFieldName = randomBoolean() ? null : randomAlphaOfLength(10); @@ -51,6 +51,39 @@ public class RegressionTests extends AbstractSerializingTestCase { return new Regression(dependentVariableName, boostedTreeParams, predictionFieldName, trainingPercent, randomizeSeed); } + public static Regression mutateForVersion(Regression instance, Version version) { + return new Regression(instance.getDependentVariable(), + BoostedTreeParamsTests.mutateForVersion(instance.getBoostedTreeParams(), version), + instance.getPredictionFieldName(), + instance.getTrainingPercent(), + instance.getRandomizeSeed()); + } + + @Override + protected void assertOnBWCObject(Regression bwcSerializedObject, Regression testInstance, Version version) { + if (version.onOrAfter(Version.V_7_6_0)) { + super.assertOnBWCObject(bwcSerializedObject, testInstance, version); + return; + } + + Regression newBwc = new Regression(bwcSerializedObject.getDependentVariable(), + bwcSerializedObject.getBoostedTreeParams(), + bwcSerializedObject.getPredictionFieldName(), + bwcSerializedObject.getTrainingPercent(), + 42L); + Regression newInstance = new Regression(testInstance.getDependentVariable(), + testInstance.getBoostedTreeParams(), + testInstance.getPredictionFieldName(), + testInstance.getTrainingPercent(), + 42L); + super.assertOnBWCObject(newBwc, newInstance, version); + } + + @Override + protected Regression mutateInstanceForVersion(Regression instance, Version version) { + return mutateForVersion(instance, version); + } + @Override protected Writeable.Reader instanceReader() { return Regression::new; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java index 21c06b96c40..29f58ae18f0 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java @@ -15,6 +15,7 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigTests; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetectionTests; import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask; import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor; import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory; @@ -87,7 +88,9 @@ public class AnalyticsProcessManagerTests extends ESTestCase { task = mock(DataFrameAnalyticsTask.class); when(task.getAllocationId()).thenReturn(TASK_ALLOCATION_ID); when(task.getProgressTracker()).thenReturn(mock(DataFrameAnalyticsTask.ProgressTracker.class)); - dataFrameAnalyticsConfig = DataFrameAnalyticsConfigTests.createRandom(CONFIG_ID); + dataFrameAnalyticsConfig = DataFrameAnalyticsConfigTests.createRandomBuilder(CONFIG_ID, + false, + OutlierDetectionTests.createRandom()).build(); dataExtractor = mock(DataFrameDataExtractor.class); when(dataExtractor.collectDataSummary()).thenReturn(new DataFrameDataExtractor.DataSummary(NUM_ROWS, NUM_COLS)); dataExtractorFactory = mock(DataFrameDataExtractorFactory.class);