[7.x][ML] Add regression analysis to DF analytics (#45292) (#45388)

This commit adds a first draft of a regression analysis
to data frame analytics. There is high probability that
the exact syntax might change.

This commit adds the new analysis type and its parameters as
well as appropriate validation. It also modifies the extractor
and the fields detector to be able to handle categorical fields
as regression analysis supports them.
This commit is contained in:
Dimitris Athanasiou 2019-08-09 19:31:13 +03:00 committed by GitHub
parent d1ed9bdbfd
commit 27497ff75f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 1026 additions and 164 deletions

View File

@ -147,6 +147,7 @@ import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis; import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection; import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.AucRoc; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.AucRoc;
@ -454,6 +455,7 @@ public class XPackClientPlugin extends Plugin implements ActionPlugin, NetworkPl
MachineLearningFeatureSetUsage::new), MachineLearningFeatureSetUsage::new),
// ML - Data frame analytics // ML - Data frame analytics
new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, OutlierDetection.NAME.getPreferredName(), OutlierDetection::new), new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, OutlierDetection.NAME.getPreferredName(), OutlierDetection::new),
new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, Regression.NAME.getPreferredName(), Regression::new),
// ML - Data frame evaluation // ML - Data frame evaluation
new NamedWriteableRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME.getPreferredName(), new NamedWriteableRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME.getPreferredName(),
BinarySoftClassification::new), BinarySoftClassification::new),

View File

@ -9,8 +9,22 @@ import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.ToXContentObject;
import java.util.Map; import java.util.Map;
import java.util.Set;
public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable { public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable {
/**
* @return The analysis parameters as a map
*/
Map<String, Object> getParams(); Map<String, Object> getParams();
/**
* @return {@code true} if this analysis supports fields with categorical values (i.e. text, keyword, ip)
*/
boolean supportsCategoricalFields();
/**
* @return The set of fields that analyzed documents must have for the analysis to operate
*/
Set<String> getRequiredFields();
} }

View File

@ -22,6 +22,10 @@ public class MlDataFrameAnalysisNamedXContentProvider implements NamedXContentPr
boolean ignoreUnknownFields = (boolean) c; boolean ignoreUnknownFields = (boolean) c;
return OutlierDetection.fromXContent(p, ignoreUnknownFields); return OutlierDetection.fromXContent(p, ignoreUnknownFields);
})); }));
namedXContent.add(new NamedXContentRegistry.Entry(DataFrameAnalysis.class, Regression.NAME, (p, c) -> {
boolean ignoreUnknownFields = (boolean) c;
return Regression.fromXContent(p, ignoreUnknownFields);
}));
return namedXContent; return namedXContent;
} }
@ -31,6 +35,8 @@ public class MlDataFrameAnalysisNamedXContentProvider implements NamedXContentPr
namedWriteables.add(new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, OutlierDetection.NAME.getPreferredName(), namedWriteables.add(new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, OutlierDetection.NAME.getPreferredName(),
OutlierDetection::new)); OutlierDetection::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, Regression.NAME.getPreferredName(),
Regression::new));
return namedWriteables; return namedWriteables;
} }

View File

@ -16,10 +16,12 @@ import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException; import java.io.IOException;
import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.Locale; import java.util.Locale;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.Set;
public class OutlierDetection implements DataFrameAnalysis { public class OutlierDetection implements DataFrameAnalysis {
@ -152,6 +154,16 @@ public class OutlierDetection implements DataFrameAnalysis {
return params; return params;
} }
@Override
public boolean supportsCategoricalFields() {
return false;
}
@Override
public Set<String> getRequiredFields() {
return Collections.emptySet();
}
public enum Method { public enum Method {
LOF, LDOF, DISTANCE_KTH_NN, DISTANCE_KNN; LOF, LDOF, DISTANCE_KTH_NN, DISTANCE_KNN;

View File

@ -0,0 +1,205 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.analyses;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
public class Regression implements DataFrameAnalysis {
public static final ParseField NAME = new ParseField("regression");
public static final ParseField DEPENDENT_VARIABLE = new ParseField("dependent_variable");
public static final ParseField LAMBDA = new ParseField("lambda");
public static final ParseField GAMMA = new ParseField("gamma");
public static final ParseField ETA = new ParseField("eta");
public static final ParseField MAXIMUM_NUMBER_TREES = new ParseField("maximum_number_trees");
public static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction");
public static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
private static final ConstructingObjectParser<Regression, Void> LENIENT_PARSER = createParser(true);
private static final ConstructingObjectParser<Regression, Void> STRICT_PARSER = createParser(false);
private static ConstructingObjectParser<Regression, Void> createParser(boolean lenient) {
ConstructingObjectParser<Regression, Void> parser = new ConstructingObjectParser<>(NAME.getPreferredName(), lenient,
a -> new Regression((String) a[0], (Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5], (String) a[6]));
parser.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE);
parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), LAMBDA);
parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), GAMMA);
parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), ETA);
parser.declareInt(ConstructingObjectParser.optionalConstructorArg(), MAXIMUM_NUMBER_TREES);
parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), FEATURE_BAG_FRACTION);
parser.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME);
return parser;
}
public static Regression fromXContent(XContentParser parser, boolean ignoreUnknownFields) {
return ignoreUnknownFields ? LENIENT_PARSER.apply(parser, null) : STRICT_PARSER.apply(parser, null);
}
private final String dependentVariable;
private final Double lambda;
private final Double gamma;
private final Double eta;
private final Integer maximumNumberTrees;
private final Double featureBagFraction;
private final String predictionFieldName;
public Regression(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
@Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, @Nullable String predictionFieldName) {
this.dependentVariable = Objects.requireNonNull(dependentVariable);
if (lambda != null && lambda < 0) {
throw ExceptionsHelper.badRequestException("[{}] must be a non-negative double", LAMBDA.getPreferredName());
}
this.lambda = lambda;
if (gamma != null && gamma < 0) {
throw ExceptionsHelper.badRequestException("[{}] must be a non-negative double", GAMMA.getPreferredName());
}
this.gamma = gamma;
if (eta != null && (eta < 0.001 || eta > 1)) {
throw ExceptionsHelper.badRequestException("[{}] must be a double in [0.001, 1]", ETA.getPreferredName());
}
this.eta = eta;
if (maximumNumberTrees != null && (maximumNumberTrees <= 0 || maximumNumberTrees > 2000)) {
throw ExceptionsHelper.badRequestException("[{}] must be an integer in [1, 2000]", MAXIMUM_NUMBER_TREES.getPreferredName());
}
this.maximumNumberTrees = maximumNumberTrees;
if (featureBagFraction != null && (featureBagFraction <= 0 || featureBagFraction > 1.0)) {
throw ExceptionsHelper.badRequestException("[{}] must be a double in (0, 1]", FEATURE_BAG_FRACTION.getPreferredName());
}
this.featureBagFraction = featureBagFraction;
this.predictionFieldName = predictionFieldName;
}
public Regression(String dependentVariable) {
this(dependentVariable, null, null, null, null, null, null);
}
public Regression(StreamInput in) throws IOException {
dependentVariable = in.readString();
lambda = in.readOptionalDouble();
gamma = in.readOptionalDouble();
eta = in.readOptionalDouble();
maximumNumberTrees = in.readOptionalVInt();
featureBagFraction = in.readOptionalDouble();
predictionFieldName = in.readOptionalString();
}
@Override
public String getWriteableName() {
return NAME.getPreferredName();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(dependentVariable);
out.writeOptionalDouble(lambda);
out.writeOptionalDouble(gamma);
out.writeOptionalDouble(eta);
out.writeOptionalVInt(maximumNumberTrees);
out.writeOptionalDouble(featureBagFraction);
out.writeOptionalString(predictionFieldName);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable);
if (lambda != null) {
builder.field(LAMBDA.getPreferredName(), lambda);
}
if (gamma != null) {
builder.field(GAMMA.getPreferredName(), gamma);
}
if (eta != null) {
builder.field(ETA.getPreferredName(), eta);
}
if (maximumNumberTrees != null) {
builder.field(MAXIMUM_NUMBER_TREES.getPreferredName(), maximumNumberTrees);
}
if (featureBagFraction != null) {
builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction);
}
if (predictionFieldName != null) {
builder.field(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
}
builder.endObject();
return builder;
}
@Override
public Map<String, Object> getParams() {
Map<String, Object> params = new HashMap<>();
params.put(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable);
if (lambda != null) {
params.put(LAMBDA.getPreferredName(), lambda);
}
if (gamma != null) {
params.put(GAMMA.getPreferredName(), gamma);
}
if (eta != null) {
params.put(ETA.getPreferredName(), eta);
}
if (maximumNumberTrees != null) {
params.put(MAXIMUM_NUMBER_TREES.getPreferredName(), maximumNumberTrees);
}
if (featureBagFraction != null) {
params.put(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction);
}
if (predictionFieldName != null) {
params.put(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
}
return params;
}
@Override
public boolean supportsCategoricalFields() {
return true;
}
@Override
public Set<String> getRequiredFields() {
return Collections.singleton(dependentVariable);
}
@Override
public int hashCode() {
return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName);
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Regression that = (Regression) o;
return Objects.equals(dependentVariable, that.dependentVariable)
&& Objects.equals(lambda, that.lambda)
&& Objects.equals(gamma, that.gamma)
&& Objects.equals(eta, that.eta)
&& Objects.equals(maximumNumberTrees, that.maximumNumberTrees)
&& Objects.equals(featureBagFraction, that.featureBagFraction)
&& Objects.equals(predictionFieldName, that.predictionFieldName);
}
}

View File

