diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java index 287e511d64c..b30b0f3cc0c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java @@ -147,6 +147,7 @@ import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState; import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis; import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.AucRoc; @@ -454,6 +455,7 @@ public class XPackClientPlugin extends Plugin implements ActionPlugin, NetworkPl MachineLearningFeatureSetUsage::new), // ML - Data frame analytics new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, OutlierDetection.NAME.getPreferredName(), OutlierDetection::new), + new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, Regression.NAME.getPreferredName(), Regression::new), // ML - Data frame evaluation new NamedWriteableRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME.getPreferredName(), BinarySoftClassification::new), diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java index f21533d9176..47d0f96194a 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java @@ -9,8 +9,22 @@ import org.elasticsearch.common.io.stream.NamedWriteable; import org.elasticsearch.common.xcontent.ToXContentObject; import java.util.Map; +import java.util.Set; public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable { + /** + * @return The analysis parameters as a map + */ Map getParams(); + + /** + * @return {@code true} if this analysis supports fields with categorical values (i.e. text, keyword, ip) + */ + boolean supportsCategoricalFields(); + + /** + * @return The set of fields that analyzed documents must have for the analysis to operate + */ + Set getRequiredFields(); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/MlDataFrameAnalysisNamedXContentProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/MlDataFrameAnalysisNamedXContentProvider.java index a48a23e4a83..e33a7748592 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/MlDataFrameAnalysisNamedXContentProvider.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/MlDataFrameAnalysisNamedXContentProvider.java @@ -22,6 +22,10 @@ public class MlDataFrameAnalysisNamedXContentProvider implements NamedXContentPr boolean ignoreUnknownFields = (boolean) c; return OutlierDetection.fromXContent(p, ignoreUnknownFields); })); + namedXContent.add(new NamedXContentRegistry.Entry(DataFrameAnalysis.class, Regression.NAME, (p, c) -> { + boolean ignoreUnknownFields = (boolean) c; + return Regression.fromXContent(p, ignoreUnknownFields); + })); return namedXContent; } @@ -31,6 +35,8 @@ public class MlDataFrameAnalysisNamedXContentProvider implements NamedXContentPr namedWriteables.add(new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, OutlierDetection.NAME.getPreferredName(), OutlierDetection::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, Regression.NAME.getPreferredName(), + Regression::new)); return namedWriteables; } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java index e6891116ad6..35b3b5d3e95 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java @@ -16,10 +16,12 @@ import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; +import java.util.Collections; import java.util.HashMap; import java.util.Locale; import java.util.Map; import java.util.Objects; +import java.util.Set; public class OutlierDetection implements DataFrameAnalysis { @@ -152,6 +154,16 @@ public class OutlierDetection implements DataFrameAnalysis { return params; } + @Override + public boolean supportsCategoricalFields() { + return false; + } + + @Override + public Set getRequiredFields() { + return Collections.emptySet(); + } + public enum Method { LOF, LDOF, DISTANCE_KTH_NN, DISTANCE_KNN; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java new file mode 100644 index 00000000000..a6b7c983a29 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java @@ -0,0 +1,205 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.analyses; + +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import java.util.Set; + +public class Regression implements DataFrameAnalysis { + + public static final ParseField NAME = new ParseField("regression"); + + public static final ParseField DEPENDENT_VARIABLE = new ParseField("dependent_variable"); + public static final ParseField LAMBDA = new ParseField("lambda"); + public static final ParseField GAMMA = new ParseField("gamma"); + public static final ParseField ETA = new ParseField("eta"); + public static final ParseField MAXIMUM_NUMBER_TREES = new ParseField("maximum_number_trees"); + public static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction"); + public static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name"); + + private static final ConstructingObjectParser LENIENT_PARSER = createParser(true); + private static final ConstructingObjectParser STRICT_PARSER = createParser(false); + + private static ConstructingObjectParser createParser(boolean lenient) { + ConstructingObjectParser parser = new ConstructingObjectParser<>(NAME.getPreferredName(), lenient, + a -> new Regression((String) a[0], (Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5], (String) a[6])); + parser.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE); + parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), LAMBDA); + parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), GAMMA); + parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), ETA); + parser.declareInt(ConstructingObjectParser.optionalConstructorArg(), MAXIMUM_NUMBER_TREES); + parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), FEATURE_BAG_FRACTION); + parser.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME); + return parser; + } + + public static Regression fromXContent(XContentParser parser, boolean ignoreUnknownFields) { + return ignoreUnknownFields ? LENIENT_PARSER.apply(parser, null) : STRICT_PARSER.apply(parser, null); + } + + private final String dependentVariable; + private final Double lambda; + private final Double gamma; + private final Double eta; + private final Integer maximumNumberTrees; + private final Double featureBagFraction; + private final String predictionFieldName; + + public Regression(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta, + @Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, @Nullable String predictionFieldName) { + this.dependentVariable = Objects.requireNonNull(dependentVariable); + + if (lambda != null && lambda < 0) { + throw ExceptionsHelper.badRequestException("[{}] must be a non-negative double", LAMBDA.getPreferredName()); + } + this.lambda = lambda; + + if (gamma != null && gamma < 0) { + throw ExceptionsHelper.badRequestException("[{}] must be a non-negative double", GAMMA.getPreferredName()); + } + this.gamma = gamma; + + if (eta != null && (eta < 0.001 || eta > 1)) { + throw ExceptionsHelper.badRequestException("[{}] must be a double in [0.001, 1]", ETA.getPreferredName()); + } + this.eta = eta; + + if (maximumNumberTrees != null && (maximumNumberTrees <= 0 || maximumNumberTrees > 2000)) { + throw ExceptionsHelper.badRequestException("[{}] must be an integer in [1, 2000]", MAXIMUM_NUMBER_TREES.getPreferredName()); + } + this.maximumNumberTrees = maximumNumberTrees; + + if (featureBagFraction != null && (featureBagFraction <= 0 || featureBagFraction > 1.0)) { + throw ExceptionsHelper.badRequestException("[{}] must be a double in (0, 1]", FEATURE_BAG_FRACTION.getPreferredName()); + } + this.featureBagFraction = featureBagFraction; + + this.predictionFieldName = predictionFieldName; + } + + public Regression(String dependentVariable) { + this(dependentVariable, null, null, null, null, null, null); + } + + public Regression(StreamInput in) throws IOException { + dependentVariable = in.readString(); + lambda = in.readOptionalDouble(); + gamma = in.readOptionalDouble(); + eta = in.readOptionalDouble(); + maximumNumberTrees = in.readOptionalVInt(); + featureBagFraction = in.readOptionalDouble(); + predictionFieldName = in.readOptionalString(); + } + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(dependentVariable); + out.writeOptionalDouble(lambda); + out.writeOptionalDouble(gamma); + out.writeOptionalDouble(eta); + out.writeOptionalVInt(maximumNumberTrees); + out.writeOptionalDouble(featureBagFraction); + out.writeOptionalString(predictionFieldName); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable); + if (lambda != null) { + builder.field(LAMBDA.getPreferredName(), lambda); + } + if (gamma != null) { + builder.field(GAMMA.getPreferredName(), gamma); + } + if (eta != null) { + builder.field(ETA.getPreferredName(), eta); + } + if (maximumNumberTrees != null) { + builder.field(MAXIMUM_NUMBER_TREES.getPreferredName(), maximumNumberTrees); + } + if (featureBagFraction != null) { + builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction); + } + if (predictionFieldName != null) { + builder.field(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName); + } + builder.endObject(); + return builder; + } + + @Override + public Map getParams() { + Map params = new HashMap<>(); + params.put(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable); + if (lambda != null) { + params.put(LAMBDA.getPreferredName(), lambda); + } + if (gamma != null) { + params.put(GAMMA.getPreferredName(), gamma); + } + if (eta != null) { + params.put(ETA.getPreferredName(), eta); + } + if (maximumNumberTrees != null) { + params.put(MAXIMUM_NUMBER_TREES.getPreferredName(), maximumNumberTrees); + } + if (featureBagFraction != null) { + params.put(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction); + } + if (predictionFieldName != null) { + params.put(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName); + } + return params; + } + + @Override + public boolean supportsCategoricalFields() { + return true; + } + + @Override + public Set getRequiredFields() { + return Collections.singleton(dependentVariable); + } + + @Override + public int hashCode() { + return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Regression that = (Regression) o; + return Objects.equals(dependentVariable, that.dependentVariable) + && Objects.equals(lambda, that.lambda) + && Objects.equals(gamma, that.gamma) + && Objects.equals(eta, that.eta) + && Objects.equals(maximumNumberTrees, that.maximumNumberTrees) + && Objects.equals(featureBagFraction, that.featureBagFraction) + && Objects.equals(predictionFieldName, that.predictionFieldName); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java index baf655a280d..11674bf26f4 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java @@ -31,6 +31,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig; import org.elasticsearch.xpack.core.ml.job.config.AnalysisLimits; import org.elasticsearch.xpack.core.ml.job.config.DataDescription; @@ -443,6 +444,31 @@ public class ElasticsearchMappings { .endObject() .endObject() .endObject() + .startObject(Regression.NAME.getPreferredName()) + .startObject(PROPERTIES) + .startObject(Regression.DEPENDENT_VARIABLE.getPreferredName()) + .field(TYPE, KEYWORD) + .endObject() + .startObject(Regression.LAMBDA.getPreferredName()) + .field(TYPE, DOUBLE) + .endObject() + .startObject(Regression.GAMMA.getPreferredName()) + .field(TYPE, DOUBLE) + .endObject() + .startObject(Regression.ETA.getPreferredName()) + .field(TYPE, DOUBLE) + .endObject() + .startObject(Regression.MAXIMUM_NUMBER_TREES.getPreferredName()) + .field(TYPE, INTEGER) + .endObject() + .startObject(Regression.FEATURE_BAG_FRACTION.getPreferredName()) + .field(TYPE, DOUBLE) + .endObject() + .startObject(Regression.PREDICTION_FIELD_NAME.getPreferredName()) + .field(TYPE, KEYWORD) + .endObject() + .endObject() + .endObject() .endObject() .endObject() // re-used: CREATE_TIME diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java index 76860e28481..92583693af2 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java @@ -14,6 +14,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig; import org.elasticsearch.xpack.core.ml.job.config.AnalysisLimits; import org.elasticsearch.xpack.core.ml.job.config.DataDescription; @@ -299,6 +300,14 @@ public final class ReservedFieldNames { OutlierDetection.N_NEIGHBORS.getPreferredName(), OutlierDetection.METHOD.getPreferredName(), OutlierDetection.FEATURE_INFLUENCE_THRESHOLD.getPreferredName(), + Regression.NAME.getPreferredName(), + Regression.DEPENDENT_VARIABLE.getPreferredName(), + Regression.LAMBDA.getPreferredName(), + Regression.GAMMA.getPreferredName(), + Regression.ETA.getPreferredName(), + Regression.MAXIMUM_NUMBER_TREES.getPreferredName(), + Regression.FEATURE_BAG_FRACTION.getPreferredName(), + Regression.PREDICTION_FIELD_NAME.getPreferredName(), ElasticsearchMappings.CONFIG_TYPE, diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java new file mode 100644 index 00000000000..e6a3dbbe8c2 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java @@ -0,0 +1,100 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.analyses; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractSerializingTestCase; + +import java.io.IOException; + +import static org.hamcrest.Matchers.equalTo; + +public class RegressionTests extends AbstractSerializingTestCase { + + @Override + protected Regression doParseInstance(XContentParser parser) throws IOException { + return Regression.fromXContent(parser, false); + } + + @Override + protected Regression createTestInstance() { + return createRandom(); + } + + public static Regression createRandom() { + Double lambda = randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, true); + Double gamma = randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, true); + Double eta = randomBoolean() ? null : randomDoubleBetween(0.001, 1.0, true); + Integer maximumNumberTrees = randomBoolean() ? null : randomIntBetween(1, 2000); + Double featureBagFraction = randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false); + String predictionFieldName = randomBoolean() ? null : randomAlphaOfLength(10); + return new Regression(randomAlphaOfLength(10), lambda, gamma, eta, maximumNumberTrees, featureBagFraction, + predictionFieldName); + } + + @Override + protected Writeable.Reader instanceReader() { + return Regression::new; + } + + public void testRegression_GivenNegativeLambda() { + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, + () -> new Regression("foo", -0.00001, 0.0, 0.5, 500, 0.3, "result")); + + assertThat(e.getMessage(), equalTo("[lambda] must be a non-negative double")); + } + + public void testRegression_GivenNegativeGamma() { + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, + () -> new Regression("foo", 0.0, -0.00001, 0.5, 500, 0.3, "result")); + + assertThat(e.getMessage(), equalTo("[gamma] must be a non-negative double")); + } + + public void testRegression_GivenEtaIsZero() { + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, + () -> new Regression("foo", 0.0, 0.0, 0.0, 500, 0.3, "result")); + + assertThat(e.getMessage(), equalTo("[eta] must be a double in [0.001, 1]")); + } + + public void testRegression_GivenEtaIsGreaterThanOne() { + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, + () -> new Regression("foo", 0.0, 0.0, 1.00001, 500, 0.3, "result")); + + assertThat(e.getMessage(), equalTo("[eta] must be a double in [0.001, 1]")); + } + + public void testRegression_GivenMaximumNumberTreesIsZero() { + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, + () -> new Regression("foo", 0.0, 0.0, 0.5, 0, 0.3, "result")); + + assertThat(e.getMessage(), equalTo("[maximum_number_trees] must be an integer in [1, 2000]")); + } + + public void testRegression_GivenMaximumNumberTreesIsGreaterThan2k() { + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, + () -> new Regression("foo", 0.0, 0.0, 0.5, 2001, 0.3, "result")); + + assertThat(e.getMessage(), equalTo("[maximum_number_trees] must be an integer in [1, 2000]")); + } + + public void testRegression_GivenFeatureBagFractionIsLessThanZero() { + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, + () -> new Regression("foo", 0.0, 0.0, 0.5, 500, -0.00001, "result")); + + assertThat(e.getMessage(), equalTo("[feature_bag_fraction] must be a double in (0, 1]")); + } + + public void testRegression_GivenFeatureBagFractionIsGreaterThanOne() { + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, + () -> new Regression("foo", 0.0, 0.0, 0.5, 500, 1.00001, "result")); + + assertThat(e.getMessage(), equalTo("[feature_bag_fraction] must be a double in (0, 1]")); + } +} diff --git a/x-pack/plugin/ml/qa/ml-with-security/build.gradle b/x-pack/plugin/ml/qa/ml-with-security/build.gradle index 2fa1d8d4098..d3e7f73862d 100644 --- a/x-pack/plugin/ml/qa/ml-with-security/build.gradle +++ b/x-pack/plugin/ml/qa/ml-with-security/build.gradle @@ -69,6 +69,15 @@ integTest.runner { 'ml/data_frame_analytics_crud/Test get stats given expression without matches and allow_no_match is false', 'ml/data_frame_analytics_crud/Test delete given missing config', 'ml/data_frame_analytics_crud/Test max model memory limit', + 'ml/data_frame_analytics_crud/Test put regression given dependent_variable is not defined', + 'ml/data_frame_analytics_crud/Test put regression given negative lambda', + 'ml/data_frame_analytics_crud/Test put regression given negative gamma', + 'ml/data_frame_analytics_crud/Test put regression given eta less than 1e-3', + 'ml/data_frame_analytics_crud/Test put regression given eta greater than one', + 'ml/data_frame_analytics_crud/Test put regression given maximum_number_trees is zero', + 'ml/data_frame_analytics_crud/Test put regression given maximum_number_trees is greater than 2k', + 'ml/data_frame_analytics_crud/Test put regression given feature_bag_fraction is negative', + 'ml/data_frame_analytics_crud/Test put regression given feature_bag_fraction is greater than one', 'ml/evaluate_data_frame/Test given missing index', 'ml/evaluate_data_frame/Test given index does not exist', 'ml/evaluate_data_frame/Test given missing evaluation', diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java index 56ea04793c3..520f7a30ece 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java @@ -21,6 +21,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import java.io.IOException; import java.util.ArrayList; @@ -118,4 +119,13 @@ abstract class MlNativeDataFrameAnalyticsIntegTestCase extends MlNativeIntegTest assertThat(stats.get(0).getId(), equalTo(id)); assertThat(stats.get(0).getState(), equalTo(state)); } + + protected static DataFrameAnalyticsConfig buildRegressionAnalytics(String id, String[] sourceIndex, String destIndex, + @Nullable String resultsField, String dependentVariable) { + DataFrameAnalyticsConfig.Builder configBuilder = new DataFrameAnalyticsConfig.Builder(id); + configBuilder.setSource(new DataFrameAnalyticsSource(sourceIndex, null)); + configBuilder.setDest(new DataFrameAnalyticsDest(destIndex, resultsField)); + configBuilder.setAnalysis(new Regression(dependentVariable)); + return configBuilder.build(); + } } diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java index 3e4fd4f7003..f1b9f8edf14 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java @@ -21,9 +21,12 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.junit.After; +import java.util.Arrays; +import java.util.List; import java.util.Map; import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.closeTo; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.greaterThanOrEqualTo; @@ -362,4 +365,68 @@ public class RunDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsIntegTest .setQuery(QueryBuilders.existsQuery("ml.outlier_score")).get(); assertThat(searchResponse.getHits().getTotalHits().value, equalTo((long) bulkRequestBuilder.numberOfActions())); } + + public void testRegressionWithNumericFeatureAndFewDocuments() throws Exception { + String sourceIndex = "test-regression-with-numeric-feature-and-few-docs"; + + BulkRequestBuilder bulkRequestBuilder = client().prepareBulk(); + bulkRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + + List featureValues = Arrays.asList(1.0, 2.0, 3.0); + List dependentVariableValues = Arrays.asList(10.0, 20.0, 30.0); + + for (int i = 0; i < 350; i++) { + Double field = featureValues.get(i % 3); + Double value = dependentVariableValues.get(i % 3); + + IndexRequest indexRequest = new IndexRequest(sourceIndex); + if (i < 300) { + indexRequest.source("feature", field, "variable", value); + } else { + indexRequest.source("feature", field); + } + bulkRequestBuilder.add(indexRequest); + } + BulkResponse bulkResponse = bulkRequestBuilder.get(); + if (bulkResponse.hasFailures()) { + fail("Failed to index data: " + bulkResponse.buildFailureMessage()); + } + + String id = "test_regression_with_numeric_feature_and_few_docs"; + DataFrameAnalyticsConfig config = buildRegressionAnalytics(id, new String[] {sourceIndex}, + sourceIndex + "-results", null, "variable"); + registerAnalytics(config); + putAnalytics(config); + + assertState(id, DataFrameAnalyticsState.STOPPED); + + startAnalytics(id); + waitUntilAnalyticsIsStopped(id); + + int resultsWithPrediction = 0; + SearchResponse sourceData = client().prepareSearch(sourceIndex).get(); + for (SearchHit hit : sourceData.getHits()) { + GetResponse destDocGetResponse = client().prepareGet().setIndex(config.getDest().getIndex()).setId(hit.getId()).get(); + assertThat(destDocGetResponse.isExists(), is(true)); + Map sourceDoc = hit.getSourceAsMap(); + Map destDoc = destDocGetResponse.getSource(); + for (String field : sourceDoc.keySet()) { + assertThat(destDoc.containsKey(field), is(true)); + assertThat(destDoc.get(field), equalTo(sourceDoc.get(field))); + } + assertThat(destDoc.containsKey("ml"), is(true)); + + @SuppressWarnings("unchecked") + Map resultsObject = (Map) destDoc.get("ml"); + + if (resultsObject.containsKey("variable_prediction")) { + resultsWithPrediction++; + double featureValue = (double) destDoc.get("feature"); + double predictionValue = (double) resultsObject.get("variable_prediction"); + // it seems for this case values can be as far off as 2.0 + assertThat(predictionValue, closeTo(10 * featureValue, 2.0)); + } + } + assertThat(resultsWithPrediction, greaterThan(0)); + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/fields/ExtractedField.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/fields/ExtractedField.java index 7dafbb5f4dc..8c741e3c535 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/fields/ExtractedField.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/fields/ExtractedField.java @@ -16,10 +16,12 @@ import org.elasticsearch.search.SearchHit; import java.io.IOException; import java.text.ParseException; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.Locale; import java.util.Map; import java.util.Objects; +import java.util.Set; /** * Represents a field to be extracted by the datafeed. @@ -37,11 +39,14 @@ public abstract class ExtractedField { /** The name of the field we extract */ protected final String name; + private final Set types; + private final ExtractionMethod extractionMethod; - protected ExtractedField(String alias, String name, ExtractionMethod extractionMethod) { + protected ExtractedField(String alias, String name, Set types, ExtractionMethod extractionMethod) { this.alias = Objects.requireNonNull(alias); this.name = Objects.requireNonNull(name); + this.types = Objects.requireNonNull(types); this.extractionMethod = Objects.requireNonNull(extractionMethod); } @@ -53,6 +58,10 @@ public abstract class ExtractedField { return name; } + public Set getTypes() { + return types; + } + public ExtractionMethod getExtractionMethod() { return extractionMethod; } @@ -65,32 +74,32 @@ public abstract class ExtractedField { return null; } - public static ExtractedField newTimeField(String name, ExtractionMethod extractionMethod) { + public static ExtractedField newTimeField(String name, Set types, ExtractionMethod extractionMethod) { if (extractionMethod == ExtractionMethod.SOURCE) { throw new IllegalArgumentException("time field cannot be extracted from source"); } - return new TimeField(name, extractionMethod); + return new TimeField(name, types, extractionMethod); } public static ExtractedField newGeoShapeField(String alias, String name) { - return new GeoShapeField(alias, name); + return new GeoShapeField(alias, name, Collections.singleton("geo_shape")); } public static ExtractedField newGeoPointField(String alias, String name) { - return new GeoPointField(alias, name); + return new GeoPointField(alias, name, Collections.singleton("geo_point")); } - public static ExtractedField newField(String name, ExtractionMethod extractionMethod) { - return newField(name, name, extractionMethod); + public static ExtractedField newField(String name, Set types, ExtractionMethod extractionMethod) { + return newField(name, name, types, extractionMethod); } - public static ExtractedField newField(String alias, String name, ExtractionMethod extractionMethod) { + public static ExtractedField newField(String alias, String name, Set types, ExtractionMethod extractionMethod) { switch (extractionMethod) { case DOC_VALUE: case SCRIPT_FIELD: - return new FromFields(alias, name, extractionMethod); + return new FromFields(alias, name, types, extractionMethod); case SOURCE: - return new FromSource(alias, name); + return new FromSource(alias, name, types); default: throw new IllegalArgumentException("Invalid extraction method [" + extractionMethod + "]"); } @@ -98,7 +107,7 @@ public abstract class ExtractedField { public ExtractedField newFromSource() { if (supportsFromSource()) { - return new FromSource(alias, name); + return new FromSource(alias, name, types); } throw new IllegalStateException("Field (alias [" + alias + "], name [" + name + "]) should be extracted via [" + extractionMethod + "] and cannot be extracted from source"); @@ -106,8 +115,8 @@ public abstract class ExtractedField { private static class FromFields extends ExtractedField { - FromFields(String alias, String name, ExtractionMethod extractionMethod) { - super(alias, name, extractionMethod); + FromFields(String alias, String name, Set types, ExtractionMethod extractionMethod) { + super(alias, name, types, extractionMethod); } @Override @@ -129,8 +138,8 @@ public abstract class ExtractedField { private static class GeoShapeField extends FromSource { private static final WellKnownText wkt = new WellKnownText(true, new StandardValidator(true)); - GeoShapeField(String alias, String name) { - super(alias, name); + GeoShapeField(String alias, String name, Set types) { + super(alias, name, types); } @Override @@ -186,8 +195,8 @@ public abstract class ExtractedField { private static class GeoPointField extends FromFields { - GeoPointField(String alias, String name) { - super(alias, name, ExtractionMethod.DOC_VALUE); + GeoPointField(String alias, String name, Set types) { + super(alias, name, types, ExtractionMethod.DOC_VALUE); } @Override @@ -222,8 +231,8 @@ public abstract class ExtractedField { private static final String EPOCH_MILLIS_FORMAT = "epoch_millis"; - TimeField(String name, ExtractionMethod extractionMethod) { - super(name, name, extractionMethod); + TimeField(String name, Set types, ExtractionMethod extractionMethod) { + super(name, name, types, extractionMethod); } @Override @@ -255,8 +264,8 @@ public abstract class ExtractedField { private String[] namePath; - FromSource(String alias, String name) { - super(alias, name, ExtractionMethod.SOURCE); + FromSource(String alias, String name, Set types) { + super(alias, name, types, ExtractionMethod.SOURCE); namePath = name.split("\\."); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/fields/ExtractedFields.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/fields/ExtractedFields.java index 9495c5a2b40..6c90d2c7db2 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/fields/ExtractedFields.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/fields/ExtractedFields.java @@ -47,15 +47,6 @@ public class ExtractedFields { return docValueFields; } - /** - * Returns a new instance which only contains fields matching the given extraction method - * @param method the extraction method to filter fields on - * @return a new instance which only contains fields matching the given extraction method - */ - public ExtractedFields filterFields(ExtractedField.ExtractionMethod method) { - return new ExtractedFields(filterFields(method, allFields)); - } - private static List filterFields(ExtractedField.ExtractionMethod method, List fields) { return fields.stream().filter(field -> field.getExtractionMethod() == method).collect(Collectors.toList()); } @@ -79,12 +70,13 @@ public class ExtractedFields { protected ExtractedField detect(String field) { String internalField = field; ExtractedField.ExtractionMethod method = ExtractedField.ExtractionMethod.SOURCE; + Set types = getTypes(field); if (scriptFields.contains(field)) { method = ExtractedField.ExtractionMethod.SCRIPT_FIELD; } else if (isAggregatable(field)) { method = ExtractedField.ExtractionMethod.DOC_VALUE; if (isFieldOfType(field, "date")) { - return ExtractedField.newTimeField(field, method); + return ExtractedField.newTimeField(field, types, method); } } else if (isFieldOfType(field, TEXT)) { String parentField = MlStrings.getParentField(field); @@ -107,7 +99,12 @@ public class ExtractedFields { return ExtractedField.newGeoShapeField(field, internalField); } - return ExtractedField.newField(field, internalField, method); + return ExtractedField.newField(field, internalField, types, method); + } + + private Set getTypes(String field) { + Map fieldCaps = fieldsCapabilities.getField(field); + return fieldCaps == null ? Collections.emptySet() : fieldCaps.keySet(); } protected boolean isAggregatable(String field) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/fields/TimeBasedExtractedFields.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/fields/TimeBasedExtractedFields.java index cf87671bf33..1067ef63007 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/fields/TimeBasedExtractedFields.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/fields/TimeBasedExtractedFields.java @@ -12,6 +12,7 @@ import org.elasticsearch.xpack.core.ml.job.config.Job; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.Objects; import java.util.Set; @@ -55,12 +56,20 @@ public class TimeBasedExtractedFields extends ExtractedFields { if (scriptFields.contains(timeField) == false && extractionMethodDetector.isAggregatable(timeField) == false) { throw new IllegalArgumentException("cannot retrieve time field [" + timeField + "] because it is not aggregatable"); } - ExtractedField timeExtractedField = ExtractedField.newTimeField(timeField, scriptFields.contains(timeField) ? - ExtractedField.ExtractionMethod.SCRIPT_FIELD : ExtractedField.ExtractionMethod.DOC_VALUE); + ExtractedField timeExtractedField = extractedTimeField(timeField, scriptFields, fieldsCapabilities); List remainingFields = job.allInputFields().stream().filter(f -> !f.equals(timeField)).collect(Collectors.toList()); List allExtractedFields = new ArrayList<>(remainingFields.size() + 1); allExtractedFields.add(timeExtractedField); remainingFields.stream().forEach(field -> allExtractedFields.add(extractionMethodDetector.detect(field))); return new TimeBasedExtractedFields(timeExtractedField, allExtractedFields); } + + private static ExtractedField extractedTimeField(String timeField, Set scriptFields, + FieldCapabilitiesResponse fieldCapabilities) { + if (scriptFields.contains(timeField)) { + return ExtractedField.newTimeField(timeField, Collections.emptySet(), ExtractedField.ExtractionMethod.SCRIPT_FIELD); + } + return ExtractedField.newTimeField(timeField, fieldCapabilities.getField(timeField).keySet(), + ExtractedField.ExtractionMethod.DOC_VALUE); + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java index fa18f3bb25b..d9f1aa994d5 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java @@ -29,11 +29,13 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.NoSuchElementException; import java.util.Objects; import java.util.Optional; +import java.util.Set; import java.util.concurrent.TimeUnit; import java.util.function.Supplier; import java.util.stream.Collectors; @@ -179,7 +181,7 @@ public class DataFrameDataExtractor { for (int i = 0; i < extractedValues.length; ++i) { ExtractedField field = context.extractedFields.getAllFields().get(i); Object[] values = field.value(hit); - if (values.length == 1 && values[0] instanceof Number) { + if (values.length == 1 && (values[0] instanceof Number || values[0] instanceof String)) { extractedValues[i] = Objects.toString(values[0]); } else { extractedValues = null; @@ -233,6 +235,17 @@ public class DataFrameDataExtractor { return new DataSummary(searchResponse.getHits().getTotalHits().value, context.extractedFields.getAllFields().size()); } + public Set getCategoricalFields() { + Set categoricalFields = new HashSet<>(); + for (ExtractedField extractedField : context.extractedFields.getAllFields()) { + String fieldName = extractedField.getName(); + if (ExtractedFieldsDetector.CATEGORICAL_TYPES.containsAll(extractedField.getTypes())) { + categoricalFields.add(fieldName); + } + } + return categoricalFields; + } + public static class DataSummary { public final long rows; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java index d58eaebe353..3ff8c8a4923 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java @@ -5,6 +5,8 @@ */ package org.elasticsearch.xpack.ml.dataframe.extractor; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.fieldcaps.FieldCapabilities; import org.elasticsearch.action.fieldcaps.FieldCapabilitiesResponse; @@ -20,6 +22,7 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.NameResolver; import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedField; import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedFields; +import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsIndex; import java.util.ArrayList; import java.util.Arrays; @@ -35,24 +38,24 @@ import java.util.stream.Stream; public class ExtractedFieldsDetector { + private static final Logger LOGGER = LogManager.getLogger(ExtractedFieldsDetector.class); + /** * Fields to ignore. These are mostly internal meta fields. */ private static final List IGNORE_FIELDS = Arrays.asList("_id", "_field_names", "_index", "_parent", "_routing", "_seq_no", - "_source", "_type", "_uid", "_version", "_feature", "_ignored"); + "_source", "_type", "_uid", "_version", "_feature", "_ignored", DataFrameAnalyticsIndex.ID_COPY); - /** - * The types supported by data frames - */ - private static final Set COMPATIBLE_FIELD_TYPES; + public static final Set CATEGORICAL_TYPES = Collections.unmodifiableSet(new HashSet<>(Arrays.asList("text", "keyword", "ip"))); + + private static final Set NUMERICAL_TYPES; static { - Set compatibleTypes = Stream.of(NumberFieldMapper.NumberType.values()) + Set numericalTypes = Stream.of(NumberFieldMapper.NumberType.values()) .map(NumberFieldMapper.NumberType::typeName) .collect(Collectors.toSet()); - compatibleTypes.add("scaled_float"); // have to add manually since scaled_float is in a module - - COMPATIBLE_FIELD_TYPES = Collections.unmodifiableSet(compatibleTypes); + numericalTypes.add("scaled_float"); + NUMERICAL_TYPES = Collections.unmodifiableSet(numericalTypes); } private final String[] index; @@ -79,16 +82,18 @@ public class ExtractedFieldsDetector { // Ignore fields under the results object fields.removeIf(field -> field.startsWith(config.getDest().getResultsField() + ".")); + includeAndExcludeFields(fields); removeFieldsWithIncompatibleTypes(fields); - includeAndExcludeFields(fields, index); + checkRequiredFieldsArePresent(fields); + + if (fields.isEmpty()) { + throw ExceptionsHelper.badRequestException("No compatible fields could be detected in index {}", Arrays.toString(index)); + } + List sortedFields = new ArrayList<>(fields); // We sort the fields to ensure the checksum for each document is deterministic Collections.sort(sortedFields); - ExtractedFields extractedFields = ExtractedFields.build(sortedFields, Collections.emptySet(), fieldCapabilitiesResponse) - .filterFields(ExtractedField.ExtractionMethod.DOC_VALUE); - if (extractedFields.getAllFields().isEmpty()) { - throw ExceptionsHelper.badRequestException("No compatible fields could be detected in index {}", Arrays.toString(index)); - } + ExtractedFields extractedFields = ExtractedFields.build(sortedFields, Collections.emptySet(), fieldCapabilitiesResponse); if (extractedFields.getDocValueFields().size() > docValueFieldsLimit) { extractedFields = fetchFromSourceIfSupported(extractedFields); if (extractedFields.getDocValueFields().size() > docValueFieldsLimit) { @@ -120,13 +125,25 @@ public class ExtractedFieldsDetector { while (fieldsIterator.hasNext()) { String field = fieldsIterator.next(); Map fieldCaps = fieldCapabilitiesResponse.getField(field); - if (fieldCaps == null || COMPATIBLE_FIELD_TYPES.containsAll(fieldCaps.keySet()) == false) { + if (fieldCaps == null) { + LOGGER.debug("[{}] Removing field [{}] because it is missing from mappings", config.getId(), field); fieldsIterator.remove(); + } else { + Set fieldTypes = fieldCaps.keySet(); + if (NUMERICAL_TYPES.containsAll(fieldTypes)) { + LOGGER.debug("[{}] field [{}] is compatible as it is numerical", config.getId(), field); + } else if (config.getAnalysis().supportsCategoricalFields() && CATEGORICAL_TYPES.containsAll(fieldTypes)) { + LOGGER.debug("[{}] field [{}] is compatible as it is categorical", config.getId(), field); + } else { + LOGGER.debug("[{}] Removing field [{}] because its types are not supported; types {}", + config.getId(), field, fieldTypes); + fieldsIterator.remove(); + } } } } - private void includeAndExcludeFields(Set fields, String[] index) { + private void includeAndExcludeFields(Set fields) { FetchSourceContext analyzedFields = config.getAnalyzedFields(); if (analyzedFields == null) { return; @@ -159,6 +176,16 @@ public class ExtractedFieldsDetector { } } + private void checkRequiredFieldsArePresent(Set fields) { + List missingFields = config.getAnalysis().getRequiredFields() + .stream() + .filter(f -> fields.contains(f) == false) + .collect(Collectors.toList()); + if (missingFields.isEmpty() == false) { + throw ExceptionsHelper.badRequestException("required fields {} are missing", missingFields); + } + } + private ExtractedFields fetchFromSourceIfSupported(ExtractedFields extractedFields) { List adjusted = new ArrayList<>(extractedFields.getAllFields().size()); for (ExtractedField field : extractedFields.getDocValueFields()) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfig.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfig.java index 226498376bb..70a2e213fb6 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfig.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfig.java @@ -12,6 +12,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis; import java.io.IOException; import java.util.Objects; +import java.util.Set; public class AnalyticsProcessConfig implements ToXContentObject { @@ -21,21 +22,24 @@ public class AnalyticsProcessConfig implements ToXContentObject { private static final String THREADS = "threads"; private static final String ANALYSIS = "analysis"; private static final String RESULTS_FIELD = "results_field"; + private static final String CATEGORICAL_FIELDS = "categorical_fields"; private final long rows; private final int cols; private final ByteSizeValue memoryLimit; private final int threads; - private final DataFrameAnalysis analysis; private final String resultsField; + private final Set categoricalFields; + private final DataFrameAnalysis analysis; public AnalyticsProcessConfig(long rows, int cols, ByteSizeValue memoryLimit, int threads, String resultsField, - DataFrameAnalysis analysis) { + Set categoricalFields, DataFrameAnalysis analysis) { this.rows = rows; this.cols = cols; this.memoryLimit = Objects.requireNonNull(memoryLimit); this.threads = threads; this.resultsField = Objects.requireNonNull(resultsField); + this.categoricalFields = Objects.requireNonNull(categoricalFields); this.analysis = Objects.requireNonNull(analysis); } @@ -51,6 +55,7 @@ public class AnalyticsProcessConfig implements ToXContentObject { builder.field(MEMORY_LIMIT, memoryLimit.getBytes()); builder.field(THREADS, threads); builder.field(RESULTS_FIELD, resultsField); + builder.field(CATEGORICAL_FIELDS, categoricalFields); builder.field(ANALYSIS, new DataFrameAnalysisWrapper(analysis)); builder.endObject(); return builder; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java index cb000a15496..f04ba577be4 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java @@ -26,6 +26,7 @@ import java.io.IOException; import java.util.List; import java.util.Objects; import java.util.Optional; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ExecutorService; @@ -283,8 +284,9 @@ public class AnalyticsProcessManager { private AnalyticsProcessConfig createProcessConfig(DataFrameAnalyticsConfig config, DataFrameDataExtractor dataExtractor) { DataFrameDataExtractor.DataSummary dataSummary = dataExtractor.collectDataSummary(); + Set categoricalFields = dataExtractor.getCategoricalFields(); AnalyticsProcessConfig processConfig = new AnalyticsProcessConfig(dataSummary.rows, dataSummary.cols, - config.getModelMemoryLimit(), 1, config.getDest().getResultsField(), config.getAnalysis()); + config.getModelMemoryLimit(), 1, config.getDest().getResultsField(), categoricalFields, config.getAnalysis()); return processConfig; } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/datafeed/extractor/fields/ExtractedFieldTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/datafeed/extractor/fields/ExtractedFieldTests.java index 6969c97be0a..87f86a33f99 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/datafeed/extractor/fields/ExtractedFieldTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/datafeed/extractor/fields/ExtractedFieldTests.java @@ -10,6 +10,7 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.ml.test.SearchHitBuilder; import java.util.Arrays; +import java.util.Collections; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.startsWith; @@ -19,46 +20,51 @@ public class ExtractedFieldTests extends ESTestCase { public void testValueGivenDocValue() { SearchHit hit = new SearchHitBuilder(42).addField("single", "bar").addField("array", Arrays.asList("a", "b")).build(); - ExtractedField single = ExtractedField.newField("single", ExtractedField.ExtractionMethod.DOC_VALUE); + ExtractedField single = ExtractedField.newField("single", Collections.singleton("keyword"), + ExtractedField.ExtractionMethod.DOC_VALUE); assertThat(single.value(hit), equalTo(new String[] { "bar" })); - ExtractedField array = ExtractedField.newField("array", ExtractedField.ExtractionMethod.DOC_VALUE); + ExtractedField array = ExtractedField.newField("array", Collections.singleton("keyword"), + ExtractedField.ExtractionMethod.DOC_VALUE); assertThat(array.value(hit), equalTo(new String[] { "a", "b" })); - ExtractedField missing = ExtractedField.newField("missing", ExtractedField.ExtractionMethod.DOC_VALUE); + ExtractedField missing = ExtractedField.newField("missing",Collections.singleton("keyword"), + ExtractedField.ExtractionMethod.DOC_VALUE); assertThat(missing.value(hit), equalTo(new Object[0])); } public void testValueGivenScriptField() { SearchHit hit = new SearchHitBuilder(42).addField("single", "bar").addField("array", Arrays.asList("a", "b")).build(); - ExtractedField single = ExtractedField.newField("single", ExtractedField.ExtractionMethod.SCRIPT_FIELD); + ExtractedField single = ExtractedField.newField("single",Collections.emptySet(), + ExtractedField.ExtractionMethod.SCRIPT_FIELD); assertThat(single.value(hit), equalTo(new String[] { "bar" })); - ExtractedField array = ExtractedField.newField("array", ExtractedField.ExtractionMethod.SCRIPT_FIELD); + ExtractedField array = ExtractedField.newField("array", Collections.emptySet(), ExtractedField.ExtractionMethod.SCRIPT_FIELD); assertThat(array.value(hit), equalTo(new String[] { "a", "b" })); - ExtractedField missing = ExtractedField.newField("missing", ExtractedField.ExtractionMethod.SCRIPT_FIELD); + ExtractedField missing = ExtractedField.newField("missing", Collections.emptySet(), ExtractedField.ExtractionMethod.SCRIPT_FIELD); assertThat(missing.value(hit), equalTo(new Object[0])); } public void testValueGivenSource() { SearchHit hit = new SearchHitBuilder(42).setSource("{\"single\":\"bar\",\"array\":[\"a\",\"b\"]}").build(); - ExtractedField single = ExtractedField.newField("single", ExtractedField.ExtractionMethod.SOURCE); + ExtractedField single = ExtractedField.newField("single", Collections.singleton("text"), ExtractedField.ExtractionMethod.SOURCE); assertThat(single.value(hit), equalTo(new String[] { "bar" })); - ExtractedField array = ExtractedField.newField("array", ExtractedField.ExtractionMethod.SOURCE); + ExtractedField array = ExtractedField.newField("array", Collections.singleton("text"), ExtractedField.ExtractionMethod.SOURCE); assertThat(array.value(hit), equalTo(new String[] { "a", "b" })); - ExtractedField missing = ExtractedField.newField("missing", ExtractedField.ExtractionMethod.SOURCE); + ExtractedField missing = ExtractedField.newField("missing", Collections.singleton("text"), ExtractedField.ExtractionMethod.SOURCE); assertThat(missing.value(hit), equalTo(new Object[0])); } public void testValueGivenNestedSource() { SearchHit hit = new SearchHitBuilder(42).setSource("{\"level_1\":{\"level_2\":{\"foo\":\"bar\"}}}").build(); - ExtractedField nested = ExtractedField.newField("alias", "level_1.level_2.foo", ExtractedField.ExtractionMethod.SOURCE); + ExtractedField nested = ExtractedField.newField("alias", "level_1.level_2.foo", Collections.singleton("text"), + ExtractedField.ExtractionMethod.SOURCE); assertThat(nested.value(hit), equalTo(new String[] { "bar" })); } @@ -91,49 +97,54 @@ public class ExtractedFieldTests extends ESTestCase { } public void testValueGivenSourceAndHitWithNoSource() { - ExtractedField missing = ExtractedField.newField("missing", ExtractedField.ExtractionMethod.SOURCE); + ExtractedField missing = ExtractedField.newField("missing", Collections.singleton("text"), ExtractedField.ExtractionMethod.SOURCE); assertThat(missing.value(new SearchHitBuilder(3).build()), equalTo(new Object[0])); } public void testValueGivenMismatchingMethod() { SearchHit hit = new SearchHitBuilder(42).addField("a", 1).setSource("{\"b\":2}").build(); - ExtractedField invalidA = ExtractedField.newField("a", ExtractedField.ExtractionMethod.SOURCE); + ExtractedField invalidA = ExtractedField.newField("a", Collections.singleton("text"), ExtractedField.ExtractionMethod.SOURCE); assertThat(invalidA.value(hit), equalTo(new Object[0])); - ExtractedField validA = ExtractedField.newField("a", ExtractedField.ExtractionMethod.DOC_VALUE); + ExtractedField validA = ExtractedField.newField("a", Collections.singleton("keyword"), ExtractedField.ExtractionMethod.DOC_VALUE); assertThat(validA.value(hit), equalTo(new Integer[] { 1 })); - ExtractedField invalidB = ExtractedField.newField("b", ExtractedField.ExtractionMethod.DOC_VALUE); + ExtractedField invalidB = ExtractedField.newField("b", Collections.singleton("keyword"), + ExtractedField.ExtractionMethod.DOC_VALUE); assertThat(invalidB.value(hit), equalTo(new Object[0])); - ExtractedField validB = ExtractedField.newField("b", ExtractedField.ExtractionMethod.SOURCE); + ExtractedField validB = ExtractedField.newField("b", Collections.singleton("text"), ExtractedField.ExtractionMethod.SOURCE); assertThat(validB.value(hit), equalTo(new Integer[] { 2 })); } public void testValueGivenEmptyHit() { SearchHit hit = new SearchHitBuilder(42).build(); - ExtractedField docValue = ExtractedField.newField("a", ExtractedField.ExtractionMethod.SOURCE); + ExtractedField docValue = ExtractedField.newField("a", Collections.singleton("text"), ExtractedField.ExtractionMethod.SOURCE); assertThat(docValue.value(hit), equalTo(new Object[0])); - ExtractedField sourceField = ExtractedField.newField("b", ExtractedField.ExtractionMethod.DOC_VALUE); + ExtractedField sourceField = ExtractedField.newField("b", Collections.singleton("keyword"), + ExtractedField.ExtractionMethod.DOC_VALUE); assertThat(sourceField.value(hit), equalTo(new Object[0])); } public void testNewTimeFieldGivenSource() { - expectThrows(IllegalArgumentException.class, () -> ExtractedField.newTimeField("time", ExtractedField.ExtractionMethod.SOURCE)); + expectThrows(IllegalArgumentException.class, () -> ExtractedField.newTimeField("time", Collections.singleton("date"), + ExtractedField.ExtractionMethod.SOURCE)); } public void testValueGivenStringTimeField() { final long millis = randomLong(); final SearchHit hit = new SearchHitBuilder(randomInt()).addField("time", Long.toString(millis)).build(); - final ExtractedField timeField = ExtractedField.newTimeField("time", ExtractedField.ExtractionMethod.DOC_VALUE); + final ExtractedField timeField = ExtractedField.newTimeField("time", Collections.singleton("date"), + ExtractedField.ExtractionMethod.DOC_VALUE); assertThat(timeField.value(hit), equalTo(new Object[] { millis })); } public void testValueGivenLongTimeField() { final long millis = randomLong(); final SearchHit hit = new SearchHitBuilder(randomInt()).addField("time", millis).build(); - final ExtractedField timeField = ExtractedField.newTimeField("time", ExtractedField.ExtractionMethod.DOC_VALUE); + final ExtractedField timeField = ExtractedField.newTimeField("time", Collections.singleton("date"), + ExtractedField.ExtractionMethod.DOC_VALUE); assertThat(timeField.value(hit), equalTo(new Object[] { millis })); } @@ -141,13 +152,15 @@ public class ExtractedFieldTests extends ESTestCase { // Prior to 6.x, timestamps were simply `long` milliseconds-past-the-epoch values final long millis = randomLong(); final SearchHit hit = new SearchHitBuilder(randomInt()).addField("time", millis).build(); - final ExtractedField timeField = ExtractedField.newTimeField("time", ExtractedField.ExtractionMethod.DOC_VALUE); + final ExtractedField timeField = ExtractedField.newTimeField("time", Collections.singleton("date"), + ExtractedField.ExtractionMethod.DOC_VALUE); assertThat(timeField.value(hit), equalTo(new Object[] { millis })); } public void testValueGivenUnknownFormatTimeField() { final SearchHit hit = new SearchHitBuilder(randomInt()).addField("time", new Object()).build(); - final ExtractedField timeField = ExtractedField.newTimeField("time", ExtractedField.ExtractionMethod.DOC_VALUE); + final ExtractedField timeField = ExtractedField.newTimeField("time", Collections.singleton("date"), + ExtractedField.ExtractionMethod.DOC_VALUE); assertThat(expectThrows(IllegalStateException.class, () -> timeField.value(hit)).getMessage(), startsWith("Unexpected value for a time field")); } @@ -155,14 +168,15 @@ public class ExtractedFieldTests extends ESTestCase { public void testAliasVersusName() { SearchHit hit = new SearchHitBuilder(42).addField("a", 1).addField("b", 2).build(); - ExtractedField field = ExtractedField.newField("a", "a", ExtractedField.ExtractionMethod.DOC_VALUE); + ExtractedField field = ExtractedField.newField("a", "a", Collections.singleton("int"), + ExtractedField.ExtractionMethod.DOC_VALUE); assertThat(field.getAlias(), equalTo("a")); assertThat(field.getName(), equalTo("a")); assertThat(field.value(hit), equalTo(new Integer[] { 1 })); hit = new SearchHitBuilder(42).addField("a", 1).addField("b", 2).build(); - field = ExtractedField.newField("a", "b", ExtractedField.ExtractionMethod.DOC_VALUE); + field = ExtractedField.newField("a", "b", Collections.singleton("int"), ExtractedField.ExtractionMethod.DOC_VALUE); assertThat(field.getAlias(), equalTo("a")); assertThat(field.getName(), equalTo("b")); assertThat(field.value(hit), equalTo(new Integer[] { 2 })); @@ -170,11 +184,11 @@ public class ExtractedFieldTests extends ESTestCase { public void testGetDocValueFormat() { for (ExtractedField.ExtractionMethod method : ExtractedField.ExtractionMethod.values()) { - assertThat(ExtractedField.newField("f", method).getDocValueFormat(), equalTo(null)); + assertThat(ExtractedField.newField("f", Collections.emptySet(), method).getDocValueFormat(), equalTo(null)); } - assertThat(ExtractedField.newTimeField("doc_value_time", ExtractedField.ExtractionMethod.DOC_VALUE).getDocValueFormat(), - equalTo("epoch_millis")); - assertThat(ExtractedField.newTimeField("source_time", ExtractedField.ExtractionMethod.SCRIPT_FIELD).getDocValueFormat(), - equalTo("epoch_millis")); + assertThat(ExtractedField.newTimeField("doc_value_time", Collections.singleton("date"), + ExtractedField.ExtractionMethod.DOC_VALUE).getDocValueFormat(), equalTo("epoch_millis")); + assertThat(ExtractedField.newTimeField("source_time", Collections.emptySet(), + ExtractedField.ExtractionMethod.SCRIPT_FIELD).getDocValueFormat(), equalTo("epoch_millis")); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/datafeed/extractor/fields/ExtractedFieldsTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/datafeed/extractor/fields/ExtractedFieldsTests.java index db25f820dbb..8dd81b47eb8 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/datafeed/extractor/fields/ExtractedFieldsTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/datafeed/extractor/fields/ExtractedFieldsTests.java @@ -27,12 +27,18 @@ import static org.mockito.Mockito.when; public class ExtractedFieldsTests extends ESTestCase { public void testAllTypesOfFields() { - ExtractedField docValue1 = ExtractedField.newField("doc1", ExtractedField.ExtractionMethod.DOC_VALUE); - ExtractedField docValue2 = ExtractedField.newField("doc2", ExtractedField.ExtractionMethod.DOC_VALUE); - ExtractedField scriptField1 = ExtractedField.newField("scripted1", ExtractedField.ExtractionMethod.SCRIPT_FIELD); - ExtractedField scriptField2 = ExtractedField.newField("scripted2", ExtractedField.ExtractionMethod.SCRIPT_FIELD); - ExtractedField sourceField1 = ExtractedField.newField("src1", ExtractedField.ExtractionMethod.SOURCE); - ExtractedField sourceField2 = ExtractedField.newField("src2", ExtractedField.ExtractionMethod.SOURCE); + ExtractedField docValue1 = ExtractedField.newField("doc1", Collections.singleton("keyword"), + ExtractedField.ExtractionMethod.DOC_VALUE); + ExtractedField docValue2 = ExtractedField.newField("doc2", Collections.singleton("ip"), + ExtractedField.ExtractionMethod.DOC_VALUE); + ExtractedField scriptField1 = ExtractedField.newField("scripted1", Collections.emptySet(), + ExtractedField.ExtractionMethod.SCRIPT_FIELD); + ExtractedField scriptField2 = ExtractedField.newField("scripted2", Collections.emptySet(), + ExtractedField.ExtractionMethod.SCRIPT_FIELD); + ExtractedField sourceField1 = ExtractedField.newField("src1", Collections.singleton("text"), + ExtractedField.ExtractionMethod.SOURCE); + ExtractedField sourceField2 = ExtractedField.newField("src2", Collections.singleton("text"), + ExtractedField.ExtractionMethod.SOURCE); ExtractedFields extractedFields = new ExtractedFields(Arrays.asList( docValue1, docValue2, scriptField1, scriptField2, sourceField1, sourceField2)); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/datafeed/extractor/fields/TimeBasedExtractedFieldsTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/datafeed/extractor/fields/TimeBasedExtractedFieldsTests.java index 6e7a3740e0a..652eb068783 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/datafeed/extractor/fields/TimeBasedExtractedFieldsTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/datafeed/extractor/fields/TimeBasedExtractedFieldsTests.java @@ -29,7 +29,8 @@ import static org.mockito.Mockito.when; public class TimeBasedExtractedFieldsTests extends ESTestCase { - private ExtractedField timeField = ExtractedField.newTimeField("time", ExtractedField.ExtractionMethod.DOC_VALUE); + private ExtractedField timeField = ExtractedField.newTimeField("time", Collections.singleton("date"), + ExtractedField.ExtractionMethod.DOC_VALUE); public void testInvalidConstruction() { expectThrows(IllegalArgumentException.class, () -> new TimeBasedExtractedFields(timeField, Collections.emptyList())); @@ -46,12 +47,18 @@ public class TimeBasedExtractedFieldsTests extends ESTestCase { } public void testAllTypesOfFields() { - ExtractedField docValue1 = ExtractedField.newField("doc1", ExtractedField.ExtractionMethod.DOC_VALUE); - ExtractedField docValue2 = ExtractedField.newField("doc2", ExtractedField.ExtractionMethod.DOC_VALUE); - ExtractedField scriptField1 = ExtractedField.newField("scripted1", ExtractedField.ExtractionMethod.SCRIPT_FIELD); - ExtractedField scriptField2 = ExtractedField.newField("scripted2", ExtractedField.ExtractionMethod.SCRIPT_FIELD); - ExtractedField sourceField1 = ExtractedField.newField("src1", ExtractedField.ExtractionMethod.SOURCE); - ExtractedField sourceField2 = ExtractedField.newField("src2", ExtractedField.ExtractionMethod.SOURCE); + ExtractedField docValue1 = ExtractedField.newField("doc1", Collections.singleton("keyword"), + ExtractedField.ExtractionMethod.DOC_VALUE); + ExtractedField docValue2 = ExtractedField.newField("doc2", Collections.singleton("float"), + ExtractedField.ExtractionMethod.DOC_VALUE); + ExtractedField scriptField1 = ExtractedField.newField("scripted1", Collections.emptySet(), + ExtractedField.ExtractionMethod.SCRIPT_FIELD); + ExtractedField scriptField2 = ExtractedField.newField("scripted2", Collections.emptySet(), + ExtractedField.ExtractionMethod.SCRIPT_FIELD); + ExtractedField sourceField1 = ExtractedField.newField("src1", Collections.singleton("text"), + ExtractedField.ExtractionMethod.SOURCE); + ExtractedField sourceField2 = ExtractedField.newField("src2", Collections.singleton("text"), + ExtractedField.ExtractionMethod.SOURCE); TimeBasedExtractedFields extractedFields = new TimeBasedExtractedFields(timeField, Arrays.asList(timeField, docValue1, docValue2, scriptField1, scriptField2, sourceField1, sourceField2)); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/datafeed/extractor/scroll/ScrollDataExtractorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/datafeed/extractor/scroll/ScrollDataExtractorTests.java index c383cf20b18..bdbe81a66a6 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/datafeed/extractor/scroll/ScrollDataExtractorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/datafeed/extractor/scroll/ScrollDataExtractorTests.java @@ -135,9 +135,11 @@ public class ScrollDataExtractorTests extends ESTestCase { capturedSearchRequests = new ArrayList<>(); capturedContinueScrollIds = new ArrayList<>(); jobId = "test-job"; - ExtractedField timeField = ExtractedField.newField("time", ExtractedField.ExtractionMethod.DOC_VALUE); + ExtractedField timeField = ExtractedField.newField("time", Collections.singleton("date"), + ExtractedField.ExtractionMethod.DOC_VALUE); extractedFields = new TimeBasedExtractedFields(timeField, - Arrays.asList(timeField, ExtractedField.newField("field_1", ExtractedField.ExtractionMethod.DOC_VALUE))); + Arrays.asList(timeField, ExtractedField.newField("field_1", Collections.singleton("keyword"), + ExtractedField.ExtractionMethod.DOC_VALUE))); indices = Arrays.asList("index-1", "index-2"); query = QueryBuilders.matchAllQuery(); scriptFields = Collections.emptyList(); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/datafeed/extractor/scroll/SearchHitToJsonProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/datafeed/extractor/scroll/SearchHitToJsonProcessorTests.java index 41a74814461..d2befb407ae 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/datafeed/extractor/scroll/SearchHitToJsonProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/datafeed/extractor/scroll/SearchHitToJsonProcessorTests.java @@ -16,16 +16,21 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.Arrays; +import java.util.Collections; import static org.hamcrest.Matchers.equalTo; public class SearchHitToJsonProcessorTests extends ESTestCase { public void testProcessGivenSingleHit() throws IOException { - ExtractedField timeField = ExtractedField.newField("time", ExtractedField.ExtractionMethod.DOC_VALUE); - ExtractedField missingField = ExtractedField.newField("missing", ExtractedField.ExtractionMethod.DOC_VALUE); - ExtractedField singleField = ExtractedField.newField("single", ExtractedField.ExtractionMethod.DOC_VALUE); - ExtractedField arrayField = ExtractedField.newField("array", ExtractedField.ExtractionMethod.DOC_VALUE); + ExtractedField timeField = ExtractedField.newField("time", Collections.singleton("date"), + ExtractedField.ExtractionMethod.DOC_VALUE); + ExtractedField missingField = ExtractedField.newField("missing", Collections.singleton("float"), + ExtractedField.ExtractionMethod.DOC_VALUE); + ExtractedField singleField = ExtractedField.newField("single", Collections.singleton("keyword"), + ExtractedField.ExtractionMethod.DOC_VALUE); + ExtractedField arrayField = ExtractedField.newField("array", Collections.singleton("keyword"), + ExtractedField.ExtractionMethod.DOC_VALUE); TimeBasedExtractedFields extractedFields = new TimeBasedExtractedFields(timeField, Arrays.asList(timeField, missingField, singleField, arrayField)); @@ -41,10 +46,14 @@ public class SearchHitToJsonProcessorTests extends ESTestCase { } public void testProcessGivenMultipleHits() throws IOException { - ExtractedField timeField = ExtractedField.newField("time", ExtractedField.ExtractionMethod.DOC_VALUE); - ExtractedField missingField = ExtractedField.newField("missing", ExtractedField.ExtractionMethod.DOC_VALUE); - ExtractedField singleField = ExtractedField.newField("single", ExtractedField.ExtractionMethod.DOC_VALUE); - ExtractedField arrayField = ExtractedField.newField("array", ExtractedField.ExtractionMethod.DOC_VALUE); + ExtractedField timeField = ExtractedField.newField("time", Collections.singleton("date"), + ExtractedField.ExtractionMethod.DOC_VALUE); + ExtractedField missingField = ExtractedField.newField("missing", Collections.singleton("float"), + ExtractedField.ExtractionMethod.DOC_VALUE); + ExtractedField singleField = ExtractedField.newField("single", Collections.singleton("keyword"), + ExtractedField.ExtractionMethod.DOC_VALUE); + ExtractedField arrayField = ExtractedField.newField("array", Collections.singleton("keyword"), + ExtractedField.ExtractionMethod.DOC_VALUE); TimeBasedExtractedFields extractedFields = new TimeBasedExtractedFields(timeField, Arrays.asList(timeField, missingField, singleField, arrayField)); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java index ffd53e5576f..b456de7b637 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java @@ -71,8 +71,8 @@ public class DataFrameDataExtractorTests extends ESTestCase { indices = Arrays.asList("index-1", "index-2"); query = QueryBuilders.matchAllQuery(); extractedFields = new ExtractedFields(Arrays.asList( - ExtractedField.newField("field_1", ExtractedField.ExtractionMethod.DOC_VALUE), - ExtractedField.newField("field_2", ExtractedField.ExtractionMethod.DOC_VALUE))); + ExtractedField.newField("field_1", Collections.singleton("keyword"), ExtractedField.ExtractionMethod.DOC_VALUE), + ExtractedField.newField("field_2", Collections.singleton("keyword"), ExtractedField.ExtractionMethod.DOC_VALUE))); scrollSize = 1000; headers = Collections.emptyMap(); @@ -288,8 +288,8 @@ public class DataFrameDataExtractorTests extends ESTestCase { public void testIncludeSourceIsFalseAndAtLeastOneSourceField() throws IOException { extractedFields = new ExtractedFields(Arrays.asList( - ExtractedField.newField("field_1", ExtractedField.ExtractionMethod.DOC_VALUE), - ExtractedField.newField("field_2", ExtractedField.ExtractionMethod.SOURCE))); + ExtractedField.newField("field_1", Collections.singleton("keyword"), ExtractedField.ExtractionMethod.DOC_VALUE), + ExtractedField.newField("field_2", Collections.singleton("text"), ExtractedField.ExtractionMethod.SOURCE))); TestExtractor dataExtractor = createExtractor(false); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java index 1345a1fe128..5f781538bec 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedField; import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedFields; @@ -38,11 +39,11 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { private static final String RESULTS_FIELD = "ml"; public void testDetect_GivenFloatField() { - FieldCapabilitiesResponse fieldCapabilities= new MockFieldCapsResponseBuilder() + FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() .addAggregatableField("some_float", "float").build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildAnalyticsConfig(), false, 100, fieldCapabilities); + SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities); ExtractedFields extractedFields = extractedFieldsDetector.detect(); List allFields = extractedFields.getAllFields(); @@ -52,12 +53,12 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { } public void testDetect_GivenNumericFieldWithMultipleTypes() { - FieldCapabilitiesResponse fieldCapabilities= new MockFieldCapsResponseBuilder() + FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() .addAggregatableField("some_number", "long", "integer", "short", "byte", "double", "float", "half_float", "scaled_float") .build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildAnalyticsConfig(), false, 100, fieldCapabilities); + SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities); ExtractedFields extractedFields = extractedFieldsDetector.detect(); List allFields = extractedFields.getAllFields(); @@ -67,36 +68,36 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { } public void testDetect_GivenNonNumericField() { - FieldCapabilitiesResponse fieldCapabilities= new MockFieldCapsResponseBuilder() + FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() .addAggregatableField("some_keyword", "keyword").build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildAnalyticsConfig(), false, 100, fieldCapabilities); + SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect()); assertThat(e.getMessage(), equalTo("No compatible fields could be detected in index [source_index]")); } - public void testDetect_GivenFieldWithNumericAndNonNumericTypes() { - FieldCapabilitiesResponse fieldCapabilities= new MockFieldCapsResponseBuilder() + public void testDetect_GivenOutlierDetectionAndFieldWithNumericAndNonNumericTypes() { + FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() .addAggregatableField("indecisive_field", "float", "keyword").build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildAnalyticsConfig(), false, 100, fieldCapabilities); + SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect()); assertThat(e.getMessage(), equalTo("No compatible fields could be detected in index [source_index]")); } - public void testDetect_GivenMultipleFields() { - FieldCapabilitiesResponse fieldCapabilities= new MockFieldCapsResponseBuilder() + public void testDetect_GivenOutlierDetectionAndMultipleFields() { + FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() .addAggregatableField("some_float", "float") .addAggregatableField("some_long", "long") .addAggregatableField("some_keyword", "keyword") .build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildAnalyticsConfig(), false, 100, fieldCapabilities); + SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities); ExtractedFields extractedFields = extractedFieldsDetector.detect(); List allFields = extractedFields.getAllFields(); @@ -107,12 +108,46 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { contains(equalTo(ExtractedField.ExtractionMethod.DOC_VALUE))); } + public void testDetect_GivenRegressionAndMultipleFields() { + FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() + .addAggregatableField("some_float", "float") + .addAggregatableField("some_long", "long") + .addAggregatableField("some_keyword", "keyword") + .addAggregatableField("foo", "keyword") + .build(); + + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + SOURCE_INDEX, buildRegressionConfig("foo"), false, 100, fieldCapabilities); + ExtractedFields extractedFields = extractedFieldsDetector.detect(); + + List allFields = extractedFields.getAllFields(); + assertThat(allFields.size(), equalTo(4)); + assertThat(allFields.stream().map(ExtractedField::getName).collect(Collectors.toList()), + contains("foo", "some_float", "some_keyword", "some_long")); + assertThat(allFields.stream().map(ExtractedField::getExtractionMethod).collect(Collectors.toSet()), + contains(equalTo(ExtractedField.ExtractionMethod.DOC_VALUE))); + } + + public void testDetect_GivenRegressionAndRequiredFieldMissing() { + FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() + .addAggregatableField("some_float", "float") + .addAggregatableField("some_long", "long") + .addAggregatableField("some_keyword", "keyword") + .build(); + + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + SOURCE_INDEX, buildRegressionConfig("foo"), false, 100, fieldCapabilities); + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect()); + + assertThat(e.getMessage(), equalTo("required fields [foo] are missing")); + } + public void testDetect_GivenIgnoredField() { - FieldCapabilitiesResponse fieldCapabilities= new MockFieldCapsResponseBuilder() + FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() .addAggregatableField("_id", "float").build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildAnalyticsConfig(), false, 100, fieldCapabilities); + SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect()); assertThat(e.getMessage(), equalTo("No compatible fields could be detected in index [source_index]")); @@ -134,7 +169,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { FieldCapabilitiesResponse fieldCapabilities = mockFieldCapsResponseBuilder.build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildAnalyticsConfig(), false, 100, fieldCapabilities); + SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities); ExtractedFields extractedFields = extractedFieldsDetector.detect(); List extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName) @@ -151,7 +186,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { FetchSourceContext desiredFields = new FetchSourceContext(true, new String[]{"your_field1", "my*"}, new String[0]); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildAnalyticsConfig(desiredFields), false, 100, fieldCapabilities); + SOURCE_INDEX, buildOutlierDetectionConfig(desiredFields), false, 100, fieldCapabilities); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect()); assertThat(e.getMessage(), equalTo("No field [your_field1] could be detected")); @@ -166,7 +201,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { FetchSourceContext desiredFields = new FetchSourceContext(true, new String[0], new String[]{"my_*"}); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildAnalyticsConfig(desiredFields), false, 100, fieldCapabilities); + SOURCE_INDEX, buildOutlierDetectionConfig(desiredFields), false, 100, fieldCapabilities); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect()); assertThat(e.getMessage(), equalTo("No compatible fields could be detected in index [source_index]")); } @@ -182,7 +217,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { FetchSourceContext desiredFields = new FetchSourceContext(true, new String[]{"your*", "my_*"}, new String[]{"*nope"}); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildAnalyticsConfig(desiredFields), false, 100, fieldCapabilities); + SOURCE_INDEX, buildOutlierDetectionConfig(desiredFields), false, 100, fieldCapabilities); ExtractedFields extractedFields = extractedFieldsDetector.detect(); List extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName) @@ -199,7 +234,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { .build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildAnalyticsConfig(), false, 100, fieldCapabilities); + SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect()); assertThat(e.getMessage(), equalTo("A field that matches the dest.results_field [ml] already exists; " + @@ -215,7 +250,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { .build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildAnalyticsConfig(), true, 100, fieldCapabilities); + SOURCE_INDEX, buildOutlierDetectionConfig(), true, 100, fieldCapabilities); ExtractedFields extractedFields = extractedFieldsDetector.detect(); List extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName) @@ -232,7 +267,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { .build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildAnalyticsConfig(), true, 4, fieldCapabilities); + SOURCE_INDEX, buildOutlierDetectionConfig(), true, 4, fieldCapabilities); ExtractedFields extractedFields = extractedFieldsDetector.detect(); List extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName) @@ -251,7 +286,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { .build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildAnalyticsConfig(), true, 3, fieldCapabilities); + SOURCE_INDEX, buildOutlierDetectionConfig(), true, 3, fieldCapabilities); ExtractedFields extractedFields = extractedFieldsDetector.detect(); List extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName) @@ -270,7 +305,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { .build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildAnalyticsConfig(), true, 2, fieldCapabilities); + SOURCE_INDEX, buildOutlierDetectionConfig(), true, 2, fieldCapabilities); ExtractedFields extractedFields = extractedFieldsDetector.detect(); List extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName) @@ -280,11 +315,11 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { contains(equalTo(ExtractedField.ExtractionMethod.SOURCE))); } - private static DataFrameAnalyticsConfig buildAnalyticsConfig() { - return buildAnalyticsConfig(null); + private static DataFrameAnalyticsConfig buildOutlierDetectionConfig() { + return buildOutlierDetectionConfig(null); } - private static DataFrameAnalyticsConfig buildAnalyticsConfig(FetchSourceContext analyzedFields) { + private static DataFrameAnalyticsConfig buildOutlierDetectionConfig(FetchSourceContext analyzedFields) { return new DataFrameAnalyticsConfig.Builder("foo") .setSource(new DataFrameAnalyticsSource(SOURCE_INDEX, null)) .setDest(new DataFrameAnalyticsDest(DEST_INDEX, null)) @@ -293,6 +328,19 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { .build(); } + private static DataFrameAnalyticsConfig buildRegressionConfig(String dependentVariable) { + return buildRegressionConfig(dependentVariable, null); + } + + private static DataFrameAnalyticsConfig buildRegressionConfig(String dependentVariable, FetchSourceContext analyzedFields) { + return new DataFrameAnalyticsConfig.Builder("foo") + .setSource(new DataFrameAnalyticsSource(SOURCE_INDEX, null)) + .setDest(new DataFrameAnalyticsDest(DEST_INDEX, null)) + .setAnalyzedFields(analyzedFields) + .setAnalysis(new Regression(dependentVariable)) + .build(); + } + private static class MockFieldCapsResponseBuilder { private final Map> fieldCaps = new HashMap<>(); diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml index 772c48e5474..253790878c5 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml @@ -607,7 +607,11 @@ setup: "dest": { "index": "index-bar_dest" }, - "analysis": {"outlier_detection":{}} + "analysis": { + "regression":{ + "dependent_variable": "to_predict" + } + } } - match: { id: "bar" } @@ -768,7 +772,11 @@ setup: "dest": { "index": "index-bar_dest" }, - "analysis": {"outlier_detection":{}} + "analysis": { + "regression":{ + "dependent_variable": "to_predict" + } + } } - match: { id: "bar" } @@ -930,3 +938,247 @@ setup: xpack.ml.max_model_memory_limit: null - match: {transient: {}} +--- +"Test put regression given dependent_variable is not defined": + + - do: + catch: /parse_exception/ + ml.put_data_frame_analytics: + id: "regression-without-dependent-variable" + body: > + { + "source": { + "index": "index-source" + }, + "dest": { + "index": "index-dest" + }, + "analysis": { + "regression": {} + } + } + +--- +"Test put regression given negative lambda": + + - do: + catch: /\[lambda\] must be a non-negative double/ + ml.put_data_frame_analytics: + id: "regression-negative-lambda" + body: > + { + "source": { + "index": "index-source" + }, + "dest": { + "index": "index-dest" + }, + "analysis": { + "regression": { + "dependent_variable": "foo", + "lambda": -1.0 + } + } + } + +--- +"Test put regression given negative gamma": + + - do: + catch: /\[gamma\] must be a non-negative double/ + ml.put_data_frame_analytics: + id: "regression-negative-gamma" + body: > + { + "source": { + "index": "index-source" + }, + "dest": { + "index": "index-dest" + }, + "analysis": { + "regression": { + "dependent_variable": "foo", + "gamma": -1.0 + } + } + } + +--- +"Test put regression given eta less than 1e-3": + + - do: + catch: /\[eta\] must be a double in \[0.001, 1\]/ + ml.put_data_frame_analytics: + id: "regression-eta-greater-less-than-valid" + body: > + { + "source": { + "index": "index-source" + }, + "dest": { + "index": "index-dest" + }, + "analysis": { + "regression": { + "dependent_variable": "foo", + "eta": 0.0009 + } + } + } + +--- +"Test put regression given eta greater than one": + + - do: + catch: /\[eta\] must be a double in \[0.001, 1\]/ + ml.put_data_frame_analytics: + id: "regression-eta-greater-than-one" + body: > + { + "source": { + "index": "index-source" + }, + "dest": { + "index": "index-dest" + }, + "analysis": { + "regression": { + "dependent_variable": "foo", + "eta": 1.00001 + } + } + } + +--- +"Test put regression given maximum_number_trees is zero": + + - do: + catch: /\[maximum_number_trees\] must be an integer in \[1, 2000\]/ + ml.put_data_frame_analytics: + id: "regression-maximum-number-trees-is-zero" + body: > + { + "source": { + "index": "index-source" + }, + "dest": { + "index": "index-dest" + }, + "analysis": { + "regression": { + "dependent_variable": "foo", + "maximum_number_trees": 0 + } + } + } + +--- +"Test put regression given maximum_number_trees is greater than 2k": + + - do: + catch: /\[maximum_number_trees\] must be an integer in \[1, 2000\]/ + ml.put_data_frame_analytics: + id: "regression-maximum-number-trees-greater-than-2k" + body: > + { + "source": { + "index": "index-source" + }, + "dest": { + "index": "index-dest" + }, + "analysis": { + "regression": { + "dependent_variable": "foo", + "maximum_number_trees": 2001 + } + } + } + +--- +"Test put regression given feature_bag_fraction is negative": + + - do: + catch: /\[feature_bag_fraction\] must be a double in \(0, 1\]/ + ml.put_data_frame_analytics: + id: "regression-feature-bag-fraction-is-negative" + body: > + { + "source": { + "index": "index-source" + }, + "dest": { + "index": "index-dest" + }, + "analysis": { + "regression": { + "dependent_variable": "foo", + "feature_bag_fraction": -0.0001 + } + } + } + +--- +"Test put regression given feature_bag_fraction is greater than one": + + - do: + catch: /\[feature_bag_fraction\] must be a double in \(0, 1\]/ + ml.put_data_frame_analytics: + id: "regression-feature-bag-fraction-is-greater-than-one" + body: > + { + "source": { + "index": "index-source" + }, + "dest": { + "index": "index-dest" + }, + "analysis": { + "regression": { + "dependent_variable": "foo", + "feature_bag_fraction": 1.0001 + } + } + } + +--- +"Test put regression given valid": + + - do: + ml.put_data_frame_analytics: + id: "valid-regression" + body: > + { + "source": { + "index": "index-source" + }, + "dest": { + "index": "index-dest" + }, + "analysis": { + "regression": { + "dependent_variable": "foo", + "lambda": 3.14, + "gamma": 0.42, + "eta": 0.5, + "maximum_number_trees": 400, + "feature_bag_fraction": 0.3 + } + } + } + - match: { id: "valid-regression" } + - match: { source.index: ["index-source"] } + - match: { dest.index: "index-dest" } + - match: { analysis: { + "regression":{ + "dependent_variable": "foo", + "lambda": 3.14, + "gamma": 0.42, + "eta": 0.5, + "maximum_number_trees": 400, + "feature_bag_fraction": 0.3 + } + }} + - is_true: create_time + - is_true: version diff --git a/x-pack/qa/full-cluster-restart/src/test/java/org/elasticsearch/xpack/restart/MlConfigIndexMappingsFullClusterRestartIT.java b/x-pack/qa/full-cluster-restart/src/test/java/org/elasticsearch/xpack/restart/MlConfigIndexMappingsFullClusterRestartIT.java index 4bbef475f99..8b7f9d06bf5 100644 --- a/x-pack/qa/full-cluster-restart/src/test/java/org/elasticsearch/xpack/restart/MlConfigIndexMappingsFullClusterRestartIT.java +++ b/x-pack/qa/full-cluster-restart/src/test/java/org/elasticsearch/xpack/restart/MlConfigIndexMappingsFullClusterRestartIT.java @@ -11,15 +11,22 @@ import org.elasticsearch.client.Response; import org.elasticsearch.client.ResponseException; import org.elasticsearch.client.WarningFailureException; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.json.JsonXContent; import org.elasticsearch.common.xcontent.support.XContentMapValues; import org.elasticsearch.upgrades.AbstractFullClusterRestartTestCase; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig; import org.elasticsearch.xpack.core.ml.job.config.DataDescription; import org.elasticsearch.xpack.core.ml.job.config.Detector; import org.elasticsearch.xpack.core.ml.job.config.Job; +import org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings; import org.elasticsearch.xpack.test.rest.XPackRestTestConstants; import org.elasticsearch.xpack.test.rest.XPackRestTestHelper; import org.junit.Before; @@ -28,12 +35,12 @@ import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.Base64; import java.util.Collections; -import java.util.HashMap; import java.util.List; import java.util.Map; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; import static org.hamcrest.Matchers.nullValue; public class MlConfigIndexMappingsFullClusterRestartIT extends AbstractFullClusterRestartTestCase { @@ -41,14 +48,23 @@ public class MlConfigIndexMappingsFullClusterRestartIT extends AbstractFullClust private static final String OLD_CLUSTER_JOB_ID = "ml-config-mappings-old-cluster-job"; private static final String NEW_CLUSTER_JOB_ID = "ml-config-mappings-new-cluster-job"; - private static final Map EXPECTED_DATA_FRAME_ANALYSIS_MAPPINGS = - mapOf( - "properties", mapOf( - "outlier_detection", mapOf( - "properties", mapOf( - "method", mapOf("type", "keyword"), - "n_neighbors", mapOf("type", "integer"), - "feature_influence_threshold", mapOf("type", "double"))))); + private static final Map EXPECTED_DATA_FRAME_ANALYSIS_MAPPINGS = getDataFrameAnalysisMappings(); + + @SuppressWarnings("unchecked") + private static Map getDataFrameAnalysisMappings() { + try (XContentBuilder builder = JsonXContent.contentBuilder()) { + builder.startObject(); + ElasticsearchMappings.addDataFrameAnalyticsFields(builder); + builder.endObject(); + + Map asMap = builder.generator().contentType().xContent().createParser( + NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, BytesReference.bytes(builder).streamInput()).map(); + return (Map) asMap.get(DataFrameAnalyticsConfig.ANALYSIS.getPreferredName()); + } catch (IOException e) { + fail("Failed to initialize expected data frame analysis mappings"); + } + return null; + } @Override protected Settings restClientSettings() { @@ -71,8 +87,8 @@ public class MlConfigIndexMappingsFullClusterRestartIT extends AbstractFullClust // trigger .ml-config index creation createAnomalyDetectorJob(OLD_CLUSTER_JOB_ID); if (getOldClusterVersion().onOrAfter(Version.V_7_3_0)) { - // .ml-config has correct mappings from the start - assertThat(mappingsForDataFrameAnalysis(), is(equalTo(EXPECTED_DATA_FRAME_ANALYSIS_MAPPINGS))); + // .ml-config has mappings for analytics as the feature was introduced in 7.3.0 + assertThat(mappingsForDataFrameAnalysis(), is(notNullValue())); } else { // .ml-config does not yet have correct mappings, it will need an update after cluster is upgraded assertThat(mappingsForDataFrameAnalysis(), is(nullValue())); @@ -125,18 +141,4 @@ public class MlConfigIndexMappingsFullClusterRestartIT extends AbstractFullClust mappings = (Map) XContentMapValues.extractValue(mappings, "properties", "analysis"); return mappings; } - - private static Map mapOf(K k1, V v1) { - Map map = new HashMap<>(); - map.put(k1, v1); - return map; - } - - private static Map mapOf(K k1, V v1, K k2, V v2, K k3, V v3) { - Map map = new HashMap<>(); - map.put(k1, v1); - map.put(k2, v2); - map.put(k3, v3); - return map; - } }