From 27497ff75f4cd94208a85af9ea5dece73b074a36 Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Fri, 9 Aug 2019 19:31:13 +0300 Subject: [PATCH] [7.x][ML] Add regression analysis to DF analytics (#45292) (#45388) This commit adds a first draft of a regression analysis to data frame analytics. There is high probability that the exact syntax might change. This commit adds the new analysis type and its parameters as well as appropriate validation. It also modifies the extractor and the fields detector to be able to handle categorical fields as regression analysis supports them. --- .../xpack/core/XPackClientPlugin.java | 2 + .../dataframe/analyses/DataFrameAnalysis.java | 14 + ...ataFrameAnalysisNamedXContentProvider.java | 6 + .../dataframe/analyses/OutlierDetection.java | 12 + .../ml/dataframe/analyses/Regression.java | 205 ++++++++++++++ .../persistence/ElasticsearchMappings.java | 26 ++ .../ml/job/results/ReservedFieldNames.java | 9 + .../dataframe/analyses/RegressionTests.java | 100 +++++++ .../ml/qa/ml-with-security/build.gradle | 9 + ...NativeDataFrameAnalyticsIntegTestCase.java | 10 + .../integration/RunDataFrameAnalyticsIT.java | 67 +++++ .../extractor/fields/ExtractedField.java | 51 ++-- .../extractor/fields/ExtractedFields.java | 19 +- .../fields/TimeBasedExtractedFields.java | 13 +- .../extractor/DataFrameDataExtractor.java | 15 +- .../extractor/ExtractedFieldsDetector.java | 61 +++-- .../process/AnalyticsProcessConfig.java | 9 +- .../process/AnalyticsProcessManager.java | 4 +- .../extractor/fields/ExtractedFieldTests.java | 72 +++-- .../fields/ExtractedFieldsTests.java | 18 +- .../fields/TimeBasedExtractedFieldsTests.java | 21 +- .../scroll/ScrollDataExtractorTests.java | 6 +- .../scroll/SearchHitToJsonProcessorTests.java | 25 +- .../DataFrameDataExtractorTests.java | 8 +- .../ExtractedFieldsDetectorTests.java | 100 +++++-- .../test/ml/data_frame_analytics_crud.yml | 256 +++++++++++++++++- ...nfigIndexMappingsFullClusterRestartIT.java | 52 ++-- 27 files changed, 1026 insertions(+), 164 deletions(-) create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java index 287e511d64c..b30b0f3cc0c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java @@ -147,6 +147,7 @@ import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState; import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis; import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.AucRoc; @@ -454,6 +455,7 @@ public class XPackClientPlugin extends Plugin implements ActionPlugin, NetworkPl MachineLearningFeatureSetUsage::new), // ML - Data frame analytics new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, OutlierDetection.NAME.getPreferredName(), OutlierDetection::new), + new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, Regression.NAME.getPreferredName(), Regression::new), // ML - Data frame evaluation new NamedWriteableRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME.getPreferredName(), BinarySoftClassification::new), diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java index f21533d9176..47d0f96194a 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java @@ -9,8 +9,22 @@ import org.elasticsearch.common.io.stream.NamedWriteable; import org.elasticsearch.common.xcontent.ToXContentObject; import java.util.Map; +import java.util.Set; public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable { + /** + * @return The analysis parameters as a map + */ Map getParams(); + + /** + * @return {@code true} if this analysis supports fields with categorical values (i.e. text, keyword, ip) + */ + boolean supportsCategoricalFields(); + + /** + * @return The set of fields that analyzed documents must have for the analysis to operate + */ + Set getRequiredFields(); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/MlDataFrameAnalysisNamedXContentProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/MlDataFrameAnalysisNamedXContentProvider.java index a48a23e4a83..e33a7748592 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/MlDataFrameAnalysisNamedXContentProvider.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/MlDataFrameAnalysisNamedXContentProvider.java @@ -22,6 +22,10 @@ public class MlDataFrameAnalysisNamedXContentProvider implements NamedXContentPr boolean ignoreUnknownFields = (boolean) c; return OutlierDetection.fromXContent(p, ignoreUnknownFields); })); + namedXContent.add(new NamedXContentRegistry.Entry(DataFrameAnalysis.class, Regression.NAME, (p, c) -> { + boolean ignoreUnknownFields = (boolean) c; + return Regression.fromXContent(p, ignoreUnknownFields); + })); return namedXContent; } @@ -31,6 +35,8 @@ public class MlDataFrameAnalysisNamedXContentProvider implements NamedXContentPr namedWriteables.add(new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, OutlierDetection.NAME.getPreferredName(), OutlierDetection::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, Regression.NAME.getPreferredName(), + Regression::new)); return namedWriteables; } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java index e6891116ad6..35b3b5d3e95 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java @@ -16,10 +16,12 @@ import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; +import java.util.Collections; import java.util.HashMap; import java.util.Locale; import java.util.Map; import java.util.Objects; +import java.util.Set; public class OutlierDetection implements DataFrameAnalysis { @@ -152,6 +154,16 @@ public class OutlierDetection implements DataFrameAnalysis { return params; } + @Override + public boolean supportsCategoricalFields() { + return false; + } + + @Override + public Set getRequiredFields() { + return Collections.emptySet(); + } + public enum Method { LOF, LDOF, DISTANCE_KTH_NN, DISTANCE_KNN; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java new file mode 100644 index 00000000000..a6b7c983a29 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java @@ -0,0 +1,205 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.analyses; + +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import java.util.Set; + +public class Regression implements DataFrameAnalysis { + + public static final ParseField NAME = new ParseField("regression"); + + public static final ParseField DEPENDENT_VARIABLE = new ParseField("dependent_variable"); + public static final ParseField LAMBDA = new ParseField("lambda"); + public static final ParseField GAMMA = new ParseField("gamma"); + public static final ParseField ETA = new ParseField("eta"); + public static final ParseField MAXIMUM_NUMBER_TREES = new ParseField("maximum_number_trees"); + public static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction"); + public static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name"); + + private static final ConstructingObjectParser LENIENT_PARSER = createParser(true); + private static final ConstructingObjectParser STRICT_PARSER = createParser(false); + + private static ConstructingObjectParser createParser(boolean lenient) { + ConstructingObjectParser parser = new ConstructingObjectParser<>(NAME.getPreferredName(), lenient, + a -> new Regression((String) a[0], (Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5], (String) a[6])); + parser.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE); + parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), LAMBDA); + parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), GAMMA); + parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), ETA); + parser.declareInt(ConstructingObjectParser.optionalConstructorArg(), MAXIMUM_NUMBER_TREES); + parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), FEATURE_BAG_FRACTION); + parser.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME); + return parser; + } + + public static Regression fromXContent(XContentParser parser, boolean ignoreUnknownFields) { + return ignoreUnknownFields ? LENIENT_PARSER.apply(parser, null) : STRICT_PARSER.apply(parser, null); + } + + private final String dependentVariable; + private final Double lambda; + private final Double gamma; + private final Double eta; + private final Integer maximumNumberTrees; + private final Double featureBagFraction; + private final String predictionFieldName; + + public Regression(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta, + @Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, @Nullable String predictionFieldName) { + this.dependentVariable = Objects.requireNonNull(dependentVariable); + + if (lambda != null && lambda < 0) { + throw ExceptionsHelper.badRequestException("[{}] must be a non-negative double", LAMBDA.getPreferredName()); + } + this.lambda = lambda; + + if (gamma != null && gamma < 0) { + throw ExceptionsHelper.badRequestException("[{}] must be a non-negative double", GAMMA.getPreferredName()); + } + this.gamma = gamma; + + if (eta != null && (eta < 0.001 || eta > 1)) { + throw ExceptionsHelper.badRequestException("[{}] must be a double in [0.001, 1]", ETA.getPreferredName()); + } + this.eta = eta; + + if (maximumNumberTrees != null && (maximumNumberTrees <= 0 || maximumNumberTrees > 2000)) { + throw ExceptionsHelper.badRequestException("[{}] must be an integer in [1, 2000]", MAXIMUM_NUMBER_TREES.getPreferredName()); + } + this.maximumNumberTrees = maximumNumberTrees; + + if (featureBagFraction != null && (featureBagFraction <= 0 || featureBagFraction > 1.0)) { + throw ExceptionsHelper.badRequestException("[{}] must be a double in (0, 1]", FEATURE_BAG_FRACTION.getPreferredName()); + } + this.featureBagFraction = featureBagFraction; + + this.predictionFieldName = predictionFieldName; + } + + public Regression(String dependentVariable) { + this(dependentVariable, null, null, null, null, null, null); + } + + public Regression(StreamInput in) throws IOException { + dependentVariable = in.readString(); + lambda = in.readOptionalDouble(); + gamma = in.readOptionalDouble(); + eta = in.readOptionalDouble(); + maximumNumberTrees = in.readOptionalVInt(); + featureBagFraction = in.readOptionalDouble(); + predictionFieldName = in.readOptionalString(); + } + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(dependentVariable); + out.writeOptionalDouble(lambda); + out.writeOptionalDouble(gamma); + out.writeOptionalDouble(eta); + out.writeOptionalVInt(maximumNumberTrees); + out.writeOptionalDouble(featureBagFraction); + out.writeOptionalString(predictionFieldName); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable); + if (lambda != null) { + builder.field(LAMBDA.getPreferredName(), lambda); + } + if (gamma != null) { + builder.field(GAMMA.getPreferredName(), gamma); + } + if (eta != null) { + builder.field(ETA.getPreferredName(), eta); + } + if (maximumNumberTrees != null) { + builder.field(MAXIMUM_NUMBER_TREES.getPreferredName(), maximumNumberTrees); + } + if (featureBagFraction != null) { + builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction); + } + if (predictionFieldName != null) { + builder.field(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName); + } + builder.endObject(); + return builder; + } + + @Override + public Map getParams() { + Map params = new HashMap<>(); + params.put(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable); + if (lambda != null) { + params.put(LAMBDA.getPreferredName(), lambda); + } + if (gamma != null) { + params.put(GAMMA.getPreferredName(), gamma); + } + if (eta != null) { + params.put(ETA.getPreferredName(), eta); + } + if (maximumNumberTrees != null) { + params.put(MAXIMUM_NUMBER_TREES.getPreferredName(), maximumNumberTrees); + } + if (featureBagFraction != null) { + params.put(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction); + } + if (predictionFieldName != null) { + params.put(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName); + } + return params; + } + + @Override + public boolean supportsCategoricalFields() { + return true; + } + + @Override + public Set getRequiredFields() { + return Collections.singleton(dependentVariable); + } + + @Override + public int hashCode() { + return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Regression that = (Regression) o; + return Objects.equals(dependentVariable, that.dependentVariable) + && Objects.equals(lambda, that.lambda) + && Objects.equals(gamma, that.gamma) + && Objects.equals(eta, that.eta) + && Objects.equals(maximumNumberTrees, that.maximumNumberTrees) + && Objects.equals(featureBagFraction, that.featureBagFraction) + && Objects.equals(predictionFieldName, that.predictionFieldName); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java index baf655a280d..11674bf26f4 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java @@ -31,6 +31,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig; import org.elasticsearch.xpack.core.ml.job.config.AnalysisLimits; import org.elasticsearch.xpack.core.ml.job.config.DataDescription; @@ -443,6 +444,31 @@ public class ElasticsearchMappings { .endObject() .endObject() .endObject() + .startObject(Regression.NAME.getPreferredName()) + .startObject(PROPERTIES) + .startObject(Regression.DEPENDENT_VARIABLE.getPreferredName()) + .field(TYPE, KEYWORD) + .endObject() + .startObject(Regression.LAMBDA.getPreferredName()) + .field(TYPE, DOUBLE) + .endObject() + .startObject(Regression.GAMMA.getPreferredName()) + .field(TYPE, DOUBLE) + .endObject() + .startObject(Regression.ETA.getPreferredName()) + .field(TYPE, DOUBLE) + .endObject() + .startObject(Regression.MAXIMUM_NUMBER_TREES.getPreferredName()) + .field(TYPE, INTEGER) + .endObject() + .startObject(Regression.FEATURE_BAG_FRACTION.getPreferredName()) + .field(TYPE, DOUBLE) + .endObject() + .startObject(Regression.PREDICTION_FIELD_NAME.getPreferredName()) + .field(TYPE, KEYWORD) + .endObject() + .endObject() + .endObject() .endObject() .endObject() // re-used: CREATE_TIME diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java index 76860e28481..92583693af2 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java @@ -14,6 +14,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig; import org.elasticsearch.xpack.core.ml.job.config.AnalysisLimits; import org.elasticsearch.xpack.core.ml.job.config.DataDescription; @@ -299,6 +300,14 @@ public final class ReservedFieldNames { OutlierDetection.N_NEIGHBORS.getPreferredName(), OutlierDetection.METHOD.getPreferredName(), OutlierDetection.FEATURE_INFLUENCE_THRESHOLD.getPreferredName(), + Regression.NAME.getPreferredName(), + Regression.DEPENDENT_VARIABLE.getPreferredName(), + Regression.LAMBDA.getPreferredName(), + Regression.GAMMA.getPreferredName(), + Regression.ETA.getPreferredName(), + Regression.MAXIMUM_NUMBER_TREES.getPreferredName(), + Regression.FEATURE_BAG_FRACTION.getPreferredName(), + Regression.PREDICTION_FIELD_NAME.getPreferredName(), ElasticsearchMappings.CONFIG_TYPE, diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java new file mode 100644 index 00000000000..e6a3dbbe8c2 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java @@ -0,0 +1,100 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.analyses; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractSerializingTestCase; + +import java.io.IOException; + +import static org.hamcrest.Matchers.equalTo; + +public class RegressionTests extends AbstractSerializingTestCase { + + @Override + protected Regression doParseInstance(XContentParser parser) throws IOException { + return Regression.fromXContent(parser, false); + } + + @Override + protected Regression createTestInstance() { + return createRandom(); + } + + public static Regression createRandom() { + Double lambda = randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, true); + Double gamma = randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, true); + Double eta = randomBoolean() ? null : randomDoubleBetween(0.001, 1.0, true); + Integer maximumNumberTrees = randomBoolean() ? null : randomIntBetween(1, 2000); + Double featureBagFraction = randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false); + String predictionFieldName = randomBoolean() ? null : randomAlphaOfLength(10); + return new Regression(randomAlphaOfLength(10), lambda, gamma, eta, maximumNumberTrees, featureBagFraction, + predictionFieldName); + } + + @Override + protected Writeable.Reader instanceReader() { + return Regression::new; + } + + public void testRegression_GivenNegativeLambda() { + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, + () -> new Regression("foo", -0.00001, 0.0, 0.5, 500, 0.3, "result")); + + assertThat(e.getMessage(), equalTo("[lambda] must be a non-negative double")); + } + + public void testRegression_GivenNegativeGamma() { + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, + () -> new Regression("foo", 0.0, -0.00001, 0.5, 500, 0.3, "result")); + + assertThat(e.getMessage(), equalTo("[gamma] must be a non-negative double")); + } + + public void testRegression_GivenEtaIsZero() { + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, + () -> new Regression("foo", 0.0, 0.0, 0.0, 500, 0.3, "result")); + + assertThat(e.getMessage(), equalTo("[eta] must be a double in [0.001, 1]")); + } + + public void testRegression_GivenEtaIsGreaterThanOne() { + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, + () -> new Regression("foo", 0.0, 0.0, 1.00001, 500, 0.3, "result")); + + assertThat(e.getMessage(), equalTo("[eta] must be a double in [0.001, 1]")); + } + + public void testRegression_GivenMaximumNumberTreesIsZero() { + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, + () -> new Regression("foo", 0.0, 0.0, 0.5, 0, 0.3, "result")); + + assertThat(e.getMessage(), equalTo("[maximum_number_trees] must be an integer in [1, 2000]")); + } + + public void testRegression_GivenMaximumNumberTreesIsGreaterThan2k() { + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, + () -> new Regression("foo", 0.0, 0.0, 0.5, 2001, 0.3, "result")); + + assertThat(e.getMessage(), equalTo("[maximum_number_trees] must be an integer in [1, 2000]")); + } + + public void testRegression_GivenFeatureBagFractionIsLessThanZero() { + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, + () -> new Regression("foo", 0.0, 0.0, 0.5, 500, -0.00001, "result")); + + assertThat(e.getMessage(), equalTo("[feature_bag_fraction] must be a double in (0, 1]")); + } + + public void testRegression_GivenFeatureBagFractionIsGreaterThanOne() { + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, + () -> new Regression("foo", 0.0, 0.0, 0.5, 500, 1.00001, "result")); + + assertThat(e.getMessage(), equalTo("[feature_bag_fraction] must be a double in (0, 1]")); + } +} diff --git a/x-pack/plugin/ml/qa/ml-with-security/build.gradle b/x-pack/plugin/ml/qa/ml-with-security/build.gradle index 2fa1d8d4098..d3e7f73862d 100644 --- a/x-pack/plugin/ml/qa/ml-with-security/build.gradle +++ b/x-pack/plugin/ml/qa/ml-with-security/build.gradle @@ -69,6 +69,15 @@ integTest.runner { 'ml/data_frame_analytics_crud/Test get stats given expression without matches and allow_no_match is false', 'ml/data_frame_analytics_crud/Test delete given missing config', 'ml/data_frame_analytics_crud/Test max model memory limit', + 'ml/data_frame_analytics_crud/Test put regression given dependent_variable is not defined', + 'ml/data_frame_analytics_crud/Test put regression given negative lambda', + 'ml/data_frame_analytics_crud/Test put regression given negative gamma', + 'ml/data_frame_analytics_crud/Test put regression given eta less than 1e-3', + 'ml/data_frame_analytics_crud/Test put regression given eta greater than one', + 'ml/data_frame_analytics_crud/Test put regression given maximum_number_trees is zero', + 'ml/data_frame_analytics_crud/Test put regression given maximum_number_trees is greater than 2k', + 'ml/data_frame_analytics_crud/Test put regression given feature_bag_fraction is negative', + 'ml/data_frame_analytics_crud/Test put regression given feature_bag_fraction is greater than one', 'ml/evaluate_data_frame/Test given missing index', 'ml/evaluate_data_frame/Test given index does not exist', 'ml/evaluate_data_frame/Test given missing evaluation', diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java index 56ea04793c3..520f7a30ece 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java @@ -21,6 +21,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import java.io.IOException; import java.util.ArrayList; @@ -118,4 +119,13 @@ abstract class MlNativeDataFrameAnalyticsIntegTestCase extends MlNativeIntegTest assertThat(stats.get(0).getId(), equalTo(id)); assertThat(stats.get(0).getState(), equalTo(state)); } + + protected static DataFrameAnalyticsConfig buildRegressionAnalytics(String id, String[] sourceIndex, String destIndex, + @Nullable String resultsField, String dependentVariable) { + DataFrameAnalyticsConfig.Builder configBuilder = new DataFrameAnalyticsConfig.Builder(id); + configBuilder.setSource(new DataFrameAnalyticsSource(sourceIndex, null)); + configBuilder.setDest(new DataFrameAnalyticsDest(destIndex, resultsField)); + configBuilder.setAnalysis(new Regression(dependentVariable)); + return configBuilder.build(); + } } diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java index 3e4fd4f7003..f1b9f8edf14 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java @@ -21,9 +21,12 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.junit.After; +import java.util.Arrays; +import java.util.List; import java.util.Map; import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.closeTo; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.greaterThanOrEqualTo; @@ -362,4 +365,68 @@ public class RunDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsIntegTest .setQuery(QueryBuilders.existsQuery("ml.outlier_score")).get(); assertThat(searchResponse.getHits().getTotalHits().value, equalTo((long) bulkRequestBuilder.numberOfActions())); } + + public void testRegressionWithNumericFeatureAndFewDocuments() throws Exception { + String sourceIndex = "test-regression-with-numeric-feature-and-few-docs"; + + BulkRequestBuilder bulkRequestBuilder = client().prepareBulk(); + bulkRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + + List featureValues = Arrays.asList(1.0, 2.0, 3.0); + List dependentVariableValues = Arrays.asList(10.0, 20.0, 30.0); + + for (int i = 0; i < 350; i++) { + Double field = featureValues.get(i % 3); + Double value = dependentVariableValues.get(i % 3); + + IndexRequest indexRequest = new IndexRequest(sourceIndex); + if (i < 300) { + indexRequest.source("feature", field, "variable", value); + } else { + indexRequest.source("feature", field); + } + bulkRequestBuilder.add(indexRequest); + } + BulkResponse bulkResponse = bulkRequestBuilder.get(); + if (bulkResponse.hasFailures()) { + fail("Failed to index data: " + bulkResponse.buildFailureMessage()); + } + + String id = "test_regression_with_numeric_feature_and_few_docs"; + DataFrameAnalyticsConfig config = buildRegressionAnalytics(id, new String[] {sourceIndex}, + sourceIndex + "-results", null, "variable"); + registerAnalytics(config); + putAnalytics(config); + + assertState(id, DataFrameAnalyticsState.STOPPED); + + startAnalytics(id); + waitUntilAnalyticsIsStopped(id); + + int resultsWithPrediction = 0; + SearchResponse sourceData = client().prepareSearch(sourceIndex).get(); + for (SearchHit hit : sourceData.getHits()) { + GetResponse destDocGetResponse = client().prepareGet().setIndex(config.getDest().getIndex()).setId(hit.getId()).get(); + assertThat(destDocGetResponse.isExists(), is(true)); + Map sourceDoc = hit.getSourceAsMap(); + Map destDoc = destDocGetResponse.getSource(); + for (String field : sourceDoc.keySet()) { + assertThat(destDoc.containsKey(field), is(true)); + assertThat(destDoc.get(field), equalTo(sourceDoc.get(field))); + } + assertThat(destDoc.containsKey("ml"), is(true)); + + @SuppressWarnings("unchecked") + Map resultsObject = (Map) destDoc.get("ml"); + + if (resultsObject.containsKey("variable_prediction")) { + resultsWithPrediction++; + double featureValue = (double) destDoc.get("feature"); + double predictionValue = (double) resultsObject.get("variable_prediction"); + // it seems for this case values can be as far off as 2.0 + assertThat(predictionValue, closeTo(10 * featureValue, 2.0)); + } + } + assertThat(resultsWithPrediction, greaterThan(0)); + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/fields/ExtractedField.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/fields/ExtractedField.java index 7dafbb5f4dc..8c741e3c535 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/fields/ExtractedField.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/fields/ExtractedField.java @@ -16,10 +16,12 @@ import org.elasticsearch.search.SearchHit; import java.io.IOException; import java.text.ParseException; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.Locale; import java.util.Map; import java.util.Objects; +import java.util.Set; /** * Represents a field to be extracted by the datafeed. @@ -37,11 +39,14 @@ public abstract class ExtractedField { /** The name of the field we extract */ protected final String name; + private final Set types; + private final ExtractionMethod extractionMethod; - protected ExtractedField(String alias, String name, ExtractionMethod extractionMethod) { + protected ExtractedField(String alias, String name, Set types, ExtractionMethod extractionMethod) { this.alias = Objects.requireNonNull(alias); this.name = Objects.requireNonNull(name); + this.types = Objects.requireNonNull(types); this.extractionMethod = Objects.requireNonNull(extractionMethod); } @@ -53,6 +58,10 @@ public abstract class ExtractedField { return name; } + public Set getTypes() { + return types; + } + public ExtractionMethod getExtractionMethod() { return extractionMethod; } @@ -65,32 +74,32 @@ public abstract class ExtractedField { return null; } - public static ExtractedField newTimeField(String name, ExtractionMethod extractionMethod) { + public static ExtractedField newTimeField(String name, Set types, ExtractionMethod extractionMethod) { if (extractionMethod == ExtractionMethod.SOURCE) { throw new IllegalArgumentException("time field cannot be extracted from source"); } - return new TimeField(name, extractionMethod); + return new TimeField(name, types, extractionMethod); } public static ExtractedField newGeoShapeField(String alias, String name) { - return new GeoShapeField(alias, name); + return new GeoShapeField(alias, name, Collections.singleton("geo_shape")); } public static ExtractedField newGeoPointField(String alias, String name) { - return new GeoPointField(alias, name); + return new GeoPointField(alias, name, Collections.singleton("geo_point")); } - public static ExtractedField newField(String name, ExtractionMethod extractionMethod) { - return newField(name, name, extractionMethod); + public static ExtractedField newField(String name, Set types, ExtractionMethod extractionMethod) { + return newField(name, name, types, extractionMethod); } - public static ExtractedField newField(String alias, String name, ExtractionMethod extractionMethod) { + public static ExtractedField newField(String alias, String name, Set types, ExtractionMethod extractionMethod) { switch (extractionMethod) { case DOC_VALUE: case SCRIPT_FIELD: - return new FromFields(alias, name, extractionMethod); + return new FromFields(alias, name, types, extractionMethod); case SOURCE: - return new FromSource(alias, name); + return new FromSource(alias, name, types); default: throw new IllegalArgumentException("Invalid extraction method [" + extractionMethod + "]"); } @@ -98,7 +107,7 @@ public abstract class ExtractedField { public ExtractedField newFromSource() { if (supportsFromSource()) { - return new FromSource(alias, name); + return new FromSource(alias, name, types); } throw new IllegalStateException("Field (alias [" + alias + "], name [" + name + "]) should be extracted via [" + extractionMethod + "] and cannot be extracted from source"); @@ -106,8 +115,8 @@ public abstract class ExtractedField { private static class FromFields extends ExtractedField { - FromFields(String alias, String name, ExtractionMethod extractionMethod) { - super(alias, name, extractionMethod); + FromFields(String alias, String name, Set types, ExtractionMethod extractionMethod) { + super(alias, name, types, extractionMethod); } @Override @@ -129,8 +138,8 @@ public abstract class ExtractedField { private static class GeoShapeField extends FromSource { private static final WellKnownText wkt = new WellKnownText(true, new StandardValidator(true)); - GeoShapeField(String alias, String name) { - super(alias, name); + GeoShapeField(String alias, String name, Set types) { + super(alias, name, types); } @Override @@ -186,8 +195,8 @@ public abstract class ExtractedField { private static class GeoPointField extends FromFields { - GeoPointField(String alias, String name) { - super(alias, name, ExtractionMethod.DOC_VALUE); + GeoPointField(String alias, String name, Set types) { + super(alias, name, types, ExtractionMethod.DOC_VALUE); } @Override @@ -222,8 +231,8 @@ public abstract class ExtractedField { private static final String EPOCH_MILLIS_FORMAT = "epoch_millis"; - TimeField(String name, ExtractionMethod extractionMethod) { - super(name, name, extractionMethod); + TimeField(String name, Set types, ExtractionMethod extractionMethod) { + super(name, name, types, extractionMethod); } @Override @@ -255,8 +264,8 @@ public abstract class ExtractedField { private String[] namePath; - FromSource(String alias, String name) { - super(alias, name, ExtractionMethod.SOURCE); + FromSource(String alias, String name, Set types) { + super(alias, name, types, ExtractionMethod.SOURCE); namePath = name.split("\\."); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/fields/ExtractedFields.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/fields/ExtractedFields.java index 9495c5a2b40..6c90d2c7db2 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/fields/ExtractedFields.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/fields/ExtractedFields.java @@ -47,15 +47,6 @@ public class ExtractedFields { return docValueFields; } - /** - * Returns a new instance which only contains fields matching the given extraction method - * @param method the extraction method to filter fields on - * @return a new instance which only contains fields matching the given extraction method - */ - public ExtractedFields filterFields(ExtractedField.ExtractionMethod method) { - return new ExtractedFields(filterFields(method, allFields)); - } - private static List filterFields(ExtractedField.ExtractionMethod method, List fields) { return fields.stream().filter(field -> field.getExtractionMethod() == method).collect(Collectors.toList()); } @@ -79,12 +70,13 @@ public class ExtractedFields { protected ExtractedField detect(String field) { String internalField = field; ExtractedField.ExtractionMethod method = ExtractedField.ExtractionMethod.SOURCE; + Set types = getTypes(field); if (scriptFields.contains(field)) { method = ExtractedField.ExtractionMethod.SCRIPT_FIELD; } else if (isAggregatable(field)) { method = ExtractedField.ExtractionMethod.DOC_VALUE; if (isFieldOfType(field, "date")) { - return ExtractedField.newTimeField(field, method); + return ExtractedField.newTimeField(field, types, method); } } else if (isFieldOfType(field, TEXT)) { String parentField = MlStrings.getParentField(field); @@ -107,7 +99,12 @@ public class ExtractedFields { return ExtractedField.newGeoShapeField(field, internalField); } - return ExtractedField.newField(field, internalField, method); + return ExtractedField.newField(field, internalField, types, method); + } + + private Set getTypes(String field) { + Map fieldCaps = fieldsCapabilities.getField(field); + return fieldCaps == null ? Collections.emptySet() : fieldCaps.keySet(); } protected boolean isAggregatable(String field) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/fields/TimeBasedExtractedFields.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/fields/TimeBasedExtractedFields.java index cf87671bf33..1067ef63007 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/fields/TimeBasedExtractedFields.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/fields/TimeBasedExtractedFields.java @@ -12,6 +12,7 @@ import org.elasticsearch.xpack.core.ml.job.config.Job; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.Objects; import java.util.Set; @@ -55,12 +56,20 @@ public class TimeBasedExtractedFields extends ExtractedFields { if (scriptFields.contains(timeField) == false && extractionMethodDetector.isAggregatable(timeField) == false) { throw new IllegalArgumentException("cannot retrieve time field [" + timeField + "] because it is not aggregatable"); } - ExtractedField timeExtractedField = ExtractedField.newTimeField(timeField, scriptFields.contains(timeField) ? - ExtractedField.ExtractionMethod.SCRIPT_FIELD : ExtractedField.ExtractionMethod.DOC_VALUE); + ExtractedField timeExtractedField = extractedTimeField(timeField, scriptFields, fieldsCapabilities); List remainingFields = job.allInputFields().stream().filter(f -> !f.equals(timeField)).collect(Collectors.toList()); List allExtractedFields = new ArrayList<>(remainingFields.size() + 1); allExtractedFields.add(timeExtractedField); remainingFields.stream().forEach(field -> allExtractedFields.add(extractionMethodDetector.detect(field))); return new TimeBasedExtractedFields(timeExtractedField, allExtractedFields); } + + private static ExtractedField extractedTimeField(String timeField, Set scriptFields, + FieldCapabilitiesResponse fieldCapabilities) { + if (scriptFields.contains(timeField)) { + return ExtractedField.newTimeField(timeField, Collections.emptySet(), ExtractedField.ExtractionMethod.SCRIPT_FIELD); + } + return ExtractedField.newTimeField(timeField, fieldCapabilities.getField(timeField).keySet(), + ExtractedField.ExtractionMethod.DOC_VALUE); + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java index fa18f3bb25b..d9f1aa994d5 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java @@ -29,11 +29,13 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.NoSuchElementException; import java.util.Objects; import java.util.Optional; +import java.util.Set; import java.util.concurrent.TimeUnit; import java.util.function.Supplier; import java.util.stream.Collectors; @@ -179,7 +181,7 @@ public class DataFrameDataExtractor { for (int i = 0; i < extractedValues.length; ++i) { ExtractedField field = context.extractedFields.getAllFields().get(i); Object[] values = field.value(hit); - if (values.length == 1 && values[0] instanceof Number) { + if (values.length == 1 && (values[0] instanceof Number || values[0] instanceof String)) { extractedValues[i] = Objects.toString(values[0]); } else { extractedValues = null; @@ -233,6 +235,17 @@ public class DataFrameDataExtractor { return new DataSummary(searchResponse.getHits().getTotalHits().value, context.extractedFields.getAllFields().size()); } + public Set getCategoricalFields() { + Set categoricalFields = new HashSet<>(); + for (ExtractedField extractedField : context.extractedFields.getAllFields()) { + String fieldName = extractedField.getName(); + if (ExtractedFieldsDetector.CATEGORICAL_TYPES.containsAll(extractedField.getTypes())) { + categoricalFields.add(fieldName); + } + } + return categoricalFields; + } + public static class DataSummary { public final long rows; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java index d58eaebe353..3ff8c8a4923 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java @@ -5,6 +5,8 @@ */ package org.elasticsearch.xpack.ml.dataframe.extractor; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.fieldcaps.FieldCapabilities; import org.elasticsearch.action.fieldcaps.FieldCapabilitiesResponse; @@ -20,6 +22,7 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.NameResolver; import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedField; import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedFields; +import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsIndex; import java.util.ArrayList; import java.util.Arrays; @@ -35,24 +38,24 @@ import java.util.stream.Stream; public class ExtractedFieldsDetector { + private static final Logger LOGGER = LogManager.getLogger(ExtractedFieldsDetector.class); + /** * Fields to ignore. These are mostly internal meta fields. */ private static final List IGNORE_FIELDS = Arrays.asList("_id", "_field_names", "_index", "_parent", "_routing", "_seq_no", - "_source", "_type", "_uid", "_version", "_feature", "_ignored"); + "_source", "_type", "_uid", "_version", "_feature", "_ignored", DataFrameAnalyticsIndex.ID_COPY); - /** - * The types supported by data frames - */ - private static final Set COMPATIBLE_FIELD_TYPES; + public static final Set CATEGORICAL_TYPES = Collections.unmodifiableSet(new HashSet<>(Arrays.asList("text", "keyword", "ip"))); + + private static final Set NUMERICAL_TYPES; static { - Set compatibleTypes = Stream.of(NumberFieldMapper.NumberType.values()) + Set numericalTypes = Stream.of(NumberFieldMapper.NumberType.values()) .map(NumberFieldMapper.NumberType::typeName) .collect(Collectors.toSet()); - compatibleTypes.add("scaled_float"); // have to add manually since scaled_float is in a module - - COMPATIBLE_FIELD_TYPES = Collections.unmodifiableSet(compatibleTypes); + numericalTypes.add("scaled_float"); + NUMERICAL_TYPES = Collections.unmodifiableSet(numericalTypes); } private final String[] index; @@ -79,16 +82,18 @@ public class ExtractedFieldsDetector { // Ignore fields under the results object fields.removeIf(field -> field.startsWith(config.getDest().getResultsField() + ".")); + includeAndExcludeFields(fields); removeFieldsWithIncompatibleTypes(fields); - includeAndExcludeFields(fields, index); + checkRequiredFieldsArePresent(fields); + + if (fields.isEmpty()) { + throw ExceptionsHelper.badRequestException("No compatible fields could be detected in index {}", Arrays.toString(index)); + } + List sortedFields = new ArrayList<>(fields); // We sort the fields to ensure the checksum for each document is deterministic Collections.sort(sortedFields); - ExtractedFields extractedFields = ExtractedFields.build(sortedFields, Collections.emptySet(), fieldCapabilitiesResponse) - .filterFields(ExtractedField.ExtractionMethod.DOC_VALUE); - if (extractedFields.getAllFields().isEmpty()) { - throw ExceptionsHelper.badRequestException("No compatible fields could be detected in index {}", Arrays.toString(index)); - } + ExtractedFields extractedFields = ExtractedFields.build(sortedFields, Collections.emptySet(), fieldCapabilitiesResponse); if (extractedFields.getDocValueFields().size() > docValueFieldsLimit) { extractedFields = fetchFromSourceIfSupported(extractedFields); if (extractedFields.getDocValueFields().size() > docValueFieldsLimit) { @@ -120,13 +125,25 @@ public class ExtractedFieldsDetector { while (fieldsIterator.hasNext()) { String field = fieldsIterator.next(); Map fieldCaps = fieldCapabilitiesResponse.getField(field); - if (fieldCaps == null || COMPATIBLE_FIELD_TYPES.containsAll(fieldCaps.keySet()) == false) { + if (fieldCaps == null) { + LOGGER.debug("[{}] Removing field [{}] because it is missing from mappings", config.getId(), field); fieldsIterator.remove(); + } else { + Set fieldTypes = fieldCaps.keySet(); + if (NUMERICAL_TYPES.containsAll(fieldTypes)) { + LOGGER.debug("[{}] field [{}] is compatible as it is numerical", config.getId(), field); + } else if (config.getAnalysis().supportsCategoricalFields() && CATEGORICAL_TYPES.containsAll(fieldTypes)) { + LOGGER.debug("[{}] field [{}] is compatible as it is categorical", config.getId(), field); + } else { + LOGGER.debug("[{}] Removing field [{}] because its types are not supported; types {}", + config.getId(), field, fieldTypes); + fieldsIterator.remove(); + } } } } - private void includeAndExcludeFields(Set fields, String[] index) { + private void includeAndExcludeFields(Set fields) { FetchSourceContext analyzedFields = config.getAnalyzedFields(); if (analyzedFields == null) { return; @@ -159,6 +176,16 @@ public class ExtractedFieldsDetector { } } + private void checkRequiredFieldsArePresent(Set fields) { + List missingFields = config.getAnalysis().getRequiredFields() + .stream() + .filter(f -> fields.contains(f) == false) + .collect(Collectors.toList()); + if (missingFields.isEmpty() == false) { + throw ExceptionsHelper.badRequestException("required fields {} are missing", missingFields); + } + } + private ExtractedFields fetchFromSourceIfSupported(ExtractedFields extractedFields) { List adjusted = new ArrayList<>(extractedFields.getAllFields().size()); for (ExtractedField field : extractedFields.getDocValueFields()) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfig.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfig.java index 226498376bb..70a2e213fb6 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfig.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfig.java @@ -12,6 +12,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis; import java.io.IOException; import java.util.Objects; +import java.util.Set; public class AnalyticsProcessConfig implements ToXContentObject { @@ -21,21 +22,24 @@ public class AnalyticsProcessConfig implements ToXContentObject { private static final String THREADS = "threads"; private static final String ANALYSIS = "analysis"; private static final String RESULTS_FIELD = "results_field"; + private static final String CATEGORICAL_FIELDS = "categorical_fields"; private final long rows; private final int cols; private final ByteSizeValue memoryLimit; private final int threads; - private final DataFrameAnalysis analysis; private final String resultsField; + private final Set categoricalFields; + private final DataFrameAnalysis analysis; public AnalyticsProcessConfig(long rows, int cols, ByteSizeValue memoryLimit, int threads, String resultsField, - DataFrameAnalysis analysis) { + Set categoricalFields, DataFrameAnalysis analysis) { this.rows = rows; this.cols = cols; this.memoryLimit = Objects.requireNonNull(memoryLimit); this.threads = threads; this.resultsField = Objects.requireNonNull(resultsField); + this.categoricalFields = Objects.requireNonNull(categoricalFields); this.analysis = Objects.requireNonNull(analysis); } @@ -51,6 +55,7 @@ public class AnalyticsProcessConfig implements ToXContentObject { builder.field(MEMORY_LIMIT, memoryLimit.getBytes()); builder.field(THREADS, threads); builder.field(RESULTS_FIELD, resultsField); + builder.field(CATEGORICAL_FIELDS, categoricalFields); builder.field(ANALYSIS, new DataFrameAnalysisWrapper(analysis)); builder.endObject(); return builder; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java index cb000a15496..f04ba577be4 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java @@ -26,6 +26,7 @@ import java.io.IOException; import java.util.List; import java.util.Objects; import java.util.Optional; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ExecutorService; @@ -283,8 +284,9 @@ public class AnalyticsProcessManager { private AnalyticsProcessConfig createProcessConfig(DataFrameAnalyticsConfig config, DataFrameDataExtractor dataExtractor) { DataFrameDataExtractor.DataSummary dataSummary = dataExtractor.collectDataSummary(); + Set categoricalFields = dataExtractor.getCategoricalFields(); AnalyticsProcessConfig processConfig = new AnalyticsProcessConfig(dataSummary.rows, dataSummary.cols, - config.getModelMemoryLimit(), 1, config.getDest().getResultsField(), config.getAnalysis()); + config.getModelMemoryLimit(), 1, config.getDest().getResultsField(), categoricalFields, config.getAnalysis()); return processConfig; } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/datafeed/extractor/fields/ExtractedFieldTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/datafeed/extractor/fields/ExtractedFieldTests.java index 6969c97be0a..87f86a33f99 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/datafeed/extractor/fields/ExtractedFieldTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/datafeed/extractor/fields/ExtractedFieldTests.java @@ -10,6 +10,7 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.ml.test.SearchHitBuilder; import java.util.Arrays; +import java.util.Collections; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.startsWith; @@ -19,46 +20,51 @@ public class ExtractedFieldTests extends ESTestCase { public void testValueGivenDocValue() { SearchHit hit = new SearchHitBuilder(42).addField("single", "bar").addField("array", Arrays.asList("a", "b")).build(); - ExtractedField single = ExtractedField.newField("single", ExtractedField.ExtractionMethod.DOC_VALUE); + ExtractedField single = ExtractedField.newField("single", Collections.singleton("keyword"), + ExtractedField.ExtractionMethod.DOC_VALUE); assertThat(single.value(hit), equalTo(new String[] { "bar" })); - ExtractedField array = ExtractedField.newField("array", ExtractedField.ExtractionMethod.DOC_VALUE); + ExtractedField array = ExtractedField.newField("array", Collections.singleton("keyword"), + ExtractedField.ExtractionMethod.DOC_VALUE); assertThat(array.value(hit), equalTo(new String[] { "a", "b" })); - ExtractedField missing = ExtractedField.newField("missing", ExtractedField.ExtractionMethod.DOC_VALUE); + ExtractedField missing = ExtractedField.newField("missing",Collections.singleton("keyword"), + ExtractedField.ExtractionMethod.DOC_VALUE); assertThat(missing.value(hit), equalTo(new Object[0])); } public void testValueGivenScriptField() { SearchHit hit = new SearchHitBuilder(42).addField("single", "bar").addField("array", Arrays.asList("a", "b")).build(); - ExtractedField single = ExtractedField.newField("single", ExtractedField.ExtractionMethod.SCRIPT_FIELD); + ExtractedField single = ExtractedField.newField("single",Collections.emptySet(), + ExtractedField.ExtractionMethod.SCRIPT_FIELD); assertThat(single.value(hit), equalTo(new String[] { "bar" })); - ExtractedField array = ExtractedField.newField("array", ExtractedField.ExtractionMethod.SCRIPT_FIELD); + ExtractedField array = ExtractedField.newField("array", Collections.emptySet(), ExtractedField.ExtractionMethod.SCRIPT_FIELD); assertThat(array.value(hit), equalTo(new String[] { "a", "b" })); - ExtractedField missing = ExtractedField.newField("missing", ExtractedField.ExtractionMethod.SCRIPT_FIELD); + ExtractedField missing = ExtractedField.newField("missing", Collections.emptySet(), ExtractedField.ExtractionMethod.SCRIPT_FIELD); assertThat(missing.value(hit), equalTo(new Object[0])); } public void testValueGivenSource() { SearchHit hit = new SearchHitBuilder(42).setSource("{\"single\":\"bar\",\"array\":[\"a\",\"b\"]}").build(); - ExtractedField single = ExtractedField.newField("single", ExtractedField.ExtractionMethod.SOURCE); + ExtractedField single = ExtractedField.newField("single", Collections.singleton("text"), ExtractedField.ExtractionMethod.SOURCE); assertThat(single.value(hit), equalTo(new String[] { "bar" })); - ExtractedField array = ExtractedField.newField("array", ExtractedField.ExtractionMethod.SOURCE); + ExtractedField array = ExtractedField.newField("array", Collections.singleton("text"), ExtractedField.ExtractionMethod.SOURCE); assertThat(array.value(hit), equalTo(new String[] { "a", "b" })); - ExtractedField missing = ExtractedField.newField("missing", ExtractedField.ExtractionMethod.SOURCE); + ExtractedField missing = ExtractedField.newField("missing", Collections.singleton("text"), ExtractedField.ExtractionMethod.SOURCE); assertThat(missing.value(hit), equalTo(new Object[0])); } public void testValueGivenNestedSource() { SearchHit hit = new SearchHitBuilder(42).setSource("{\"level_1\":{\"level_2\":{\"foo\":\"bar\"}}}").build(); - ExtractedField nested = ExtractedField.newField("alias", "level_1.level_2.foo", ExtractedField.ExtractionMethod.SOURCE); + ExtractedField nested = ExtractedField.newField("alias", "level_1.level_2.foo", Collections.singleton("text"), + ExtractedField.ExtractionMethod.SOURCE); assertThat(nested.value(hit), equalTo(new String[] { "bar" })); } @@ -91,49 +97,54 @@ public class ExtractedFieldTests extends ESTestCase { } public void testValueGivenSourceAndHitWithNoSource() { - ExtractedField missing = ExtractedField.newField("missing", ExtractedField.ExtractionMethod.SOURCE); + ExtractedField missing = ExtractedField.newField("missing", Collections.singleton("text"), ExtractedField.ExtractionMethod.SOURCE); assertThat(missing.value(new SearchHitBuilder(3).build()), equalTo(new Object[0])); } public void testValueGivenMismatchingMethod() { SearchHit hit = new SearchHitBuilder(42).addField("a", 1).setSource("{\"b\":2}").build(); - ExtractedField invalidA = ExtractedField.newField("a", ExtractedField.ExtractionMethod.SOURCE); + ExtractedField invalidA = ExtractedField.newField("a", Collections.singleton("text"), ExtractedField.ExtractionMethod.SOURCE); assertThat(invalidA.value(hit), equalTo(new Object[0])); - ExtractedField validA = ExtractedField.newField("a", ExtractedField.ExtractionMethod.DOC_VALUE); + ExtractedField validA = ExtractedField.newField("a", Collections.singleton("keyword"), ExtractedField.ExtractionMethod.DOC_VALUE); assertThat(validA.value(hit), equalTo(new Integer[] { 1 })); - ExtractedField invalidB = ExtractedField.newField("b", ExtractedField.ExtractionMethod.DOC_VALUE); + ExtractedField invalidB = ExtractedField.newField("b", Collections.singleton("keyword"), + ExtractedField.ExtractionMethod.DOC_VALUE); assertThat(invalidB.value(hit), equalTo(new Object[0])); - ExtractedField validB = ExtractedField.newField("b", ExtractedField.ExtractionMethod.SOURCE); + ExtractedField validB = ExtractedField.newField("b", Collections.singleton("text"), ExtractedField.ExtractionMethod.SOURCE); assertThat(validB.value(hit), equalTo(new Integer[] { 2 })); } public void testValueGivenEmptyHit() { SearchHit hit = new SearchHitBuilder(42).build(); - ExtractedField docValue = ExtractedField.newField("a", ExtractedField.ExtractionMethod.SOURCE); + ExtractedField docValue = ExtractedField.newField("a", Collections.singleton("text"), ExtractedField.ExtractionMethod.SOURCE); assertThat(docValue.value(hit), equalTo(new Object[0])); - ExtractedField sourceField = ExtractedField.newField("b", ExtractedField.ExtractionMethod.DOC_VALUE); + ExtractedField sourceField = ExtractedField.newField("b", Collections.singleton("keyword"), + ExtractedField.ExtractionMethod.DOC_VALUE); assertThat(sourceField.value(hit), equalTo(new Object[0])); } public void testNewTimeFieldGivenSource() { - expectThrows(IllegalArgumentException.class, () -> ExtractedField.newTimeField("time", ExtractedField.ExtractionMethod.SOURCE)); + expectThrows(IllegalArgumentException.class, () -> ExtractedField.newTimeField("time", Collections.singleton("date"), + ExtractedField.ExtractionMethod.SOURCE)); } public void testValueGivenStringTimeField() { final long millis = randomLong(); final SearchHit hit = new SearchHitBuilder(randomInt()).addField("time", Long.toString(millis)).build(); - final ExtractedField timeField = ExtractedField.newTimeField("time", ExtractedField.ExtractionMethod.DOC_VALUE); + final ExtractedField timeField = ExtractedField.newTimeField("time", Collections.singleton("date"), + ExtractedField.ExtractionMethod.DOC_VALUE); assertThat(timeField.value(hit), equalTo(new Object[] { millis })); } public void testValueGivenLongTimeField() { final long millis = randomLong(); final SearchHit hit = new SearchHitBuilder(randomInt()).addField("time", millis).build(); - final ExtractedField timeField = ExtractedField.newTimeField("time", ExtractedField.ExtractionMethod.DOC_VALUE); + final ExtractedField timeField = ExtractedField.newTimeField("time", Collections.singleton("date"), + ExtractedField.ExtractionMethod.DOC_VALUE); assertThat(timeField.value(hit), equalTo(new Object[] { millis })); } @@ -141,13 +152,15 @@ public class ExtractedFieldTests extends ESTestCase { // Prior to 6.x, timestamps were simply `long` milliseconds-past-the-epoch values final long millis = randomLong(); final SearchHit hit = new SearchHitBuilder(randomInt()).addField("time", millis).build(); - final ExtractedField timeField = ExtractedField.newTimeField("time", ExtractedField.ExtractionMethod.DOC_VALUE); + final ExtractedField timeField = ExtractedField.newTimeField("time", Collections.singleton("date"), + ExtractedField.ExtractionMethod.DOC_VALUE); assertThat(timeField.value(hit), equalTo(new Object[] { millis })); } public void testValueGivenUnknownFormatTimeField() { final SearchHit hit = new SearchHitBuilder(randomInt()).addField("time", new Object()).build(); - final ExtractedField timeField = ExtractedField.newTimeField("time", ExtractedField.ExtractionMethod.DOC_VALUE); + final ExtractedField timeField = ExtractedField.newTimeField("time", Collections.singleton("date"), + ExtractedField.ExtractionMethod.DOC_VALUE); assertThat(expectThrows(IllegalStateException.class, () -> timeField.value(hit)).getMessage(), startsWith("Unexpected value for a time field")); } @@ -155,14 +168,15 @@ public class ExtractedFieldTests extends ESTestCase { public void testAliasVersusName() { SearchHit hit = new SearchHitBuilder(42).addField("a", 1).addField("b", 2).build(); - ExtractedField field = ExtractedField.newField("a", "a", ExtractedField.ExtractionMethod.DOC_VALUE); + ExtractedField field = ExtractedField.newField("a", "a", Collections.singleton("int"), + ExtractedField.ExtractionMethod.DOC_VALUE); assertThat(field.getAlias(), equalTo("a")); assertThat(field.getName(), equalTo("a")); assertThat(field.value(hit), equalTo(new Integer[] { 1 })); hit = new SearchHitBuilder(42).addField("a", 1).addField("b", 2).build(); - field = ExtractedField.newField("a", "b", ExtractedField.ExtractionMethod.DOC_VALUE); + field = ExtractedField.newField("a", "b", Collections.singleton("int"), ExtractedField.ExtractionMethod.DOC_VALUE); assertThat(field.getAlias(), equalTo("a")); assertThat(field.getName(), equalTo("b")); assertThat(field.value(hit), equalTo(new Integer[] { 2 })); @@ -170,11 +184,11 @@ public class ExtractedFieldTests extends ESTestCase { public void testGetDocValueFormat() { for (ExtractedField.ExtractionMethod method : ExtractedField.ExtractionMethod.values()) { - assertThat(ExtractedField.newField("f", method).getDocValueFormat(), equalTo(null)); + assertThat(ExtractedField.newField("f", Collections.emptySet(), method).getDocValueFormat(), equalTo(null)); } - assertThat(ExtractedField.newTimeField("doc_value_time", ExtractedField.ExtractionMethod.DOC_VALUE).getDocValueFormat(), - equalTo("epoch_millis")); - assertThat(ExtractedField.newTimeField("source_time", ExtractedField.ExtractionMethod.SCRIPT_FIELD).getDocValueFormat(), - equalTo("epoch_millis")); + assertThat(ExtractedField.newTimeField("doc_value_time", Collections.singleton("date"), + ExtractedField.ExtractionMethod.DOC_VALUE).getDocValueFormat(), equalTo("epoch_millis")); + assertThat(ExtractedField.newTimeField("source_time", Collections.emptySet(), + ExtractedField.ExtractionMethod.SCRIPT_FIELD).getDocValueFormat(), equalTo("epoch_millis")); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/datafeed/extractor/fields/ExtractedFieldsTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/datafeed/extractor/fields/ExtractedFieldsTests.java index db25f820dbb..8dd81b47eb8 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/datafeed/extractor/fields/ExtractedFieldsTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/datafeed/extractor/fields/ExtractedFieldsTests.java @@ -27,12 +27,18 @@ import static org.mockito.Mockito.when; public class ExtractedFieldsTests extends ESTestCase { public void testAllTypesOfFields() { - ExtractedField docValue1 = ExtractedField.newField("doc1", ExtractedField.ExtractionMethod.DOC_VALUE); - ExtractedField docValue2 = ExtractedField.newField("doc2", ExtractedField.ExtractionMethod.DOC_VALUE); - ExtractedField scriptField1 = ExtractedField.newField("scripted1", ExtractedField.ExtractionMethod.SCRIPT_FIELD); - ExtractedField scriptField2 = ExtractedField.newField("scripted2", ExtractedField.ExtractionMethod.SCRIPT_FIELD); - ExtractedField sourceField1 = ExtractedField.newField("src1", ExtractedField.ExtractionMethod.SOURCE); - ExtractedField sourceField2 = ExtractedField.newField("src2", ExtractedField.ExtractionMethod.SOURCE); + ExtractedField docValue1 = ExtractedField.newField("doc1", Collections.singleton("keyword"), + ExtractedField.ExtractionMethod.DOC_VALUE); + ExtractedField docValue2 = ExtractedField.newField("doc2", Collections.singleton("ip"), + ExtractedField.ExtractionMethod.DOC_VALUE); + ExtractedField scriptField1 = ExtractedField.newField("scripted1", Collections.emptySet(), + ExtractedField.ExtractionMethod.SCRIPT_FIELD); + ExtractedField scriptField2 = ExtractedField.newField("scripted2", Collections.emptySet(), + ExtractedField.ExtractionMethod.SCRIPT_FIELD); + ExtractedField sourceField1 = ExtractedField.newField("src1", Collections.singleton("text"), + ExtractedField.ExtractionMethod.SOURCE); + ExtractedField sourceField2 = ExtractedField.newField("src2", Collections.singleton("text"), + ExtractedField.ExtractionMethod.SOURCE); ExtractedFields extractedFields = new ExtractedFields(Arrays.asList( docValue1, docValue2, scriptField1, scriptField2, sourceField1, sourceField2)); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/datafeed/extractor/fields/TimeBasedExtractedFieldsTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/datafeed/extractor/fields/TimeBasedExtractedFieldsTests.java index 6e7a3740e0a..652eb068783 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/datafeed/extractor/fields/TimeBasedExtractedFieldsTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/datafeed/extractor/fields/TimeBasedExtractedFieldsTests.java @@ -29,7 +29,8 @@ import static org.mockito.Mockito.when; public class TimeBasedExtractedFieldsTests extends ESTestCase { - private ExtractedField timeField = ExtractedField.newTimeField("time", ExtractedField.ExtractionMethod.DOC_VALUE); + private ExtractedField timeField = ExtractedField.newTimeField("time", Collections.singleton("date"), + ExtractedField.ExtractionMethod.DOC_VALUE); public void testInvalidConstruction() { expectThrows(IllegalArgumentException.class, () -> new TimeBasedExtractedFields(timeField, Collections.emptyList())); @@ -46,12 +47,18 @@ public class TimeBasedExtractedFieldsTests extends ESTestCase { } public void testAllTypesOfFields() { - ExtractedField docValue1 = ExtractedField.newField("doc1", ExtractedField.ExtractionMethod.DOC_VALUE); - ExtractedField docValue2 = ExtractedField.newField("doc2", ExtractedField.ExtractionMethod.DOC_VALUE); - ExtractedField scriptField1 = ExtractedField.newField("scripted1", ExtractedField.ExtractionMethod.SCRIPT_FIELD); - ExtractedField scriptField2 = ExtractedField.newField("scripted2", ExtractedField.ExtractionMethod.SCRIPT_FIELD); - ExtractedField sourceField1 = ExtractedField.newField("src1", ExtractedField.ExtractionMethod.SOURCE); - ExtractedField sourceField2 = ExtractedField.newField("src2", ExtractedField.ExtractionMethod.SOURCE); + ExtractedField docValue1 = ExtractedField.newField("doc1", Collections.singleton("keyword"), + ExtractedField.ExtractionMethod.DOC_VALUE); + ExtractedField docValue2 = ExtractedField.newField("doc2", Collections.singleton("float"), + ExtractedField.ExtractionMethod.DOC_VALUE); + ExtractedField scriptField1 = ExtractedField.newField("scripted1", Collections.emptySet(), + ExtractedField.ExtractionMethod.SCRIPT_FIELD); + ExtractedField scriptField2 = ExtractedField.newField("scripted2", Collections.emptySet(), + ExtractedField.ExtractionMethod.SCRIPT_FIELD); + ExtractedField sourceField1 = ExtractedField.newField("src1", Collections.singleton("text"), + ExtractedField.ExtractionMethod.SOURCE); + ExtractedField sourceField2 = ExtractedField.newField("src2", Collections.singleton("text"), + ExtractedField.ExtractionMethod.SOURCE); TimeBasedExtractedFields extractedFields = new TimeBasedExtractedFields(timeField, Arrays.asList(timeField, docValue1, docValue2, scriptField1, scriptField2, sourceField1, sourceField2)); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/datafeed/extractor/scroll/ScrollDataExtractorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/datafeed/extractor/scroll/ScrollDataExtractorTests.java index c383cf20b18..bdbe81a66a6 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/datafeed/extractor/scroll/ScrollDataExtractorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/datafeed/extractor/scroll/ScrollDataExtractorTests.java @@ -135,9 +135,11 @@ public class ScrollDataExtractorTests extends ESTestCase { capturedSearchRequests = new ArrayList<>(); capturedContinueScrollIds = new ArrayList<>(); jobId = "test-job"; - ExtractedField timeField = ExtractedField.newField("time", ExtractedField.ExtractionMethod.DOC_VALUE); + ExtractedField timeField = ExtractedField.newField("time", Collections.singleton("date"), + ExtractedField.ExtractionMethod.DOC_VALUE); extractedFields = new TimeBasedExtractedFields(timeField, - Arrays.asList(timeField, ExtractedField.newField("field_1", ExtractedField.ExtractionMethod.DOC_VALUE))); + Arrays.asList(timeField, ExtractedField.newField("field_1", Collections.singleton("keyword"), + ExtractedField.ExtractionMethod.DOC_VALUE))); indices = Arrays.asList("index-1", "index-2"); query = QueryBuilders.matchAllQuery(); scriptFields = Collections.emptyList(); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/datafeed/extractor/scroll/SearchHitToJsonProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/datafeed/extractor/scroll/SearchHitToJsonProcessorTests.java index 41a74814461..d2befb407ae 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/datafeed/extractor/scroll/SearchHitToJsonProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/datafeed/extractor/scroll/SearchHitToJsonProcessorTests.java @@ -16,16 +16,21 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.Arrays; +import java.util.Collections; import static org.hamcrest.Matchers.equalTo; public class SearchHitToJsonProcessorTests extends ESTestCase { public void testProcessGivenSingleHit() throws IOException { - ExtractedField timeField = ExtractedField.newField("time", ExtractedField.ExtractionMethod.DOC_VALUE); - ExtractedField missingField = ExtractedField.newField("missing", ExtractedField.ExtractionMethod.DOC_VALUE); - ExtractedField singleField = ExtractedField.newField("single", ExtractedField.ExtractionMethod.DOC_VALUE); - ExtractedField arrayField = ExtractedField.newField("array", ExtractedField.ExtractionMethod.DOC_VALUE); + ExtractedField timeField = ExtractedField.newField("time", Collections.singleton("date"), + ExtractedField.ExtractionMethod.DOC_VALUE); + ExtractedField missingField = ExtractedField.newField("missing", Collections.singleton("float"), + ExtractedField.ExtractionMethod.DOC_VALUE); + ExtractedField singleField = ExtractedField.newField("single", Collections.singleton("keyword"), + ExtractedField.ExtractionMethod.DOC_VALUE); + ExtractedField arrayField = ExtractedField.newField("array", Collections.singleton("keyword"), + ExtractedField.ExtractionMethod.DOC_VALUE); TimeBasedExtractedFields extractedFields = new TimeBasedExtractedFields(timeField, Arrays.asList(timeField, missingField, singleField, arrayField)); @@ -41,10 +46,14 @@ public class SearchHitToJsonProcessorTests extends ESTestCase { } public void testProcessGivenMultipleHits() throws IOException { - ExtractedField timeField = ExtractedField.newField("time", ExtractedField.ExtractionMethod.DOC_VALUE); - ExtractedField missingField = ExtractedField.newField("missing", ExtractedField.ExtractionMethod.DOC_VALUE); - ExtractedField singleField = ExtractedField.newField("single", ExtractedField.ExtractionMethod.DOC_VALUE); - ExtractedField arrayField = ExtractedField.newField("array", ExtractedField.ExtractionMethod.DOC_VALUE); + ExtractedField timeField = ExtractedField.newField("time", Collections.singleton("date"), + ExtractedField.ExtractionMethod.DOC_VALUE); + ExtractedField missingField = ExtractedField.newField("missing", Collections.singleton("float"), + ExtractedField.ExtractionMethod.DOC_VALUE); + ExtractedField singleField = ExtractedField.newField("single", Collections.singleton("keyword"), + ExtractedField.ExtractionMethod.DOC_VALUE); + ExtractedField arrayField = ExtractedField.newField("array", Collections.singleton("keyword"), + ExtractedField.ExtractionMethod.DOC_VALUE); TimeBasedExtractedFields extractedFields = new TimeBasedExtractedFields(timeField, Arrays.asList(timeField, missingField, singleField, arrayField)); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java index ffd53e5576f..b456de7b637 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java @@ -71,8 +71,8 @@ public class DataFrameDataExtractorTests extends ESTestCase { indices = Arrays.asList("index-1", "index-2"); query = QueryBuilders.matchAllQuery(); extractedFields = new ExtractedFields(Arrays.asList( - ExtractedField.newField("field_1", ExtractedField.ExtractionMethod.DOC_VALUE), - ExtractedField.newField("field_2", ExtractedField.ExtractionMethod.DOC_VALUE))); + ExtractedField.newField("field_1", Collections.singleton("keyword"), ExtractedField.ExtractionMethod.DOC_VALUE), + ExtractedField.newField("field_2", Collections.singleton("keyword"), ExtractedField.ExtractionMethod.DOC_VALUE))); scrollSize = 1000; headers = Collections.emptyMap(); @@ -288,8 +288,8 @@ public class DataFrameDataExtractorTests extends ESTestCase { public void testIncludeSourceIsFalseAndAtLeastOneSourceField() throws IOException { extractedFields = new ExtractedFields(Arrays.asList( - ExtractedField.newField("field_1", ExtractedField.ExtractionMethod.DOC_VALUE), - ExtractedField.newField("field_2", ExtractedField.ExtractionMethod.SOURCE))); + ExtractedField.newField("field_1", Collections.singleton("keyword"), ExtractedField.ExtractionMethod.DOC_VALUE), + ExtractedField.newField("field_2", Collections.singleton("text"), ExtractedField.ExtractionMethod.SOURCE))); TestExtractor dataExtractor = createExtractor(false); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java index 1345a1fe128..5f781538bec 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedField; import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedFields; @@ -38,11 +39,11 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { private static final String RESULTS_FIELD = "ml"; public void testDetect_GivenFloatField() { - FieldCapabilitiesResponse fieldCapabilities= new MockFieldCapsResponseBuilder() + FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() .addAggregatableField("some_float", "float").build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildAnalyticsConfig(), false, 100, fieldCapabilities); + SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities); ExtractedFields extractedFields = extractedFieldsDetector.detect(); List allFields = extractedFields.getAllFields(); @@ -52,12 +53,12 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { } public void testDetect_GivenNumericFieldWithMultipleTypes() { - FieldCapabilitiesResponse fieldCapabilities= new MockFieldCapsResponseBuilder() + FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() .addAggregatableField("some_number", "long", "integer", "short", "byte", "double", "float", "half_float", "scaled_float") .build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildAnalyticsConfig(), false, 100, fieldCapabilities); + SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities); ExtractedFields extractedFields = extractedFieldsDetector.detect(); List allFields = extractedFields.getAllFields(); @@ -67,36 +68,36 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { } public void testDetect_GivenNonNumericField() { - FieldCapabilitiesResponse fieldCapabilities= new MockFieldCapsResponseBuilder() + FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() .addAggregatableField("some_keyword", "keyword").build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildAnalyticsConfig(), false, 100, fieldCapabilities); + SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect()); assertThat(e.getMessage(), equalTo("No compatible fields could be detected in index [source_index]")); } - public void testDetect_GivenFieldWithNumericAndNonNumericTypes() { - FieldCapabilitiesResponse fieldCapabilities= new MockFieldCapsResponseBuilder() + public void testDetect_GivenOutlierDetectionAndFieldWithNumericAndNonNumericTypes() { + FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() .addAggregatableField("indecisive_field", "float", "keyword").build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildAnalyticsConfig(), false, 100, fieldCapabilities); + SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect()); assertThat(e.getMessage(), equalTo("No compatible fields could be detected in index [source_index]")); } - public void testDetect_GivenMultipleFields() { - FieldCapabilitiesResponse fieldCapabilities= new MockFieldCapsResponseBuilder() + public void testDetect_GivenOutlierDetectionAndMultipleFields() { + FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() .addAggregatableField("some_float", "float") .addAggregatableField("some_long", "long") .addAggregatableField("some_keyword", "keyword") .build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildAnalyticsConfig(), false, 100, fieldCapabilities); + SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities); ExtractedFields extractedFields = extractedFieldsDetector.detect(); List allFields = extractedFields.getAllFields(); @@ -107,12 +108,46 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { contains(equalTo(ExtractedField.ExtractionMethod.DOC_VALUE))); } + public void testDetect_GivenRegressionAndMultipleFields() { + FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() + .addAggregatableField("some_float", "float") + .addAggregatableField("some_long", "long") + .addAggregatableField("some_keyword", "keyword") + .addAggregatableField("foo", "keyword") + .build(); + + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + SOURCE_INDEX, buildRegressionConfig("foo"), false, 100, fieldCapabilities); + ExtractedFields extractedFields = extractedFieldsDetector.detect(); + + List allFields = extractedFields.getAllFields(); + assertThat(allFields.size(), equalTo(4)); + assertThat(allFields.stream().map(ExtractedField::getName).collect(Collectors.toList()), + contains("foo", "some_float", "some_keyword", "some_long")); + assertThat(allFields.stream().map(ExtractedField::getExtractionMethod).collect(Collectors.toSet()), + contains(equalTo(ExtractedField.ExtractionMethod.DOC_VALUE))); + } + + public void testDetect_GivenRegressionAndRequiredFieldMissing() { + FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() + .addAggregatableField("some_float", "float") + .addAggregatableField("some_long", "long") + .addAggregatableField("some_keyword", "keyword") + .build(); + + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + SOURCE_INDEX, buildRegressionConfig("foo"), false, 100, fieldCapabilities); + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect()); + + assertThat(e.getMessage(), equalTo("required fields [foo] are missing")); + } + public void testDetect_GivenIgnoredField() { - FieldCapabilitiesResponse fieldCapabilities= new MockFieldCapsResponseBuilder() + FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() .addAggregatableField("_id", "float").build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildAnalyticsConfig(), false, 100, fieldCapabilities); + SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect()); assertThat(e.getMessage(), equalTo("No compatible fields could be detected in index [source_index]")); @@ -134,7 +169,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { FieldCapabilitiesResponse fieldCapabilities = mockFieldCapsResponseBuilder.build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildAnalyticsConfig(), false, 100, fieldCapabilities); + SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities); ExtractedFields extractedFields = extractedFieldsDetector.detect(); List extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName) @@ -151,7 +186,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { FetchSourceContext desiredFields = new FetchSourceContext(true, new String[]{"your_field1", "my*"}, new String[0]); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildAnalyticsConfig(desiredFields), false, 100, fieldCapabilities); + SOURCE_INDEX, buildOutlierDetectionConfig(desiredFields), false, 100, fieldCapabilities); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect()); assertThat(e.getMessage(), equalTo("No field [your_field1] could be detected")); @@ -166,7 +201,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { FetchSourceContext desiredFields = new FetchSourceContext(true, new String[0], new String[]{"my_*"}); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildAnalyticsConfig(desiredFields), false, 100, fieldCapabilities); + SOURCE_INDEX, buildOutlierDetectionConfig(desiredFields), false, 100, fieldCapabilities); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect()); assertThat(e.getMessage(), equalTo("No compatible fields could be detected in index [source_index]")); } @@ -182,7 +217,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { FetchSourceContext desiredFields = new FetchSourceContext(true, new String[]{"your*", "my_*"}, new String[]{"*nope"}); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildAnalyticsConfig(desiredFields), false, 100, fieldCapabilities); + SOURCE_INDEX, buildOutlierDetectionConfig(desiredFields), false, 100, fieldCapabilities); ExtractedFields extractedFields = extractedFieldsDetector.detect(); List extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName) @@ -199,7 +234,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { .build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildAnalyticsConfig(), false, 100, fieldCapabilities); + SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect()); assertThat(e.getMessage(), equalTo("A field that matches the dest.results_field [ml] already exists; " + @@ -215,7 +250,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { .build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildAnalyticsConfig(), true, 100, fieldCapabilities); + SOURCE_INDEX, buildOutlierDetectionConfig(), true, 100, fieldCapabilities); ExtractedFields extractedFields = extractedFieldsDetector.detect(); List extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName) @@ -232,7 +267,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { .build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildAnalyticsConfig(), true, 4, fieldCapabilities); + SOURCE_INDEX, buildOutlierDetectionConfig(), true, 4, fieldCapabilities); ExtractedFields extractedFields = extractedFieldsDetector.detect(); List extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName) @@ -251,7 +286,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { .build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildAnalyticsConfig(), true, 3, fieldCapabilities); + SOURCE_INDEX, buildOutlierDetectionConfig(), true, 3, fieldCapabilities); ExtractedFields extractedFields = extractedFieldsDetector.detect(); List extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName) @@ -270,7 +305,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { .build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildAnalyticsConfig(), true, 2, fieldCapabilities); + SOURCE_INDEX, buildOutlierDetectionConfig(), true, 2, fieldCapabilities); ExtractedFields extractedFields = extractedFieldsDetector.detect(); List extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName) @@ -280,11 +315,11 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { contains(equalTo(ExtractedField.ExtractionMethod.SOURCE))); } - private static DataFrameAnalyticsConfig buildAnalyticsConfig() { - return buildAnalyticsConfig(null); + private static DataFrameAnalyticsConfig buildOutlierDetectionConfig() { + return buildOutlierDetectionConfig(null); } - private static DataFrameAnalyticsConfig buildAnalyticsConfig(FetchSourceContext analyzedFields) { + private static DataFrameAnalyticsConfig buildOutlierDetectionConfig(FetchSourceContext analyzedFields) { return new DataFrameAnalyticsConfig.Builder("foo") .setSource(new DataFrameAnalyticsSource(SOURCE_INDEX, null)) .setDest(new DataFrameAnalyticsDest(DEST_INDEX, null)) @@ -293,6 +328,19 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { .build(); } + private static DataFrameAnalyticsConfig buildRegressionConfig(String dependentVariable) { + return buildRegressionConfig(dependentVariable, null); + } + + private static DataFrameAnalyticsConfig buildRegressionConfig(String dependentVariable, FetchSourceContext analyzedFields) { + return new DataFrameAnalyticsConfig.Builder("foo") + .setSource(new DataFrameAnalyticsSource(SOURCE_INDEX, null)) + .setDest(new DataFrameAnalyticsDest(DEST_INDEX, null)) + .setAnalyzedFields(analyzedFields) + .setAnalysis(new Regression(dependentVariable)) + .build(); + } + private static class MockFieldCapsResponseBuilder { private final Map> fieldCaps = new HashMap<>(); diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml index 772c48e5474..253790878c5 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml @@ -607,7 +607,11 @@ setup: "dest": { "index": "index-bar_dest" }, - "analysis": {"outlier_detection":{}} + "analysis": { + "regression":{ + "dependent_variable": "to_predict" + } + } } - match: { id: "bar" } @@ -768,7 +772,11 @@ setup: "dest": { "index": "index-bar_dest" }, - "analysis": {"outlier_detection":{}} + "analysis": { + "regression":{ + "dependent_variable": "to_predict" + } + } } - match: { id: "bar" } @@ -930,3 +938,247 @@ setup: xpack.ml.max_model_memory_limit: null - match: {transient: {}} +--- +"Test put regression given dependent_variable is not defined": + + - do: + catch: /parse_exception/ + ml.put_data_frame_analytics: + id: "regression-without-dependent-variable" + body: > + { + "source": { + "index": "index-source" + }, + "dest": { + "index": "index-dest" + }, + "analysis": { + "regression": {} + } + } + +--- +"Test put regression given negative lambda": + + - do: + catch: /\[lambda\] must be a non-negative double/ + ml.put_data_frame_analytics: + id: "regression-negative-lambda" + body: > + { + "source": { + "index": "index-source" + }, + "dest": { + "index": "index-dest" + }, + "analysis": { + "regression": { + "dependent_variable": "foo", + "lambda": -1.0 + } + } + } + +--- +"Test put regression given negative gamma": + + - do: + catch: /\[gamma\] must be a non-negative double/ + ml.put_data_frame_analytics: + id: "regression-negative-gamma" + body: > + { + "source": { + "index": "index-source" + }, + "dest": { + "index": "index-dest" + }, + "analysis": { + "regression": { + "dependent_variable": "foo", + "gamma": -1.0 + } + } + } + +--- +"Test put regression given eta less than 1e-3": + + - do: + catch: /\[eta\] must be a double in \[0.001, 1\]/ + ml.put_data_frame_analytics: + id: "regression-eta-greater-less-than-valid" + body: > + { + "source": { + "index": "index-source" + }, + "dest": { + "index": "index-dest" + }, + "analysis": { + "regression": { + "dependent_variable": "foo", + "eta": 0.0009 + } + } + } + +--- +"Test put regression given eta greater than one": + + - do: + catch: /\[eta\] must be a double in \[0.001, 1\]/ + ml.put_data_frame_analytics: + id: "regression-eta-greater-than-one" + body: > + { + "source": { + "index": "index-source" + }, + "dest": { + "index": "index-dest" + }, + "analysis": { + "regression": { + "dependent_variable": "foo", + "eta": 1.00001 + } + } + } + +--- +"Test put regression given maximum_number_trees is zero": + + - do: + catch: /\[maximum_number_trees\] must be an integer in \[1, 2000\]/ + ml.put_data_frame_analytics: + id: "regression-maximum-number-trees-is-zero" + body: > + { + "source": { + "index": "index-source" + }, + "dest": { + "index": "index-dest" + }, + "analysis": { + "regression": { + "dependent_variable": "foo", + "maximum_number_trees": 0 + } + } + } + +--- +"Test put regression given maximum_number_trees is greater than 2k": + + - do: + catch: /\[maximum_number_trees\] must be an integer in \[1, 2000\]/ + ml.put_data_frame_analytics: + id: "regression-maximum-number-trees-greater-than-2k" + body: > + { + "source": { + "index": "index-source" + }, + "dest": { + "index": "index-dest" + }, + "analysis": { + "regression": { + "dependent_variable": "foo", + "maximum_number_trees": 2001 + } + } + } + +--- +"Test put regression given feature_bag_fraction is negative": + + - do: + catch: /\[feature_bag_fraction\] must be a double in \(0, 1\]/ + ml.put_data_frame_analytics: + id: "regression-feature-bag-fraction-is-negative" + body: > + { + "source": { + "index": "index-source" + }, + "dest": { + "index": "index-dest" + }, + "analysis": { + "regression": { + "dependent_variable": "foo", + "feature_bag_fraction": -0.0001 + } + } + } + +--- +"Test put regression given feature_bag_fraction is greater than one": + + - do: + catch: /\[feature_bag_fraction\] must be a double in \(0, 1\]/ + ml.put_data_frame_analytics: + id: "regression-feature-bag-fraction-is-greater-than-one" + body: > + { + "source": { + "index": "index-source" + }, + "dest": { + "index": "index-dest" + }, + "analysis": { + "regression": { + "dependent_variable": "foo", + "feature_bag_fraction": 1.0001 + } + } + } + +--- +"Test put regression given valid": + + - do: + ml.put_data_frame_analytics: + id: "valid-regression" + body: > + { + "source": { + "index": "index-source" + }, + "dest": { + "index": "index-dest" + }, + "analysis": { + "regression": { + "dependent_variable": "foo", + "lambda": 3.14, + "gamma": 0.42, + "eta": 0.5, + "maximum_number_trees": 400, + "feature_bag_fraction": 0.3 + } + } + } + - match: { id: "valid-regression" } + - match: { source.index: ["index-source"] } + - match: { dest.index: "index-dest" } + - match: { analysis: { + "regression":{ + "dependent_variable": "foo", + "lambda": 3.14, + "gamma": 0.42, + "eta": 0.5, + "maximum_number_trees": 400, + "feature_bag_fraction": 0.3 + } + }} + - is_true: create_time + - is_true: version diff --git a/x-pack/qa/full-cluster-restart/src/test/java/org/elasticsearch/xpack/restart/MlConfigIndexMappingsFullClusterRestartIT.java b/x-pack/qa/full-cluster-restart/src/test/java/org/elasticsearch/xpack/restart/MlConfigIndexMappingsFullClusterRestartIT.java index 4bbef475f99..8b7f9d06bf5 100644 --- a/x-pack/qa/full-cluster-restart/src/test/java/org/elasticsearch/xpack/restart/MlConfigIndexMappingsFullClusterRestartIT.java +++ b/x-pack/qa/full-cluster-restart/src/test/java/org/elasticsearch/xpack/restart/MlConfigIndexMappingsFullClusterRestartIT.java @@ -11,15 +11,22 @@ import org.elasticsearch.client.Response; import org.elasticsearch.client.ResponseException; import org.elasticsearch.client.WarningFailureException; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.json.JsonXContent; import org.elasticsearch.common.xcontent.support.XContentMapValues; import org.elasticsearch.upgrades.AbstractFullClusterRestartTestCase; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig; import org.elasticsearch.xpack.core.ml.job.config.DataDescription; import org.elasticsearch.xpack.core.ml.job.config.Detector; import org.elasticsearch.xpack.core.ml.job.config.Job; +import org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings; import org.elasticsearch.xpack.test.rest.XPackRestTestConstants; import org.elasticsearch.xpack.test.rest.XPackRestTestHelper; import org.junit.Before; @@ -28,12 +35,12 @@ import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.Base64; import java.util.Collections; -import java.util.HashMap; import java.util.List; import java.util.Map; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; import static org.hamcrest.Matchers.nullValue; public class MlConfigIndexMappingsFullClusterRestartIT extends AbstractFullClusterRestartTestCase { @@ -41,14 +48,23 @@ public class MlConfigIndexMappingsFullClusterRestartIT extends AbstractFullClust private static final String OLD_CLUSTER_JOB_ID = "ml-config-mappings-old-cluster-job"; private static final String NEW_CLUSTER_JOB_ID = "ml-config-mappings-new-cluster-job"; - private static final Map EXPECTED_DATA_FRAME_ANALYSIS_MAPPINGS = - mapOf( - "properties", mapOf( - "outlier_detection", mapOf( - "properties", mapOf( - "method", mapOf("type", "keyword"), - "n_neighbors", mapOf("type", "integer"), - "feature_influence_threshold", mapOf("type", "double"))))); + private static final Map EXPECTED_DATA_FRAME_ANALYSIS_MAPPINGS = getDataFrameAnalysisMappings(); + + @SuppressWarnings("unchecked") + private static Map getDataFrameAnalysisMappings() { + try (XContentBuilder builder = JsonXContent.contentBuilder()) { + builder.startObject(); + ElasticsearchMappings.addDataFrameAnalyticsFields(builder); + builder.endObject(); + + Map asMap = builder.generator().contentType().xContent().createParser( + NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, BytesReference.bytes(builder).streamInput()).map(); + return (Map) asMap.get(DataFrameAnalyticsConfig.ANALYSIS.getPreferredName()); + } catch (IOException e) { + fail("Failed to initialize expected data frame analysis mappings"); + } + return null; + } @Override protected Settings restClientSettings() { @@ -71,8 +87,8 @@ public class MlConfigIndexMappingsFullClusterRestartIT extends AbstractFullClust // trigger .ml-config index creation createAnomalyDetectorJob(OLD_CLUSTER_JOB_ID); if (getOldClusterVersion().onOrAfter(Version.V_7_3_0)) { - // .ml-config has correct mappings from the start - assertThat(mappingsForDataFrameAnalysis(), is(equalTo(EXPECTED_DATA_FRAME_ANALYSIS_MAPPINGS))); + // .ml-config has mappings for analytics as the feature was introduced in 7.3.0 + assertThat(mappingsForDataFrameAnalysis(), is(notNullValue())); } else { // .ml-config does not yet have correct mappings, it will need an update after cluster is upgraded assertThat(mappingsForDataFrameAnalysis(), is(nullValue())); @@ -125,18 +141,4 @@ public class MlConfigIndexMappingsFullClusterRestartIT extends AbstractFullClust mappings = (Map) XContentMapValues.extractValue(mappings, "properties", "analysis"); return mappings; } - - private static Map mapOf(K k1, V v1) { - Map map = new HashMap<>(); - map.put(k1, v1); - return map; - } - - private static Map mapOf(K k1, V v1, K k2, V v2, K k3, V v3) { - Map map = new HashMap<>(); - map.put(k1, v1); - map.put(k2, v2); - map.put(k3, v3); - return map; - } }