@ -31,6 +31,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection; import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig; import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig;
import org.elasticsearch.xpack.core.ml.job.config.AnalysisLimits; import org.elasticsearch.xpack.core.ml.job.config.AnalysisLimits;
import org.elasticsearch.xpack.core.ml.job.config.DataDescription; import org.elasticsearch.xpack.core.ml.job.config.DataDescription;
@ -443,6 +444,31 @@ public class ElasticsearchMappings {
.endObject() .endObject()
.endObject() .endObject()
.endObject() .endObject()
.startObject(Regression.NAME.getPreferredName())
.startObject(PROPERTIES)
.startObject(Regression.DEPENDENT_VARIABLE.getPreferredName())
.field(TYPE, KEYWORD)
.endObject()
.startObject(Regression.LAMBDA.getPreferredName())
.field(TYPE, DOUBLE)
.endObject()
.startObject(Regression.GAMMA.getPreferredName())
.field(TYPE, DOUBLE)
.endObject()
.startObject(Regression.ETA.getPreferredName())
.field(TYPE, DOUBLE)
.endObject()
.startObject(Regression.MAXIMUM_NUMBER_TREES.getPreferredName())
.field(TYPE, INTEGER)
.endObject()
.startObject(Regression.FEATURE_BAG_FRACTION.getPreferredName())
.field(TYPE, DOUBLE)
.endObject()
.startObject(Regression.PREDICTION_FIELD_NAME.getPreferredName())
.field(TYPE, KEYWORD)
.endObject()
.endObject()
.endObject()
.endObject() .endObject()
.endObject() .endObject()
// re-used: CREATE_TIME // re-used: CREATE_TIME

View File

@ -14,6 +14,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection; import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig; import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig;
import org.elasticsearch.xpack.core.ml.job.config.AnalysisLimits; import org.elasticsearch.xpack.core.ml.job.config.AnalysisLimits;
import org.elasticsearch.xpack.core.ml.job.config.DataDescription; import org.elasticsearch.xpack.core.ml.job.config.DataDescription;
@ -299,6 +300,14 @@ public final class ReservedFieldNames {
OutlierDetection.N_NEIGHBORS.getPreferredName(), OutlierDetection.N_NEIGHBORS.getPreferredName(),
OutlierDetection.METHOD.getPreferredName(), OutlierDetection.METHOD.getPreferredName(),
OutlierDetection.FEATURE_INFLUENCE_THRESHOLD.getPreferredName(), OutlierDetection.FEATURE_INFLUENCE_THRESHOLD.getPreferredName(),
Regression.NAME.getPreferredName(),
Regression.DEPENDENT_VARIABLE.getPreferredName(),
Regression.LAMBDA.getPreferredName(),
Regression.GAMMA.getPreferredName(),
Regression.ETA.getPreferredName(),
Regression.MAXIMUM_NUMBER_TREES.getPreferredName(),
Regression.FEATURE_BAG_FRACTION.getPreferredName(),
Regression.PREDICTION_FIELD_NAME.getPreferredName(),
ElasticsearchMappings.CONFIG_TYPE, ElasticsearchMappings.CONFIG_TYPE,

View File

@ -0,0 +1,100 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.analyses;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractSerializingTestCase;
import java.io.IOException;
import static org.hamcrest.Matchers.equalTo;
public class RegressionTests extends AbstractSerializingTestCase<Regression> {
@Override
protected Regression doParseInstance(XContentParser parser) throws IOException {
return Regression.fromXContent(parser, false);
}
@Override
protected Regression createTestInstance() {
return createRandom();
}
public static Regression createRandom() {
Double lambda = randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, true);
Double gamma = randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, true);
Double eta = randomBoolean() ? null : randomDoubleBetween(0.001, 1.0, true);
Integer maximumNumberTrees = randomBoolean() ? null : randomIntBetween(1, 2000);
Double featureBagFraction = randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false);
String predictionFieldName = randomBoolean() ? null : randomAlphaOfLength(10);
return new Regression(randomAlphaOfLength(10), lambda, gamma, eta, maximumNumberTrees, featureBagFraction,
predictionFieldName);
}
@Override
protected Writeable.Reader<Regression> instanceReader() {
return Regression::new;
}
public void testRegression_GivenNegativeLambda() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", -0.00001, 0.0, 0.5, 500, 0.3, "result"));
assertThat(e.getMessage(), equalTo("[lambda] must be a non-negative double"));
}
public void testRegression_GivenNegativeGamma() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", 0.0, -0.00001, 0.5, 500, 0.3, "result"));
assertThat(e.getMessage(), equalTo("[gamma] must be a non-negative double"));
}
public void testRegression_GivenEtaIsZero() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", 0.0, 0.0, 0.0, 500, 0.3, "result"));
assertThat(e.getMessage(), equalTo("[eta] must be a double in [0.001, 1]"));
}
public void testRegression_GivenEtaIsGreaterThanOne() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", 0.0, 0.0, 1.00001, 500, 0.3, "result"));
assertThat(e.getMessage(), equalTo("[eta] must be a double in [0.001, 1]"));
}
public void testRegression_GivenMaximumNumberTreesIsZero() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", 0.0, 0.0, 0.5, 0, 0.3, "result"));
assertThat(e.getMessage(), equalTo("[maximum_number_trees] must be an integer in [1, 2000]"));
}
public void testRegression_GivenMaximumNumberTreesIsGreaterThan2k() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", 0.0, 0.0, 0.5, 2001, 0.3, "result"));
assertThat(e.getMessage(), equalTo("[maximum_number_trees] must be an integer in [1, 2000]"));
}
public void testRegression_GivenFeatureBagFractionIsLessThanZero() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", 0.0, 0.0, 0.5, 500, -0.00001, "result"));
assertThat(e.getMessage(), equalTo("[feature_bag_fraction] must be a double in (0, 1]"));
}
public void testRegression_GivenFeatureBagFractionIsGreaterThanOne() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", 0.0, 0.0, 0.5, 500, 1.00001, "result"));
assertThat(e.getMessage(), equalTo("[feature_bag_fraction] must be a double in (0, 1]"));
}
}

View File

@ -69,6 +69,15 @@ integTest.runner {
'ml/data_frame_analytics_crud/Test get stats given expression without matches and allow_no_match is false', 'ml/data_frame_analytics_crud/Test get stats given expression without matches and allow_no_match is false',
'ml/data_frame_analytics_crud/Test delete given missing config', 'ml/data_frame_analytics_crud/Test delete given missing config',
'ml/data_frame_analytics_crud/Test max model memory limit', 'ml/data_frame_analytics_crud/Test max model memory limit',
'ml/data_frame_analytics_crud/Test put regression given dependent_variable is not defined',
'ml/data_frame_analytics_crud/Test put regression given negative lambda',
'ml/data_frame_analytics_crud/Test put regression given negative gamma',
'ml/data_frame_analytics_crud/Test put regression given eta less than 1e-3',
'ml/data_frame_analytics_crud/Test put regression given eta greater than one',
'ml/data_frame_analytics_crud/Test put regression given maximum_number_trees is zero',
'ml/data_frame_analytics_crud/Test put regression given maximum_number_trees is greater than 2k',
'ml/data_frame_analytics_crud/Test put regression given feature_bag_fraction is negative',
'ml/data_frame_analytics_crud/Test put regression given feature_bag_fraction is greater than one',
'ml/evaluate_data_frame/Test given missing index', 'ml/evaluate_data_frame/Test given missing index',
'ml/evaluate_data_frame/Test given index does not exist', 'ml/evaluate_data_frame/Test given index does not exist',
'ml/evaluate_data_frame/Test given missing evaluation', 'ml/evaluate_data_frame/Test given missing evaluation',

View File

@ -21,6 +21,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection; import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
@ -118,4 +119,13 @@ abstract class MlNativeDataFrameAnalyticsIntegTestCase extends MlNativeIntegTest
assertThat(stats.get(0).getId(), equalTo(id)); assertThat(stats.get(0).getId(), equalTo(id));
assertThat(stats.get(0).getState(), equalTo(state)); assertThat(stats.get(0).getState(), equalTo(state));
} }
protected static DataFrameAnalyticsConfig buildRegressionAnalytics(String id, String[] sourceIndex, String destIndex,
@Nullable String resultsField, String dependentVariable) {
DataFrameAnalyticsConfig.Builder configBuilder = new DataFrameAnalyticsConfig.Builder(id);
configBuilder.setSource(new DataFrameAnalyticsSource(sourceIndex, null));
configBuilder.setDest(new DataFrameAnalyticsDest(destIndex, resultsField));
configBuilder.setAnalysis(new Regression(dependentVariable));
return configBuilder.build();
}
} }

View File

@ -21,9 +21,12 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
import org.junit.After; import org.junit.After;
import java.util.Arrays;
import java.util.List;
import java.util.Map; import java.util.Map;
import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.greaterThanOrEqualTo;
@ -362,4 +365,68 @@ public class RunDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsIntegTest
.setQuery(QueryBuilders.existsQuery("ml.outlier_score")).get(); .setQuery(QueryBuilders.existsQuery("ml.outlier_score")).get();
assertThat(searchResponse.getHits().getTotalHits().value, equalTo((long) bulkRequestBuilder.numberOfActions())); assertThat(searchResponse.getHits().getTotalHits().value, equalTo((long) bulkRequestBuilder.numberOfActions()));
} }
public void testRegressionWithNumericFeatureAndFewDocuments() throws Exception {
String sourceIndex = "test-regression-with-numeric-feature-and-few-docs";
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk();
bulkRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
List<Double> featureValues = Arrays.asList(1.0, 2.0, 3.0);
List<Double> dependentVariableValues = Arrays.asList(10.0, 20.0, 30.0);
for (int i = 0; i < 350; i++) {
Double field = featureValues.get(i % 3);
Double value = dependentVariableValues.get(i % 3);
IndexRequest indexRequest = new IndexRequest(sourceIndex);
if (i < 300) {
indexRequest.source("feature", field, "variable", value);
} else {
indexRequest.source("feature", field);
}
bulkRequestBuilder.add(indexRequest);
}
BulkResponse bulkResponse = bulkRequestBuilder.get();
if (bulkResponse.hasFailures()) {
fail("Failed to index data: " + bulkResponse.buildFailureMessage());
}
String id = "test_regression_with_numeric_feature_and_few_docs";
DataFrameAnalyticsConfig config = buildRegressionAnalytics(id, new String[] {sourceIndex},
sourceIndex + "-results", null, "variable");
registerAnalytics(config);
putAnalytics(config);
assertState(id, DataFrameAnalyticsState.STOPPED);
startAnalytics(id);
waitUntilAnalyticsIsStopped(id);
int resultsWithPrediction = 0;
SearchResponse sourceData = client().prepareSearch(sourceIndex).get();
for (SearchHit hit : sourceData.getHits()) {
GetResponse destDocGetResponse = client().prepareGet().setIndex(config.getDest().getIndex()).setId(hit.getId()).get();
assertThat(destDocGetResponse.isExists(), is(true));
Map<String, Object> sourceDoc = hit.getSourceAsMap();
Map<String, Object> destDoc = destDocGetResponse.getSource();
for (String field : sourceDoc.keySet()) {
assertThat(destDoc.containsKey(field), is(true));
assertThat(destDoc.get(field), equalTo(sourceDoc.get(field)));
}
assertThat(destDoc.containsKey("ml"), is(true));
@SuppressWarnings("unchecked")
Map<String, Object> resultsObject = (Map<String, Object>) destDoc.get("ml");
if (resultsObject.containsKey("variable_prediction")) {
resultsWithPrediction++;
double featureValue = (double) destDoc.get("feature");
double predictionValue = (double) resultsObject.get("variable_prediction");
// it seems for this case values can be as far off as 2.0
assertThat(predictionValue, closeTo(10 * featureValue, 2.0));
}
}
assertThat(resultsWithPrediction, greaterThan(0));
}
} }

View File

@ -16,10 +16,12 @@ import org.elasticsearch.search.SearchHit;
import java.io.IOException; import java.io.IOException;
import java.text.ParseException; import java.text.ParseException;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Locale; import java.util.Locale;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.Set;
/** /**
* Represents a field to be extracted by the datafeed. * Represents a field to be extracted by the datafeed.
@ -37,11 +39,14 @@ public abstract class ExtractedField {
/** The name of the field we extract */ /** The name of the field we extract */
protected final String name; protected final String name;
private final Set<String> types;
private final ExtractionMethod extractionMethod; private final ExtractionMethod extractionMethod;
protected ExtractedField(String alias, String name, ExtractionMethod extractionMethod) { protected ExtractedField(String alias, String name, Set<String> types, ExtractionMethod extractionMethod) {
this.alias = Objects.requireNonNull(alias); this.alias = Objects.requireNonNull(alias);
this.name = Objects.requireNonNull(name); this.name = Objects.requireNonNull(name);
this.types = Objects.requireNonNull(types);
this.extractionMethod = Objects.requireNonNull(extractionMethod); this.extractionMethod = Objects.requireNonNull(extractionMethod);
} }
@ -53,6 +58,10 @@ public abstract class ExtractedField {
return name; return name;
} }
public Set<String> getTypes() {
return types;
}
public ExtractionMethod getExtractionMethod() { public ExtractionMethod getExtractionMethod() {
return extractionMethod; return extractionMethod;
} }
@ -65,32 +74,32 @@ public abstract class ExtractedField {
return null; return null;
} }
public static ExtractedField newTimeField(String name, ExtractionMethod extractionMethod) { public static ExtractedField newTimeField(String name, Set<String> types, ExtractionMethod extractionMethod) {
if (extractionMethod == ExtractionMethod.SOURCE) { if (extractionMethod == ExtractionMethod.SOURCE) {
throw new IllegalArgumentException("time field cannot be extracted from source"); throw new IllegalArgumentException("time field cannot be extracted from source");
} }
return new TimeField(name, extractionMethod); return new TimeField(name, types, extractionMethod);
} }
public static ExtractedField newGeoShapeField(String alias, String name) { public static ExtractedField newGeoShapeField(String alias, String name) {
return new GeoShapeField(alias, name); return new GeoShapeField(alias, name, Collections.singleton("geo_shape"));
} }
public static ExtractedField newGeoPointField(String alias, String name) { public static ExtractedField newGeoPointField(String alias, String name) {
return new GeoPointField(alias, name); return new GeoPointField(alias, name, Collections.singleton("geo_point"));
} }
public static ExtractedField newField(String name, ExtractionMethod extractionMethod) { public static ExtractedField newField(String name, Set<String> types, ExtractionMethod extractionMethod) {
return newField(name, name, extractionMethod); return newField(name, name, types, extractionMethod);
} }
public static ExtractedField newField(String alias, String name, ExtractionMethod extractionMethod) { public static ExtractedField newField(String alias, String name, Set<String> types, ExtractionMethod extractionMethod) {
switch (extractionMethod) { switch (extractionMethod) {
case DOC_VALUE: case DOC_VALUE:
case SCRIPT_FIELD: case SCRIPT_FIELD:
return new FromFields(alias, name, extractionMethod); return new FromFields(alias, name, types, extractionMethod);
case SOURCE: case SOURCE:
return new FromSource(alias, name); return new FromSource(alias, name, types);
default: default:
throw new IllegalArgumentException("Invalid extraction method [" + extractionMethod + "]"); throw new IllegalArgumentException("Invalid extraction method [" + extractionMethod + "]");
} }
@ -98,7 +107,7 @@ public abstract class ExtractedField {
public ExtractedField newFromSource() { public ExtractedField newFromSource() {
if (supportsFromSource()) { if (supportsFromSource()) {
return new FromSource(alias, name); return new FromSource(alias, name, types);
} }
throw new IllegalStateException("Field (alias [" + alias + "], name [" + name + "]) should be extracted via [" throw new IllegalStateException("Field (alias [" + alias + "], name [" + name + "]) should be extracted via ["
+ extractionMethod + "] and cannot be extracted from source"); + extractionMethod + "] and cannot be extracted from source");
@ -106,8 +115,8 @@ public abstract class ExtractedField {
private static class FromFields extends ExtractedField { private static class FromFields extends ExtractedField {
FromFields(String alias, String name, ExtractionMethod extractionMethod) { FromFields(String alias, String name, Set<String> types, ExtractionMethod extractionMethod) {
super(alias, name, extractionMethod); super(alias, name, types, extractionMethod);
} }
@Override @Override
@ -129,8 +138,8 @@ public abstract class ExtractedField {
private static class GeoShapeField extends FromSource { private static class GeoShapeField extends FromSource {
private static final WellKnownText wkt = new WellKnownText(true, new StandardValidator(true)); private static final WellKnownText wkt = new WellKnownText(true, new StandardValidator(true));
GeoShapeField(String alias, String name) { GeoShapeField(String alias, String name, Set<String> types) {
super(alias, name); super(alias, name, types);
} }
@Override @Override
@ -186,8 +195,8 @@ public abstract class ExtractedField {
private static class GeoPointField extends FromFields { private static class GeoPointField extends FromFields {
GeoPointField(String alias, String name) { GeoPointField(String alias, String name, Set<String> types) {
super(alias, name, ExtractionMethod.DOC_VALUE); super(alias, name, types, ExtractionMethod.DOC_VALUE);
} }
@Override @Override
@ -222,8 +231,8 @@ public abstract class ExtractedField {
private static final String EPOCH_MILLIS_FORMAT = "epoch_millis"; private static final String EPOCH_MILLIS_FORMAT = "epoch_millis";
TimeField(String name, ExtractionMethod extractionMethod) { TimeField(String name, Set<String> types, ExtractionMethod extractionMethod) {
super(name, name, extractionMethod); super(name, name, types, extractionMethod);
} }
@Override @Override
@ -255,8 +264,8 @@ public abstract class ExtractedField {
private String[] namePath; private String[] namePath;
FromSource(String alias, String name) { FromSource(String alias, String name, Set<String> types) {
super(alias, name, ExtractionMethod.SOURCE); super(alias, name, types, ExtractionMethod.SOURCE);
namePath = name.split("\\."); namePath = name.split("\\.");
} }

View File

@ -47,15 +47,6 @@ public class ExtractedFields {
return docValueFields; return docValueFields;
} }
/**
* Returns a new instance which only contains fields matching the given extraction method
* @param method the extraction method to filter fields on
* @return a new instance which only contains fields matching the given extraction method
*/
public ExtractedFields filterFields(ExtractedField.ExtractionMethod method) {
return new ExtractedFields(filterFields(method, allFields));
}
private static List<ExtractedField> filterFields(ExtractedField.ExtractionMethod method, List<ExtractedField> fields) { private static List<ExtractedField> filterFields(ExtractedField.ExtractionMethod method, List<ExtractedField> fields) {
return fields.stream().filter(field -> field.getExtractionMethod() == method).collect(Collectors.toList()); return fields.stream().filter(field -> field.getExtractionMethod() == method).collect(Collectors.toList());
} }
@ -79,12 +70,13 @@ public class ExtractedFields {
protected ExtractedField detect(String field) { protected ExtractedField detect(String field) {
String internalField = field; String internalField = field;
ExtractedField.ExtractionMethod method = ExtractedField.ExtractionMethod.SOURCE; ExtractedField.ExtractionMethod method = ExtractedField.ExtractionMethod.SOURCE;
Set<String> types = getTypes(field);
if (scriptFields.contains(field)) { if (scriptFields.contains(field)) {
method = ExtractedField.ExtractionMethod.SCRIPT_FIELD; method = ExtractedField.ExtractionMethod.SCRIPT_FIELD;
} else if (isAggregatable(field)) { } else if (isAggregatable(field)) {
method = ExtractedField.ExtractionMethod.DOC_VALUE; method = ExtractedField.ExtractionMethod.DOC_VALUE;
if (isFieldOfType(field, "date")) { if (isFieldOfType(field, "date")) {
return ExtractedField.newTimeField(field, method); return ExtractedField.newTimeField(field, types, method);
} }
} else if (isFieldOfType(field, TEXT)) { } else if (isFieldOfType(field, TEXT)) {
String parentField = MlStrings.getParentField(field); String parentField = MlStrings.getParentField(field);
@ -107,7 +99,12 @@ public class ExtractedFields {
return ExtractedField.newGeoShapeField(field, internalField); return ExtractedField.newGeoShapeField(field, internalField);
} }
return ExtractedField.newField(field, internalField, method); return ExtractedField.newField(field, internalField, types, method);
}
private Set<String> getTypes(String field) {
Map<String, FieldCapabilities> fieldCaps = fieldsCapabilities.getField(field);
return fieldCaps == null ? Collections.emptySet() : fieldCaps.keySet();
} }
protected boolean isAggregatable(String field) { protected boolean isAggregatable(String field) {

View File

@ -12,6 +12,7 @@ import org.elasticsearch.xpack.core.ml.job.config.Job;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Objects; import java.util.Objects;
import java.util.Set; import java.util.Set;
@ -55,12 +56,20 @@ public class TimeBasedExtractedFields extends ExtractedFields {
if (scriptFields.contains(timeField) == false && extractionMethodDetector.isAggregatable(timeField) == false) { if (scriptFields.contains(timeField) == false && extractionMethodDetector.isAggregatable(timeField) == false) {
throw new IllegalArgumentException("cannot retrieve time field [" + timeField + "] because it is not aggregatable"); throw new IllegalArgumentException("cannot retrieve time field [" + timeField + "] because it is not aggregatable");
} }
ExtractedField timeExtractedField = ExtractedField.newTimeField(timeField, scriptFields.contains(timeField) ? ExtractedField timeExtractedField = extractedTimeField(timeField, scriptFields, fieldsCapabilities);
ExtractedField.ExtractionMethod.SCRIPT_FIELD : ExtractedField.ExtractionMethod.DOC_VALUE);
List<String> remainingFields = job.allInputFields().stream().filter(f -> !f.equals(timeField)).collect(Collectors.toList()); List<String> remainingFields = job.allInputFields().stream().filter(f -> !f.equals(timeField)).collect(Collectors.toList());
List<ExtractedField> allExtractedFields = new ArrayList<>(remainingFields.size() + 1); List<ExtractedField> allExtractedFields = new ArrayList<>(remainingFields.size() + 1);
allExtractedFields.add(timeExtractedField); allExtractedFields.add(timeExtractedField);
remainingFields.stream().forEach(field -> allExtractedFields.add(extractionMethodDetector.detect(field))); remainingFields.stream().forEach(field -> allExtractedFields.add(extractionMethodDetector.detect(field)));
return new TimeBasedExtractedFields(timeExtractedField, allExtractedFields); return new TimeBasedExtractedFields(timeExtractedField, allExtractedFields);
} }
private static ExtractedField extractedTimeField(String timeField, Set<String> scriptFields,
FieldCapabilitiesResponse fieldCapabilities) {
if (scriptFields.contains(timeField)) {
return ExtractedField.newTimeField(timeField, Collections.emptySet(), ExtractedField.ExtractionMethod.SCRIPT_FIELD);
}
return ExtractedField.newTimeField(timeField, fieldCapabilities.getField(timeField).keySet(),
ExtractedField.ExtractionMethod.DOC_VALUE);
}
} }

View File

@ -29,11 +29,13 @@ import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.NoSuchElementException; import java.util.NoSuchElementException;
import java.util.Objects; import java.util.Objects;
import java.util.Optional; import java.util.Optional;
import java.util.Set;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.function.Supplier; import java.util.function.Supplier;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@ -179,7 +181,7 @@ public class DataFrameDataExtractor {
for (int i = 0; i < extractedValues.length; ++i) { for (int i = 0; i < extractedValues.length; ++i) {
ExtractedField field = context.extractedFields.getAllFields().get(i); ExtractedField field = context.extractedFields.getAllFields().get(i);
Object[] values = field.value(hit); Object[] values = field.value(hit);
if (values.length == 1 && values[0] instanceof Number) { if (values.length == 1 && (values[0] instanceof Number || values[0] instanceof String)) {
extractedValues[i] = Objects.toString(values[0]); extractedValues[i] = Objects.toString(values[0]);
} else { } else {
extractedValues = null; extractedValues = null;
@ -233,6 +235,17 @@ public class DataFrameDataExtractor {
return new DataSummary(searchResponse.getHits().getTotalHits().value, context.extractedFields.getAllFields().size()); return new DataSummary(searchResponse.getHits().getTotalHits().value, context.extractedFields.getAllFields().size());
} }
public Set<String> getCategoricalFields() {
Set<String> categoricalFields = new HashSet<>();
for (ExtractedField extractedField : context.extractedFields.getAllFields()) {
String fieldName = extractedField.getName();
if (ExtractedFieldsDetector.CATEGORICAL_TYPES.containsAll(extractedField.getTypes())) {
categoricalFields.add(fieldName);
}
}
return categoricalFields;
}
public static class DataSummary { public static class DataSummary {
public final long rows; public final long rows;

View File

@ -5,6 +5,8 @@
*/ */
package org.elasticsearch.xpack.ml.dataframe.extractor; package org.elasticsearch.xpack.ml.dataframe.extractor;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.fieldcaps.FieldCapabilities; import org.elasticsearch.action.fieldcaps.FieldCapabilities;
import org.elasticsearch.action.fieldcaps.FieldCapabilitiesResponse; import org.elasticsearch.action.fieldcaps.FieldCapabilitiesResponse;
@ -20,6 +22,7 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.NameResolver; import org.elasticsearch.xpack.core.ml.utils.NameResolver;
import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedField; import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedField;
import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedFields; import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedFields;
import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsIndex;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
@ -35,24 +38,24 @@ import java.util.stream.Stream;
public class ExtractedFieldsDetector { public class ExtractedFieldsDetector {
private static final Logger LOGGER = LogManager.getLogger(ExtractedFieldsDetector.class);
/** /**
* Fields to ignore. These are mostly internal meta fields. * Fields to ignore. These are mostly internal meta fields.
*/ */
private static final List<String> IGNORE_FIELDS = Arrays.asList("_id", "_field_names", "_index", "_parent", "_routing", "_seq_no", private static final List<String> IGNORE_FIELDS = Arrays.asList("_id", "_field_names", "_index", "_parent", "_routing", "_seq_no",
"_source", "_type", "_uid", "_version", "_feature", "_ignored"); "_source", "_type", "_uid", "_version", "_feature", "_ignored", DataFrameAnalyticsIndex.ID_COPY);
/** public static final Set<String> CATEGORICAL_TYPES = Collections.unmodifiableSet(new HashSet<>(Arrays.asList("text", "keyword", "ip")));
* The types supported by data frames
*/ private static final Set<String> NUMERICAL_TYPES;
private static final Set<String> COMPATIBLE_FIELD_TYPES;
static { static {
Set<String> compatibleTypes = Stream.of(NumberFieldMapper.NumberType.values()) Set<String> numericalTypes = Stream.of(NumberFieldMapper.NumberType.values())
.map(NumberFieldMapper.NumberType::typeName) .map(NumberFieldMapper.NumberType::typeName)
.collect(Collectors.toSet()); .collect(Collectors.toSet());
compatibleTypes.add("scaled_float"); // have to add manually since scaled_float is in a module numericalTypes.add("scaled_float");
NUMERICAL_TYPES = Collections.unmodifiableSet(numericalTypes);
COMPATIBLE_FIELD_TYPES = Collections.unmodifiableSet(compatibleTypes);
} }
private final String[] index; private final String[] index;
@ -79,16 +82,18 @@ public class ExtractedFieldsDetector {
// Ignore fields under the results object // Ignore fields under the results object
fields.removeIf(field -> field.startsWith(config.getDest().getResultsField() + ".")); fields.removeIf(field -> field.startsWith(config.getDest().getResultsField() + "."));
includeAndExcludeFields(fields);
removeFieldsWithIncompatibleTypes(fields); removeFieldsWithIncompatibleTypes(fields);
includeAndExcludeFields(fields, index); checkRequiredFieldsArePresent(fields);
if (fields.isEmpty()) {
throw ExceptionsHelper.badRequestException("No compatible fields could be detected in index {}", Arrays.toString(index));
}
List<String> sortedFields = new ArrayList<>(fields); List<String> sortedFields = new ArrayList<>(fields);
// We sort the fields to ensure the checksum for each document is deterministic // We sort the fields to ensure the checksum for each document is deterministic
Collections.sort(sortedFields); Collections.sort(sortedFields);
ExtractedFields extractedFields = ExtractedFields.build(sortedFields, Collections.emptySet(), fieldCapabilitiesResponse) ExtractedFields extractedFields = ExtractedFields.build(sortedFields, Collections.emptySet(), fieldCapabilitiesResponse);
.filterFields(ExtractedField.ExtractionMethod.DOC_VALUE);
if (extractedFields.getAllFields().isEmpty()) {
throw ExceptionsHelper.badRequestException("No compatible fields could be detected in index {}", Arrays.toString(index));
}
if (extractedFields.getDocValueFields().size() > docValueFieldsLimit) { if (extractedFields.getDocValueFields().size() > docValueFieldsLimit) {
extractedFields = fetchFromSourceIfSupported(extractedFields); extractedFields = fetchFromSourceIfSupported(extractedFields);
if (extractedFields.getDocValueFields().size() > docValueFieldsLimit) { if (extractedFields.getDocValueFields().size() > docValueFieldsLimit) {
@ -120,13 +125,25 @@ public class ExtractedFieldsDetector {
while (fieldsIterator.hasNext()) { while (fieldsIterator.hasNext()) {
String field = fieldsIterator.next(); String field = fieldsIterator.next();
Map<String, FieldCapabilities> fieldCaps = fieldCapabilitiesResponse.getField(field); Map<String, FieldCapabilities> fieldCaps = fieldCapabilitiesResponse.getField(field);
if (fieldCaps == null || COMPATIBLE_FIELD_TYPES.containsAll(fieldCaps.keySet()) == false) { if (fieldCaps == null) {
LOGGER.debug("[{}] Removing field [{}] because it is missing from mappings", config.getId(), field);
fieldsIterator.remove(); fieldsIterator.remove();
} else {
Set<String> fieldTypes = fieldCaps.keySet();
if (NUMERICAL_TYPES.containsAll(fieldTypes)) {
LOGGER.debug("[{}] field [{}] is compatible as it is numerical", config.getId(), field);
} else if (config.getAnalysis().supportsCategoricalFields() && CATEGORICAL_TYPES.containsAll(fieldTypes)) {
LOGGER.debug("[{}] field [{}] is compatible as it is categorical", config.getId(), field);
} else {
LOGGER.debug("[{}] Removing field [{}] because its types are not supported; types {}",
config.getId(), field, fieldTypes);
fieldsIterator.remove();
}
} }
} }
} }
private void includeAndExcludeFields(Set<String> fields, String[] index) { private void includeAndExcludeFields(Set<String> fields) {
FetchSourceContext analyzedFields = config.getAnalyzedFields(); FetchSourceContext analyzedFields = config.getAnalyzedFields();
if (analyzedFields == null) { if (analyzedFields == null) {
return; return;
@ -159,6 +176,16 @@ public class ExtractedFieldsDetector {
} }
} }
private void checkRequiredFieldsArePresent(Set<String> fields) {
List<String> missingFields = config.getAnalysis().getRequiredFields()
.stream()
.filter(f -> fields.contains(f) == false)
.collect(Collectors.toList());
if (missingFields.isEmpty() == false) {
throw ExceptionsHelper.badRequestException("required fields {} are missing", missingFields);
}
}
private ExtractedFields fetchFromSourceIfSupported(ExtractedFields extractedFields) { private ExtractedFields fetchFromSourceIfSupported(ExtractedFields extractedFields) {
List<ExtractedField> adjusted = new ArrayList<>(extractedFields.getAllFields().size()); List<ExtractedField> adjusted = new ArrayList<>(extractedFields.getAllFields().size());
for (ExtractedField field : extractedFields.getDocValueFields()) { for (ExtractedField field : extractedFields.getDocValueFields()) {

View File

@ -12,6 +12,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
import java.io.IOException; import java.io.IOException;
import java.util.Objects; import java.util.Objects;
import java.util.Set;
public class AnalyticsProcessConfig implements ToXContentObject { public class AnalyticsProcessConfig implements ToXContentObject {
@ -21,21 +22,24 @@ public class AnalyticsProcessConfig implements ToXContentObject {
private static final String THREADS = "threads"; private static final String THREADS = "threads";
private static final String ANALYSIS = "analysis"; private static final String ANALYSIS = "analysis";
private static final String RESULTS_FIELD = "results_field"; private static final String RESULTS_FIELD = "results_field";
private static final String CATEGORICAL_FIELDS = "categorical_fields";
private final long rows; private final long rows;
private final int cols; private final int cols;
private final ByteSizeValue memoryLimit; private final ByteSizeValue memoryLimit;
private final int threads; private final int threads;
private final DataFrameAnalysis analysis;
private final String resultsField; private final String resultsField;
private final Set<String> categoricalFields;
private final DataFrameAnalysis analysis;
public AnalyticsProcessConfig(long rows, int cols, ByteSizeValue memoryLimit, int threads, String resultsField, public AnalyticsProcessConfig(long rows, int cols, ByteSizeValue memoryLimit, int threads, String resultsField,
DataFrameAnalysis analysis) { Set<String> categoricalFields, DataFrameAnalysis analysis) {
this.rows = rows; this.rows = rows;
this.cols = cols; this.cols = cols;
this.memoryLimit = Objects.requireNonNull(memoryLimit); this.memoryLimit = Objects.requireNonNull(memoryLimit);
this.threads = threads; this.threads = threads;
this.resultsField = Objects.requireNonNull(resultsField); this.resultsField = Objects.requireNonNull(resultsField);
this.categoricalFields = Objects.requireNonNull(categoricalFields);
this.analysis = Objects.requireNonNull(analysis); this.analysis = Objects.requireNonNull(analysis);
} }
@ -51,6 +55,7 @@ public class AnalyticsProcessConfig implements ToXContentObject {
builder.field(MEMORY_LIMIT, memoryLimit.getBytes()); builder.field(MEMORY_LIMIT, memoryLimit.getBytes());
builder.field(THREADS, threads); builder.field(THREADS, threads);
builder.field(RESULTS_FIELD, resultsField); builder.field(RESULTS_FIELD, resultsField);
builder.field(CATEGORICAL_FIELDS, categoricalFields);
builder.field(ANALYSIS, new DataFrameAnalysisWrapper(analysis)); builder.field(ANALYSIS, new DataFrameAnalysisWrapper(analysis));
builder.endObject(); builder.endObject();
return builder; return builder;

View File

@ -26,6 +26,7 @@ import java.io.IOException;
import java.util.List; import java.util.List;
import java.util.Objects; import java.util.Objects;
import java.util.Optional; import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
@ -283,8 +284,9 @@ public class AnalyticsProcessManager {
private AnalyticsProcessConfig createProcessConfig(DataFrameAnalyticsConfig config, DataFrameDataExtractor dataExtractor) { private AnalyticsProcessConfig createProcessConfig(DataFrameAnalyticsConfig config, DataFrameDataExtractor dataExtractor) {
DataFrameDataExtractor.DataSummary dataSummary = dataExtractor.collectDataSummary(); DataFrameDataExtractor.DataSummary dataSummary = dataExtractor.collectDataSummary();
Set<String> categoricalFields = dataExtractor.getCategoricalFields();
AnalyticsProcessConfig processConfig = new AnalyticsProcessConfig(dataSummary.rows, dataSummary.cols, AnalyticsProcessConfig processConfig = new AnalyticsProcessConfig(dataSummary.rows, dataSummary.cols,
config.getModelMemoryLimit(), 1, config.getDest().getResultsField(), config.getAnalysis()); config.getModelMemoryLimit(), 1, config.getDest().getResultsField(), categoricalFields, config.getAnalysis());
return processConfig; return processConfig;
} }
} }

View File

@ -10,6 +10,7 @@ import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.ml.test.SearchHitBuilder; import org.elasticsearch.xpack.ml.test.SearchHitBuilder;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.startsWith; import static org.hamcrest.Matchers.startsWith;
@ -19,46 +20,51 @@ public class ExtractedFieldTests extends ESTestCase {
public void testValueGivenDocValue() { public void testValueGivenDocValue() {
SearchHit hit = new SearchHitBuilder(42).addField("single", "bar").addField("array", Arrays.asList("a", "b")).build(); SearchHit hit = new SearchHitBuilder(42).addField("single", "bar").addField("array", Arrays.asList("a", "b")).build();
ExtractedField single = ExtractedField.newField("single", ExtractedField.ExtractionMethod.DOC_VALUE); ExtractedField single = ExtractedField.newField("single", Collections.singleton("keyword"),
ExtractedField.ExtractionMethod.DOC_VALUE);
assertThat(single.value(hit), equalTo(new String[] { "bar" })); assertThat(single.value(hit), equalTo(new String[] { "bar" }));
ExtractedField array = ExtractedField.newField("array", ExtractedField.ExtractionMethod.DOC_VALUE); ExtractedField array = ExtractedField.newField("array", Collections.singleton("keyword"),
ExtractedField.ExtractionMethod.DOC_VALUE);
assertThat(array.value(hit), equalTo(new String[] { "a", "b" })); assertThat(array.value(hit), equalTo(new String[] { "a", "b" }));
ExtractedField missing = ExtractedField.newField("missing", ExtractedField.ExtractionMethod.DOC_VALUE); ExtractedField missing = ExtractedField.newField("missing",Collections.singleton("keyword"),
ExtractedField.ExtractionMethod.DOC_VALUE);
assertThat(missing.value(hit), equalTo(new Object[0])); assertThat(missing.value(hit), equalTo(new Object[0]));
} }
public void testValueGivenScriptField() { public void testValueGivenScriptField() {
SearchHit hit = new SearchHitBuilder(42).addField("single", "bar").addField("array", Arrays.asList("a", "b")).build(); SearchHit hit = new SearchHitBuilder(42).addField("single", "bar").addField("array", Arrays.asList("a", "b")).build();
ExtractedField single = ExtractedField.newField("single", ExtractedField.ExtractionMethod.SCRIPT_FIELD); ExtractedField single = ExtractedField.newField("single",Collections.emptySet(),
ExtractedField.ExtractionMethod.SCRIPT_FIELD);
assertThat(single.value(hit), equalTo(new String[] { "bar" })); assertThat(single.value(hit), equalTo(new String[] { "bar" }));
ExtractedField array = ExtractedField.newField("array", ExtractedField.ExtractionMethod.SCRIPT_FIELD); ExtractedField array = ExtractedField.newField("array", Collections.emptySet(), ExtractedField.ExtractionMethod.SCRIPT_FIELD);
assertThat(array.value(hit), equalTo(new String[] { "a", "b" })); assertThat(array.value(hit), equalTo(new String[] { "a", "b" }));
ExtractedField missing = ExtractedField.newField("missing", ExtractedField.ExtractionMethod.SCRIPT_FIELD); ExtractedField missing = ExtractedField.newField("missing", Collections.emptySet(), ExtractedField.ExtractionMethod.SCRIPT_FIELD);
assertThat(missing.value(hit), equalTo(new Object[0])); assertThat(missing.value(hit), equalTo(new Object[0]));
} }
public void testValueGivenSource() { public void testValueGivenSource() {
SearchHit hit = new SearchHitBuilder(42).setSource("{\"single\":\"bar\",\"array\":[\"a\",\"b\"]}").build(); SearchHit hit = new SearchHitBuilder(42).setSource("{\"single\":\"bar\",\"array\":[\"a\",\"b\"]}").build();
ExtractedField single = ExtractedField.newField("single", ExtractedField.ExtractionMethod.SOURCE); ExtractedField single = ExtractedField.newField("single", Collections.singleton("text"), ExtractedField.ExtractionMethod.SOURCE);
assertThat(single.value(hit), equalTo(new String[] { "bar" })); assertThat(single.value(hit), equalTo(new String[] { "bar" }));
ExtractedField array = ExtractedField.newField("array", ExtractedField.ExtractionMethod.SOURCE); ExtractedField array = ExtractedField.newField("array", Collections.singleton("text"), ExtractedField.ExtractionMethod.SOURCE);
assertThat(array.value(hit), equalTo(new String[] { "a", "b" })); assertThat(array.value(hit), equalTo(new String[] { "a", "b" }));
ExtractedField missing = ExtractedField.newField("missing", ExtractedField.ExtractionMethod.SOURCE); ExtractedField missing = ExtractedField.newField("missing", Collections.singleton("text"), ExtractedField.ExtractionMethod.SOURCE);
assertThat(missing.value(hit), equalTo(new Object[0])); assertThat(missing.value(hit), equalTo(new Object[0]));
} }
public void testValueGivenNestedSource() { public void testValueGivenNestedSource() {
SearchHit hit = new SearchHitBuilder(42).setSource("{\"level_1\":{\"level_2\":{\"foo\":\"bar\"}}}").build(); SearchHit hit = new SearchHitBuilder(42).setSource("{\"level_1\":{\"level_2\":{\"foo\":\"bar\"}}}").build();
ExtractedField nested = ExtractedField.newField("alias", "level_1.level_2.foo", ExtractedField.ExtractionMethod.SOURCE); ExtractedField nested = ExtractedField.newField("alias", "level_1.level_2.foo", Collections.singleton("text"),
ExtractedField.ExtractionMethod.SOURCE);
assertThat(nested.value(hit), equalTo(new String[] { "bar" })); assertThat(nested.value(hit), equalTo(new String[] { "bar" }));
} }
@ -91,49 +97,54 @@ public class ExtractedFieldTests extends ESTestCase {
} }
public void testValueGivenSourceAndHitWithNoSource() { public void testValueGivenSourceAndHitWithNoSource() {
ExtractedField missing = ExtractedField.newField("missing", ExtractedField.ExtractionMethod.SOURCE); ExtractedField missing = ExtractedField.newField("missing", Collections.singleton("text"), ExtractedField.ExtractionMethod.SOURCE);
assertThat(missing.value(new SearchHitBuilder(3).build()), equalTo(new Object[0])); assertThat(missing.value(new SearchHitBuilder(3).build()), equalTo(new Object[0]));
} }
public void testValueGivenMismatchingMethod() { public void testValueGivenMismatchingMethod() {
SearchHit hit = new SearchHitBuilder(42).addField("a", 1).setSource("{\"b\":2}").build(); SearchHit hit = new SearchHitBuilder(42).addField("a", 1).setSource("{\"b\":2}").build();
ExtractedField invalidA = ExtractedField.newField("a", ExtractedField.ExtractionMethod.SOURCE); ExtractedField invalidA = ExtractedField.newField("a", Collections.singleton("text"), ExtractedField.ExtractionMethod.SOURCE);
assertThat(invalidA.value(hit), equalTo(new Object[0])); assertThat(invalidA.value(hit), equalTo(new Object[0]));
ExtractedField validA = ExtractedField.newField("a", ExtractedField.ExtractionMethod.DOC_VALUE); ExtractedField validA = ExtractedField.newField("a", Collections.singleton("keyword"), ExtractedField.ExtractionMethod.DOC_VALUE);
assertThat(validA.value(hit), equalTo(new Integer[] { 1 })); assertThat(validA.value(hit), equalTo(new Integer[] { 1 }));
ExtractedField invalidB = ExtractedField.newField("b", ExtractedField.ExtractionMethod.DOC_VALUE); ExtractedField invalidB = ExtractedField.newField("b", Collections.singleton("keyword"),
ExtractedField.ExtractionMethod.DOC_VALUE);
assertThat(invalidB.value(hit), equalTo(new Object[0])); assertThat(invalidB.value(hit), equalTo(new Object[0]));
ExtractedField validB = ExtractedField.newField("b", ExtractedField.ExtractionMethod.SOURCE); ExtractedField validB = ExtractedField.newField("b", Collections.singleton("text"), ExtractedField.ExtractionMethod.SOURCE);
assertThat(validB.value(hit), equalTo(new Integer[] { 2 })); assertThat(validB.value(hit), equalTo(new Integer[] { 2 }));
} }
public void testValueGivenEmptyHit() { public void testValueGivenEmptyHit() {
SearchHit hit = new SearchHitBuilder(42).build(); SearchHit hit = new SearchHitBuilder(42).build();
ExtractedField docValue = ExtractedField.newField("a", ExtractedField.ExtractionMethod.SOURCE); ExtractedField docValue = ExtractedField.newField("a", Collections.singleton("text"), ExtractedField.ExtractionMethod.SOURCE);
assertThat(docValue.value(hit), equalTo(new Object[0])); assertThat(docValue.value(hit), equalTo(new Object[0]));
ExtractedField sourceField = ExtractedField.newField("b", ExtractedField.ExtractionMethod.DOC_VALUE); ExtractedField sourceField = ExtractedField.newField("b", Collections.singleton("keyword"),
ExtractedField.ExtractionMethod.DOC_VALUE);
assertThat(sourceField.value(hit), equalTo(new Object[0])); assertThat(sourceField.value(hit), equalTo(new Object[0]));
} }
public void testNewTimeFieldGivenSource() { public void testNewTimeFieldGivenSource() {
expectThrows(IllegalArgumentException.class, () -> ExtractedField.newTimeField("time", ExtractedField.ExtractionMethod.SOURCE)); expectThrows(IllegalArgumentException.class, () -> ExtractedField.newTimeField("time", Collections.singleton("date"),
ExtractedField.ExtractionMethod.SOURCE));
} }
public void testValueGivenStringTimeField() { public void testValueGivenStringTimeField() {
final long millis = randomLong(); final long millis = randomLong();
final SearchHit hit = new SearchHitBuilder(randomInt()).addField("time", Long.toString(millis)).build(); final SearchHit hit = new SearchHitBuilder(randomInt()).addField("time", Long.toString(millis)).build();
final ExtractedField timeField = ExtractedField.newTimeField("time", ExtractedField.ExtractionMethod.DOC_VALUE); final ExtractedField timeField = ExtractedField.newTimeField("time", Collections.singleton("date"),
ExtractedField.ExtractionMethod.DOC_VALUE);
assertThat(timeField.value(hit), equalTo(new Object[] { millis })); assertThat(timeField.value(hit), equalTo(new Object[] { millis }));
} }
public void testValueGivenLongTimeField() { public void testValueGivenLongTimeField() {
final long millis = randomLong(); final long millis = randomLong();
final SearchHit hit = new SearchHitBuilder(randomInt()).addField("time", millis).build(); final SearchHit hit = new SearchHitBuilder(randomInt()).addField("time", millis).build();
final ExtractedField timeField = ExtractedField.newTimeField("time", ExtractedField.ExtractionMethod.DOC_VALUE); final ExtractedField timeField = ExtractedField.newTimeField("time", Collections.singleton("date"),
ExtractedField.ExtractionMethod.DOC_VALUE);
assertThat(timeField.value(hit), equalTo(new Object[] { millis })); assertThat(timeField.value(hit), equalTo(new Object[] { millis }));
} }
@ -141,13 +152,15 @@ public class ExtractedFieldTests extends ESTestCase {
// Prior to 6.x, timestamps were simply `long` milliseconds-past-the-epoch values // Prior to 6.x, timestamps were simply `long` milliseconds-past-the-epoch values
final long millis = randomLong(); final long millis = randomLong();
final SearchHit hit = new SearchHitBuilder(randomInt()).addField("time", millis).build(); final SearchHit hit = new SearchHitBuilder(randomInt()).addField("time", millis).build();
final ExtractedField timeField = ExtractedField.newTimeField("time", ExtractedField.ExtractionMethod.DOC_VALUE); final ExtractedField timeField = ExtractedField.newTimeField("time", Collections.singleton("date"),
ExtractedField.ExtractionMethod.DOC_VALUE);
assertThat(timeField.value(hit), equalTo(new Object[] { millis })); assertThat(timeField.value(hit), equalTo(new Object[] { millis }));
} }
public void testValueGivenUnknownFormatTimeField() { public void testValueGivenUnknownFormatTimeField() {
final SearchHit hit = new SearchHitBuilder(randomInt()).addField("time", new Object()).build(); final SearchHit hit = new SearchHitBuilder(randomInt()).addField("time", new Object()).build();
final ExtractedField timeField = ExtractedField.newTimeField("time", ExtractedField.ExtractionMethod.DOC_VALUE); final ExtractedField timeField = ExtractedField.newTimeField("time", Collections.singleton("date"),
ExtractedField.ExtractionMethod.DOC_VALUE);
assertThat(expectThrows(IllegalStateException.class, () -> timeField.value(hit)).getMessage(), assertThat(expectThrows(IllegalStateException.class, () -> timeField.value(hit)).getMessage(),
startsWith("Unexpected value for a time field")); startsWith("Unexpected value for a time field"));
} }
@ -155,14 +168,15 @@ public class ExtractedFieldTests extends ESTestCase {
public void testAliasVersusName() { public void testAliasVersusName() {
SearchHit hit = new SearchHitBuilder(42).addField("a", 1).addField("b", 2).build(); SearchHit hit = new SearchHitBuilder(42).addField("a", 1).addField("b", 2).build();
ExtractedField field = ExtractedField.newField("a", "a", ExtractedField.ExtractionMethod.DOC_VALUE); ExtractedField field = ExtractedField.newField("a", "a", Collections.singleton("int"),
ExtractedField.ExtractionMethod.DOC_VALUE);
assertThat(field.getAlias(), equalTo("a")); assertThat(field.getAlias(), equalTo("a"));
assertThat(field.getName(), equalTo("a")); assertThat(field.getName(), equalTo("a"));
assertThat(field.value(hit), equalTo(new Integer[] { 1 })); assertThat(field.value(hit), equalTo(new Integer[] { 1 }));
hit = new SearchHitBuilder(42).addField("a", 1).addField("b", 2).build(); hit = new SearchHitBuilder(42).addField("a", 1).addField("b", 2).build();
field = ExtractedField.newField("a", "b", ExtractedField.ExtractionMethod.DOC_VALUE); field = ExtractedField.newField("a", "b", Collections.singleton("int"), ExtractedField.ExtractionMethod.DOC_VALUE);
assertThat(field.getAlias(), equalTo("a")); assertThat(field.getAlias(), equalTo("a"));
assertThat(field.getName(), equalTo("b")); assertThat(field.getName(), equalTo("b"));
assertThat(field.value(hit), equalTo(new Integer[] { 2 })); assertThat(field.value(hit), equalTo(new Integer[] { 2 }));
@ -170,11 +184,11 @@ public class ExtractedFieldTests extends ESTestCase {
public void testGetDocValueFormat() { public void testGetDocValueFormat() {
for (ExtractedField.ExtractionMethod method : ExtractedField.ExtractionMethod.values()) { for (ExtractedField.ExtractionMethod method : ExtractedField.ExtractionMethod.values()) {
assertThat(ExtractedField.newField("f", method).getDocValueFormat(), equalTo(null)); assertThat(ExtractedField.newField("f", Collections.emptySet(), method).getDocValueFormat(), equalTo(null));
} }
assertThat(ExtractedField.newTimeField("doc_value_time", ExtractedField.ExtractionMethod.DOC_VALUE).getDocValueFormat(), assertThat(ExtractedField.newTimeField("doc_value_time", Collections.singleton("date"),
equalTo("epoch_millis")); ExtractedField.ExtractionMethod.DOC_VALUE).getDocValueFormat(), equalTo("epoch_millis"));
assertThat(ExtractedField.newTimeField("source_time", ExtractedField.ExtractionMethod.SCRIPT_FIELD).getDocValueFormat(), assertThat(ExtractedField.newTimeField("source_time", Collections.emptySet(),
equalTo("epoch_millis")); ExtractedField.ExtractionMethod.SCRIPT_FIELD).getDocValueFormat(), equalTo("epoch_millis"));
} }
} }

View File

@ -27,12 +27,18 @@ import static org.mockito.Mockito.when;
public class ExtractedFieldsTests extends ESTestCase { public class ExtractedFieldsTests extends ESTestCase {
public void testAllTypesOfFields() { public void testAllTypesOfFields() {
ExtractedField docValue1 = ExtractedField.newField("doc1", ExtractedField.ExtractionMethod.DOC_VALUE); ExtractedField docValue1 = ExtractedField.newField("doc1", Collections.singleton("keyword"),
ExtractedField docValue2 = ExtractedField.newField("doc2", ExtractedField.ExtractionMethod.DOC_VALUE); ExtractedField.ExtractionMethod.DOC_VALUE);
ExtractedField scriptField1 = ExtractedField.newField("scripted1", ExtractedField.ExtractionMethod.SCRIPT_FIELD); ExtractedField docValue2 = ExtractedField.newField("doc2", Collections.singleton("ip"),
ExtractedField scriptField2 = ExtractedField.newField("scripted2", ExtractedField.ExtractionMethod.SCRIPT_FIELD); ExtractedField.ExtractionMethod.DOC_VALUE);
ExtractedField sourceField1 = ExtractedField.newField("src1", ExtractedField.ExtractionMethod.SOURCE); ExtractedField scriptField1 = ExtractedField.newField("scripted1", Collections.emptySet(),
ExtractedField sourceField2 = ExtractedField.newField("src2", ExtractedField.ExtractionMethod.SOURCE); ExtractedField.ExtractionMethod.SCRIPT_FIELD);
ExtractedField scriptField2 = ExtractedField.newField("scripted2", Collections.emptySet(),
ExtractedField.ExtractionMethod.SCRIPT_FIELD);
ExtractedField sourceField1 = ExtractedField.newField("src1", Collections.singleton("text"),
ExtractedField.ExtractionMethod.SOURCE);
ExtractedField sourceField2 = ExtractedField.newField("src2", Collections.singleton("text"),
ExtractedField.ExtractionMethod.SOURCE);
ExtractedFields extractedFields = new ExtractedFields(Arrays.asList( ExtractedFields extractedFields = new ExtractedFields(Arrays.asList(
docValue1, docValue2, scriptField1, scriptField2, sourceField1, sourceField2)); docValue1, docValue2, scriptField1, scriptField2, sourceField1, sourceField2));

View File

@ -29,7 +29,8 @@ import static org.mockito.Mockito.when;
public class TimeBasedExtractedFieldsTests extends ESTestCase { public class TimeBasedExtractedFieldsTests extends ESTestCase {
private ExtractedField timeField = ExtractedField.newTimeField("time", ExtractedField.ExtractionMethod.DOC_VALUE); private ExtractedField timeField = ExtractedField.newTimeField("time", Collections.singleton("date"),
ExtractedField.ExtractionMethod.DOC_VALUE);
public void testInvalidConstruction() { public void testInvalidConstruction() {
expectThrows(IllegalArgumentException.class, () -> new TimeBasedExtractedFields(timeField, Collections.emptyList())); expectThrows(IllegalArgumentException.class, () -> new TimeBasedExtractedFields(timeField, Collections.emptyList()));
@ -46,12 +47,18 @@ public class TimeBasedExtractedFieldsTests extends ESTestCase {
} }
public void testAllTypesOfFields() { public void testAllTypesOfFields() {
ExtractedField docValue1 = ExtractedField.newField("doc1", ExtractedField.ExtractionMethod.DOC_VALUE); ExtractedField docValue1 = ExtractedField.newField("doc1", Collections.singleton("keyword"),
ExtractedField docValue2 = ExtractedField.newField("doc2", ExtractedField.ExtractionMethod.DOC_VALUE); ExtractedField.ExtractionMethod.DOC_VALUE);
ExtractedField scriptField1 = ExtractedField.newField("scripted1", ExtractedField.ExtractionMethod.SCRIPT_FIELD); ExtractedField docValue2 = ExtractedField.newField("doc2", Collections.singleton("float"),
ExtractedField scriptField2 = ExtractedField.newField("scripted2", ExtractedField.ExtractionMethod.SCRIPT_FIELD); ExtractedField.ExtractionMethod.DOC_VALUE);
ExtractedField sourceField1 = ExtractedField.newField("src1", ExtractedField.ExtractionMethod.SOURCE); ExtractedField scriptField1 = ExtractedField.newField("scripted1", Collections.emptySet(),
ExtractedField sourceField2 = ExtractedField.newField("src2", ExtractedField.ExtractionMethod.SOURCE); ExtractedField.ExtractionMethod.SCRIPT_FIELD);
ExtractedField scriptField2 = ExtractedField.newField("scripted2", Collections.emptySet(),
ExtractedField.ExtractionMethod.SCRIPT_FIELD);
ExtractedField sourceField1 = ExtractedField.newField("src1", Collections.singleton("text"),
ExtractedField.ExtractionMethod.SOURCE);
ExtractedField sourceField2 = ExtractedField.newField("src2", Collections.singleton("text"),
ExtractedField.ExtractionMethod.SOURCE);
TimeBasedExtractedFields extractedFields = new TimeBasedExtractedFields(timeField, Arrays.asList(timeField, TimeBasedExtractedFields extractedFields = new TimeBasedExtractedFields(timeField, Arrays.asList(timeField,
docValue1, docValue2, scriptField1, scriptField2, sourceField1, sourceField2)); docValue1, docValue2, scriptField1, scriptField2, sourceField1, sourceField2));

View File

@ -135,9 +135,11 @@ public class ScrollDataExtractorTests extends ESTestCase {
capturedSearchRequests = new ArrayList<>(); capturedSearchRequests = new ArrayList<>();
capturedContinueScrollIds = new ArrayList<>(); capturedContinueScrollIds = new ArrayList<>();
jobId = "test-job"; jobId = "test-job";
ExtractedField timeField = ExtractedField.newField("time", ExtractedField.ExtractionMethod.DOC_VALUE); ExtractedField timeField = ExtractedField.newField("time", Collections.singleton("date"),
ExtractedField.ExtractionMethod.DOC_VALUE);
extractedFields = new TimeBasedExtractedFields(timeField, extractedFields = new TimeBasedExtractedFields(timeField,
Arrays.asList(timeField, ExtractedField.newField("field_1", ExtractedField.ExtractionMethod.DOC_VALUE))); Arrays.asList(timeField, ExtractedField.newField("field_1", Collections.singleton("keyword"),
ExtractedField.ExtractionMethod.DOC_VALUE)));
indices = Arrays.asList("index-1", "index-2"); indices = Arrays.asList("index-1", "index-2");
query = QueryBuilders.matchAllQuery(); query = QueryBuilders.matchAllQuery();
scriptFields = Collections.emptyList(); scriptFields = Collections.emptyList();

View File

@ -16,16 +16,21 @@ import java.io.ByteArrayOutputStream;
import java.io.IOException; import java.io.IOException;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
public class SearchHitToJsonProcessorTests extends ESTestCase { public class SearchHitToJsonProcessorTests extends ESTestCase {
public void testProcessGivenSingleHit() throws IOException { public void testProcessGivenSingleHit() throws IOException {
ExtractedField timeField = ExtractedField.newField("time", ExtractedField.ExtractionMethod.DOC_VALUE); ExtractedField timeField = ExtractedField.newField("time", Collections.singleton("date"),
ExtractedField missingField = ExtractedField.newField("missing", ExtractedField.ExtractionMethod.DOC_VALUE); ExtractedField.ExtractionMethod.DOC_VALUE);
ExtractedField singleField = ExtractedField.newField("single", ExtractedField.ExtractionMethod.DOC_VALUE); ExtractedField missingField = ExtractedField.newField("missing", Collections.singleton("float"),
ExtractedField arrayField = ExtractedField.newField("array", ExtractedField.ExtractionMethod.DOC_VALUE); ExtractedField.ExtractionMethod.DOC_VALUE);
ExtractedField singleField = ExtractedField.newField("single", Collections.singleton("keyword"),
ExtractedField.ExtractionMethod.DOC_VALUE);
ExtractedField arrayField = ExtractedField.newField("array", Collections.singleton("keyword"),
ExtractedField.ExtractionMethod.DOC_VALUE);
TimeBasedExtractedFields extractedFields = new TimeBasedExtractedFields(timeField, TimeBasedExtractedFields extractedFields = new TimeBasedExtractedFields(timeField,
Arrays.asList(timeField, missingField, singleField, arrayField)); Arrays.asList(timeField, missingField, singleField, arrayField));
@ -41,10 +46,14 @@ public class SearchHitToJsonProcessorTests extends ESTestCase {
} }
public void testProcessGivenMultipleHits() throws IOException { public void testProcessGivenMultipleHits() throws IOException {
ExtractedField timeField = ExtractedField.newField("time", ExtractedField.ExtractionMethod.DOC_VALUE); ExtractedField timeField = ExtractedField.newField("time", Collections.singleton("date"),
ExtractedField missingField = ExtractedField.newField("missing", ExtractedField.ExtractionMethod.DOC_VALUE); ExtractedField.ExtractionMethod.DOC_VALUE);
ExtractedField singleField = ExtractedField.newField("single", ExtractedField.ExtractionMethod.DOC_VALUE); ExtractedField missingField = ExtractedField.newField("missing", Collections.singleton("float"),
ExtractedField arrayField = ExtractedField.newField("array", ExtractedField.ExtractionMethod.DOC_VALUE); ExtractedField.ExtractionMethod.DOC_VALUE);
ExtractedField singleField = ExtractedField.newField("single", Collections.singleton("keyword"),
ExtractedField.ExtractionMethod.DOC_VALUE);
ExtractedField arrayField = ExtractedField.newField("array", Collections.singleton("keyword"),
ExtractedField.ExtractionMethod.DOC_VALUE);
TimeBasedExtractedFields extractedFields = new TimeBasedExtractedFields(timeField, TimeBasedExtractedFields extractedFields = new TimeBasedExtractedFields(timeField,
Arrays.asList(timeField, missingField, singleField, arrayField)); Arrays.asList(timeField, missingField, singleField, arrayField));

View File

@ -71,8 +71,8 @@ public class DataFrameDataExtractorTests extends ESTestCase {
indices = Arrays.asList("index-1", "index-2"); indices = Arrays.asList("index-1", "index-2");
query = QueryBuilders.matchAllQuery(); query = QueryBuilders.matchAllQuery();
extractedFields = new ExtractedFields(Arrays.asList( extractedFields = new ExtractedFields(Arrays.asList(
ExtractedField.newField("field_1", ExtractedField.ExtractionMethod.DOC_VALUE), ExtractedField.newField("field_1", Collections.singleton("keyword"), ExtractedField.ExtractionMethod.DOC_VALUE),
ExtractedField.newField("field_2", ExtractedField.ExtractionMethod.DOC_VALUE))); ExtractedField.newField("field_2", Collections.singleton("keyword"), ExtractedField.ExtractionMethod.DOC_VALUE)));
scrollSize = 1000; scrollSize = 1000;
headers = Collections.emptyMap(); headers = Collections.emptyMap();
@ -288,8 +288,8 @@ public class DataFrameDataExtractorTests extends ESTestCase {
public void testIncludeSourceIsFalseAndAtLeastOneSourceField() throws IOException { public void testIncludeSourceIsFalseAndAtLeastOneSourceField() throws IOException {
extractedFields = new ExtractedFields(Arrays.asList( extractedFields = new ExtractedFields(Arrays.asList(
ExtractedField.newField("field_1", ExtractedField.ExtractionMethod.DOC_VALUE), ExtractedField.newField("field_1", Collections.singleton("keyword"), ExtractedField.ExtractionMethod.DOC_VALUE),
ExtractedField.newField("field_2", ExtractedField.ExtractionMethod.SOURCE))); ExtractedField.newField("field_2", Collections.singleton("text"), ExtractedField.ExtractionMethod.SOURCE)));
TestExtractor dataExtractor = createExtractor(false); TestExtractor dataExtractor = createExtractor(false);

View File

@ -14,6 +14,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection; import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedField; import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedField;
import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedFields; import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedFields;
@ -38,11 +39,11 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
private static final String RESULTS_FIELD = "ml"; private static final String RESULTS_FIELD = "ml";
public void testDetect_GivenFloatField() { public void testDetect_GivenFloatField() {
FieldCapabilitiesResponse fieldCapabilities= new MockFieldCapsResponseBuilder() FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
.addAggregatableField("some_float", "float").build(); .addAggregatableField("some_float", "float").build();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildAnalyticsConfig(), false, 100, fieldCapabilities); SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities);
ExtractedFields extractedFields = extractedFieldsDetector.detect(); ExtractedFields extractedFields = extractedFieldsDetector.detect();
List<ExtractedField> allFields = extractedFields.getAllFields(); List<ExtractedField> allFields = extractedFields.getAllFields();
@ -52,12 +53,12 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
} }
public void testDetect_GivenNumericFieldWithMultipleTypes() { public void testDetect_GivenNumericFieldWithMultipleTypes() {
FieldCapabilitiesResponse fieldCapabilities= new MockFieldCapsResponseBuilder() FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
.addAggregatableField("some_number", "long", "integer", "short", "byte", "double", "float", "half_float", "scaled_float") .addAggregatableField("some_number", "long", "integer", "short", "byte", "double", "float", "half_float", "scaled_float")
.build(); .build();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildAnalyticsConfig(), false, 100, fieldCapabilities); SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities);
ExtractedFields extractedFields = extractedFieldsDetector.detect(); ExtractedFields extractedFields = extractedFieldsDetector.detect();
List<ExtractedField> allFields = extractedFields.getAllFields(); List<ExtractedField> allFields = extractedFields.getAllFields();
@ -67,36 +68,36 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
} }
public void testDetect_GivenNonNumericField() { public void testDetect_GivenNonNumericField() {
FieldCapabilitiesResponse fieldCapabilities= new MockFieldCapsResponseBuilder() FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
.addAggregatableField("some_keyword", "keyword").build(); .addAggregatableField("some_keyword", "keyword").build();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildAnalyticsConfig(), false, 100, fieldCapabilities); SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities);
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect()); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect());
assertThat(e.getMessage(), equalTo("No compatible fields could be detected in index [source_index]")); assertThat(e.getMessage(), equalTo("No compatible fields could be detected in index [source_index]"));
} }
public void testDetect_GivenFieldWithNumericAndNonNumericTypes() { public void testDetect_GivenOutlierDetectionAndFieldWithNumericAndNonNumericTypes() {
FieldCapabilitiesResponse fieldCapabilities= new MockFieldCapsResponseBuilder() FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
.addAggregatableField("indecisive_field", "float", "keyword").build(); .addAggregatableField("indecisive_field", "float", "keyword").build();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildAnalyticsConfig(), false, 100, fieldCapabilities); SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities);
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect()); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect());
assertThat(e.getMessage(), equalTo("No compatible fields could be detected in index [source_index]")); assertThat(e.getMessage(), equalTo("No compatible fields could be detected in index [source_index]"));
} }
public void testDetect_GivenMultipleFields() { public void testDetect_GivenOutlierDetectionAndMultipleFields() {
FieldCapabilitiesResponse fieldCapabilities= new MockFieldCapsResponseBuilder() FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
.addAggregatableField("some_float", "float") .addAggregatableField("some_float", "float")
.addAggregatableField("some_long", "long") .addAggregatableField("some_long", "long")
.addAggregatableField("some_keyword", "keyword") .addAggregatableField("some_keyword", "keyword")
.build(); .build();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildAnalyticsConfig(), false, 100, fieldCapabilities); SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities);
ExtractedFields extractedFields = extractedFieldsDetector.detect(); ExtractedFields extractedFields = extractedFieldsDetector.detect();
List<ExtractedField> allFields = extractedFields.getAllFields(); List<ExtractedField> allFields = extractedFields.getAllFields();
@ -107,12 +108,46 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
contains(equalTo(ExtractedField.ExtractionMethod.DOC_VALUE))); contains(equalTo(ExtractedField.ExtractionMethod.DOC_VALUE)));
} }
public void testDetect_GivenRegressionAndMultipleFields() {
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
.addAggregatableField("some_float", "float")
.addAggregatableField("some_long", "long")
.addAggregatableField("some_keyword", "keyword")
.addAggregatableField("foo", "keyword")
.build();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildRegressionConfig("foo"), false, 100, fieldCapabilities);
ExtractedFields extractedFields = extractedFieldsDetector.detect();
List<ExtractedField> allFields = extractedFields.getAllFields();
assertThat(allFields.size(), equalTo(4));
assertThat(allFields.stream().map(ExtractedField::getName).collect(Collectors.toList()),
contains("foo", "some_float", "some_keyword", "some_long"));
assertThat(allFields.stream().map(ExtractedField::getExtractionMethod).collect(Collectors.toSet()),
contains(equalTo(ExtractedField.ExtractionMethod.DOC_VALUE)));
}
public void testDetect_GivenRegressionAndRequiredFieldMissing() {
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
.addAggregatableField("some_float", "float")
.addAggregatableField("some_long", "long")
.addAggregatableField("some_keyword", "keyword")
.build();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildRegressionConfig("foo"), false, 100, fieldCapabilities);
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect());
assertThat(e.getMessage(), equalTo("required fields [foo] are missing"));
}
public void testDetect_GivenIgnoredField() { public void testDetect_GivenIgnoredField() {
FieldCapabilitiesResponse fieldCapabilities= new MockFieldCapsResponseBuilder() FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
.addAggregatableField("_id", "float").build(); .addAggregatableField("_id", "float").build();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildAnalyticsConfig(), false, 100, fieldCapabilities); SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities);
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect()); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect());
assertThat(e.getMessage(), equalTo("No compatible fields could be detected in index [source_index]")); assertThat(e.getMessage(), equalTo("No compatible fields could be detected in index [source_index]"));
@ -134,7 +169,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
FieldCapabilitiesResponse fieldCapabilities = mockFieldCapsResponseBuilder.build(); FieldCapabilitiesResponse fieldCapabilities = mockFieldCapsResponseBuilder.build();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildAnalyticsConfig(), false, 100, fieldCapabilities); SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities);
ExtractedFields extractedFields = extractedFieldsDetector.detect(); ExtractedFields extractedFields = extractedFieldsDetector.detect();
List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName) List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName)
@ -151,7 +186,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
FetchSourceContext desiredFields = new FetchSourceContext(true, new String[]{"your_field1", "my*"}, new String[0]); FetchSourceContext desiredFields = new FetchSourceContext(true, new String[]{"your_field1", "my*"}, new String[0]);
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildAnalyticsConfig(desiredFields), false, 100, fieldCapabilities); SOURCE_INDEX, buildOutlierDetectionConfig(desiredFields), false, 100, fieldCapabilities);
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect()); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect());
assertThat(e.getMessage(), equalTo("No field [your_field1] could be detected")); assertThat(e.getMessage(), equalTo("No field [your_field1] could be detected"));
@ -166,7 +201,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
FetchSourceContext desiredFields = new FetchSourceContext(true, new String[0], new String[]{"my_*"}); FetchSourceContext desiredFields = new FetchSourceContext(true, new String[0], new String[]{"my_*"});
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildAnalyticsConfig(desiredFields), false, 100, fieldCapabilities); SOURCE_INDEX, buildOutlierDetectionConfig(desiredFields), false, 100, fieldCapabilities);
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect()); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect());
assertThat(e.getMessage(), equalTo("No compatible fields could be detected in index [source_index]")); assertThat(e.getMessage(), equalTo("No compatible fields could be detected in index [source_index]"));
} }
@ -182,7 +217,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
FetchSourceContext desiredFields = new FetchSourceContext(true, new String[]{"your*", "my_*"}, new String[]{"*nope"}); FetchSourceContext desiredFields = new FetchSourceContext(true, new String[]{"your*", "my_*"}, new String[]{"*nope"});
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildAnalyticsConfig(desiredFields), false, 100, fieldCapabilities); SOURCE_INDEX, buildOutlierDetectionConfig(desiredFields), false, 100, fieldCapabilities);
ExtractedFields extractedFields = extractedFieldsDetector.detect(); ExtractedFields extractedFields = extractedFieldsDetector.detect();
List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName) List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName)
@ -199,7 +234,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
.build(); .build();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildAnalyticsConfig(), false, 100, fieldCapabilities); SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities);
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect()); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect());
assertThat(e.getMessage(), equalTo("A field that matches the dest.results_field [ml] already exists; " + assertThat(e.getMessage(), equalTo("A field that matches the dest.results_field [ml] already exists; " +
@ -215,7 +250,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
.build(); .build();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildAnalyticsConfig(), true, 100, fieldCapabilities); SOURCE_INDEX, buildOutlierDetectionConfig(), true, 100, fieldCapabilities);
ExtractedFields extractedFields = extractedFieldsDetector.detect(); ExtractedFields extractedFields = extractedFieldsDetector.detect();
List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName) List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName)
@ -232,7 +267,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
.build(); .build();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildAnalyticsConfig(), true, 4, fieldCapabilities); SOURCE_INDEX, buildOutlierDetectionConfig(), true, 4, fieldCapabilities);
ExtractedFields extractedFields = extractedFieldsDetector.detect(); ExtractedFields extractedFields = extractedFieldsDetector.detect();
List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName) List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName)
@ -251,7 +286,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
.build(); .build();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildAnalyticsConfig(), true, 3, fieldCapabilities); SOURCE_INDEX, buildOutlierDetectionConfig(), true, 3, fieldCapabilities);
ExtractedFields extractedFields = extractedFieldsDetector.detect(); ExtractedFields extractedFields = extractedFieldsDetector.detect();
List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName) List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName)
@ -270,7 +305,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
.build(); .build();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildAnalyticsConfig(), true, 2, fieldCapabilities); SOURCE_INDEX, buildOutlierDetectionConfig(), true, 2, fieldCapabilities);
ExtractedFields extractedFields = extractedFieldsDetector.detect(); ExtractedFields extractedFields = extractedFieldsDetector.detect();
List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName) List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName)
@ -280,11 +315,11 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
contains(equalTo(ExtractedField.ExtractionMethod.SOURCE))); contains(equalTo(ExtractedField.ExtractionMethod.SOURCE)));
} }
private static DataFrameAnalyticsConfig buildAnalyticsConfig() { private static DataFrameAnalyticsConfig buildOutlierDetectionConfig() {
return buildAnalyticsConfig(null); return buildOutlierDetectionConfig(null);
} }
private static DataFrameAnalyticsConfig buildAnalyticsConfig(FetchSourceContext analyzedFields) { private static DataFrameAnalyticsConfig buildOutlierDetectionConfig(FetchSourceContext analyzedFields) {
return new DataFrameAnalyticsConfig.Builder("foo") return new DataFrameAnalyticsConfig.Builder("foo")
.setSource(new DataFrameAnalyticsSource(SOURCE_INDEX, null)) .setSource(new DataFrameAnalyticsSource(SOURCE_INDEX, null))
.setDest(new DataFrameAnalyticsDest(DEST_INDEX, null)) .setDest(new DataFrameAnalyticsDest(DEST_INDEX, null))
@ -293,6 +328,19 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
.build(); .build();
} }
private static DataFrameAnalyticsConfig buildRegressionConfig(String dependentVariable) {
return buildRegressionConfig(dependentVariable, null);
}
private static DataFrameAnalyticsConfig buildRegressionConfig(String dependentVariable, FetchSourceContext analyzedFields) {
return new DataFrameAnalyticsConfig.Builder("foo")
.setSource(new DataFrameAnalyticsSource(SOURCE_INDEX, null))
.setDest(new DataFrameAnalyticsDest(DEST_INDEX, null))
.setAnalyzedFields(analyzedFields)
.setAnalysis(new Regression(dependentVariable))
.build();
}
private static class MockFieldCapsResponseBuilder { private static class MockFieldCapsResponseBuilder {
private final Map<String, Map<String, FieldCapabilities>> fieldCaps = new HashMap<>(); private final Map<String, Map<String, FieldCapabilities>> fieldCaps = new HashMap<>();

View File

@ -607,7 +607,11 @@ setup:
"dest": { "dest": {
"index": "index-bar_dest" "index": "index-bar_dest"
}, },
"analysis": {"outlier_detection":{}} "analysis": {
"regression":{
"dependent_variable": "to_predict"
}
}
} }
- match: { id: "bar" } - match: { id: "bar" }
@ -768,7 +772,11 @@ setup:
"dest": { "dest": {
"index": "index-bar_dest" "index": "index-bar_dest"
}, },
"analysis": {"outlier_detection":{}} "analysis": {
"regression":{
"dependent_variable": "to_predict"
}
}
} }
- match: { id: "bar" } - match: { id: "bar" }
@ -930,3 +938,247 @@ setup:
xpack.ml.max_model_memory_limit: null xpack.ml.max_model_memory_limit: null
- match: {transient: {}} - match: {transient: {}}
---
"Test put regression given dependent_variable is not defined":
- do:
catch: /parse_exception/
ml.put_data_frame_analytics:
id: "regression-without-dependent-variable"
body: >
{
"source": {
"index": "index-source"
},
"dest": {
"index": "index-dest"
},
"analysis": {
"regression": {}
}
}
---
"Test put regression given negative lambda":
- do:
catch: /\[lambda\] must be a non-negative double/
ml.put_data_frame_analytics:
id: "regression-negative-lambda"
body: >
{
"source": {
"index": "index-source"
},
"dest": {
"index": "index-dest"
},
"analysis": {
"regression": {
"dependent_variable": "foo",
"lambda": -1.0
}
}
}
---
"Test put regression given negative gamma":
- do:
catch: /\[gamma\] must be a non-negative double/
ml.put_data_frame_analytics:
id: "regression-negative-gamma"
body: >
{
"source": {
"index": "index-source"
},
"dest": {
"index": "index-dest"
},
"analysis": {
"regression": {
"dependent_variable": "foo",
"gamma": -1.0
}
}
}
---
"Test put regression given eta less than 1e-3":
- do:
catch: /\[eta\] must be a double in \[0.001, 1\]/
ml.put_data_frame_analytics:
id: "regression-eta-greater-less-than-valid"
body: >
{
"source": {
"index": "index-source"
},
"dest": {
"index": "index-dest"
},
"analysis": {
"regression": {
"dependent_variable": "foo",
"eta": 0.0009
}
}
}
---
"Test put regression given eta greater than one":
- do:
catch: /\[eta\] must be a double in \[0.001, 1\]/
ml.put_data_frame_analytics:
id: "regression-eta-greater-than-one"
body: >
{
"source": {
"index": "index-source"
},
"dest": {
"index": "index-dest"
},
"analysis": {
"regression": {
"dependent_variable": "foo",
"eta": 1.00001
}
}
}
---
"Test put regression given maximum_number_trees is zero":
- do:
catch: /\[maximum_number_trees\] must be an integer in \[1, 2000\]/
ml.put_data_frame_analytics:
id: "regression-maximum-number-trees-is-zero"
body: >
{
"source": {
"index": "index-source"
},
"dest": {
"index": "index-dest"
},
"analysis": {
"regression": {
"dependent_variable": "foo",
"maximum_number_trees": 0
}
}
}
---
"Test put regression given maximum_number_trees is greater than 2k":
- do:
catch: /\[maximum_number_trees\] must be an integer in \[1, 2000\]/
ml.put_data_frame_analytics:
id: "regression-maximum-number-trees-greater-than-2k"
body: >
{
"source": {
"index": "index-source"
},
"dest": {
"index": "index-dest"
},
"analysis": {
"regression": {
"dependent_variable": "foo",
"maximum_number_trees": 2001
}
}
}
---
"Test put regression given feature_bag_fraction is negative":
- do:
catch: /\[feature_bag_fraction\] must be a double in \(0, 1\]/
ml.put_data_frame_analytics:
id: "regression-feature-bag-fraction-is-negative"
body: >
{
"source": {
"index": "index-source"
},
"dest": {
"index": "index-dest"
},
"analysis": {
"regression": {
"dependent_variable": "foo",
"feature_bag_fraction": -0.0001
}
}
}
---
"Test put regression given feature_bag_fraction is greater than one":
- do:
catch: /\[feature_bag_fraction\] must be a double in \(0, 1\]/
ml.put_data_frame_analytics:
id: "regression-feature-bag-fraction-is-greater-than-one"
body: >
{
"source": {
"index": "index-source"
},
"dest": {
"index": "index-dest"
},
"analysis": {
"regression": {
"dependent_variable": "foo",
"feature_bag_fraction": 1.0001
}
}
}
---
"Test put regression given valid":
- do:
ml.put_data_frame_analytics:
id: "valid-regression"
body: >
{
"source": {
"index": "index-source"
},
"dest": {
"index": "index-dest"
},
"analysis": {
"regression": {
"dependent_variable": "foo",
"lambda": 3.14,
"gamma": 0.42,
"eta": 0.5,
"maximum_number_trees": 400,
"feature_bag_fraction": 0.3
}
}
}
- match: { id: "valid-regression" }
- match: { source.index: ["index-source"] }
- match: { dest.index: "index-dest" }
- match: { analysis: {
"regression":{
"dependent_variable": "foo",
"lambda": 3.14,
"gamma": 0.42,
"eta": 0.5,
"maximum_number_trees": 400,
"feature_bag_fraction": 0.3
}
}}
- is_true: create_time
- is_true: version

View File

@ -11,15 +11,22 @@ import org.elasticsearch.client.Response;
import org.elasticsearch.client.ResponseException; import org.elasticsearch.client.ResponseException;
import org.elasticsearch.client.WarningFailureException; import org.elasticsearch.client.WarningFailureException;
import org.elasticsearch.common.Strings; import org.elasticsearch.common.Strings;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.common.xcontent.support.XContentMapValues; import org.elasticsearch.common.xcontent.support.XContentMapValues;
import org.elasticsearch.upgrades.AbstractFullClusterRestartTestCase; import org.elasticsearch.upgrades.AbstractFullClusterRestartTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig; import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig;
import org.elasticsearch.xpack.core.ml.job.config.DataDescription; import org.elasticsearch.xpack.core.ml.job.config.DataDescription;
import org.elasticsearch.xpack.core.ml.job.config.Detector; import org.elasticsearch.xpack.core.ml.job.config.Detector;
import org.elasticsearch.xpack.core.ml.job.config.Job; import org.elasticsearch.xpack.core.ml.job.config.Job;
import org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings;
import org.elasticsearch.xpack.test.rest.XPackRestTestConstants; import org.elasticsearch.xpack.test.rest.XPackRestTestConstants;
import org.elasticsearch.xpack.test.rest.XPackRestTestHelper; import org.elasticsearch.xpack.test.rest.XPackRestTestHelper;
import org.junit.Before; import org.junit.Before;
@ -28,12 +35,12 @@ import java.io.IOException;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.util.Base64; import java.util.Base64;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.nullValue; import static org.hamcrest.Matchers.nullValue;
public class MlConfigIndexMappingsFullClusterRestartIT extends AbstractFullClusterRestartTestCase { public class MlConfigIndexMappingsFullClusterRestartIT extends AbstractFullClusterRestartTestCase {
@ -41,14 +48,23 @@ public class MlConfigIndexMappingsFullClusterRestartIT extends AbstractFullClust
private static final String OLD_CLUSTER_JOB_ID = "ml-config-mappings-old-cluster-job"; private static final String OLD_CLUSTER_JOB_ID = "ml-config-mappings-old-cluster-job";
private static final String NEW_CLUSTER_JOB_ID = "ml-config-mappings-new-cluster-job"; private static final String NEW_CLUSTER_JOB_ID = "ml-config-mappings-new-cluster-job";
private static final Map<String, Object> EXPECTED_DATA_FRAME_ANALYSIS_MAPPINGS = private static final Map<String, Object> EXPECTED_DATA_FRAME_ANALYSIS_MAPPINGS = getDataFrameAnalysisMappings();
mapOf(
"properties", mapOf( @SuppressWarnings("unchecked")
"outlier_detection", mapOf( private static Map<String, Object> getDataFrameAnalysisMappings() {
"properties", mapOf( try (XContentBuilder builder = JsonXContent.contentBuilder()) {
"method", mapOf("type", "keyword"), builder.startObject();
"n_neighbors", mapOf("type", "integer"), ElasticsearchMappings.addDataFrameAnalyticsFields(builder);
"feature_influence_threshold", mapOf("type", "double"))))); builder.endObject();
Map<String, Object> asMap = builder.generator().contentType().xContent().createParser(
NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, BytesReference.bytes(builder).streamInput()).map();
return (Map<String, Object>) asMap.get(DataFrameAnalyticsConfig.ANALYSIS.getPreferredName());
} catch (IOException e) {
fail("Failed to initialize expected data frame analysis mappings");
}
return null;
}
@Override @Override
protected Settings restClientSettings() { protected Settings restClientSettings() {
@ -71,8 +87,8 @@ public class MlConfigIndexMappingsFullClusterRestartIT extends AbstractFullClust
// trigger .ml-config index creation // trigger .ml-config index creation
createAnomalyDetectorJob(OLD_CLUSTER_JOB_ID); createAnomalyDetectorJob(OLD_CLUSTER_JOB_ID);
if (getOldClusterVersion().onOrAfter(Version.V_7_3_0)) { if (getOldClusterVersion().onOrAfter(Version.V_7_3_0)) {
// .ml-config has correct mappings from the start // .ml-config has mappings for analytics as the feature was introduced in 7.3.0
assertThat(mappingsForDataFrameAnalysis(), is(equalTo(EXPECTED_DATA_FRAME_ANALYSIS_MAPPINGS))); assertThat(mappingsForDataFrameAnalysis(), is(notNullValue()));
} else { } else {
// .ml-config does not yet have correct mappings, it will need an update after cluster is upgraded // .ml-config does not yet have correct mappings, it will need an update after cluster is upgraded
assertThat(mappingsForDataFrameAnalysis(), is(nullValue())); assertThat(mappingsForDataFrameAnalysis(), is(nullValue()));
@ -125,18 +141,4 @@ public class MlConfigIndexMappingsFullClusterRestartIT extends AbstractFullClust
mappings = (Map<String, Object>) XContentMapValues.extractValue(mappings, "properties", "analysis"); mappings = (Map<String, Object>) XContentMapValues.extractValue(mappings, "properties", "analysis");
return mappings; return mappings;
} }
private static <K, V> Map<K, V> mapOf(K k1, V v1) {
Map<K, V> map = new HashMap<>();
map.put(k1, v1);
return map;
}
private static <K, V> Map<K, V> mapOf(K k1, V v1, K k2, V v2, K k3, V v3) {
Map<K, V> map = new HashMap<>();
map.put(k1, v1);
map.put(k2, v2);
map.put(k3, v3);
return map;
}
} }