[7.x][ML] Add regression analysis to DF analytics (#45292) (#45388)

This commit adds a first draft of a regression analysis
to data frame analytics. There is high probability that
the exact syntax might change.

This commit adds the new analysis type and its parameters as
well as appropriate validation. It also modifies the extractor
and the fields detector to be able to handle categorical fields
as regression analysis supports them.
This commit is contained in:
Dimitris Athanasiou 2019-08-09 19:31:13 +03:00 committed by GitHub
parent d1ed9bdbfd
commit 27497ff75f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 1026 additions and 164 deletions

View File

@ -147,6 +147,7 @@ import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.AucRoc;
@ -454,6 +455,7 @@ public class XPackClientPlugin extends Plugin implements ActionPlugin, NetworkPl
MachineLearningFeatureSetUsage::new),
// ML - Data frame analytics
new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, OutlierDetection.NAME.getPreferredName(), OutlierDetection::new),
new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, Regression.NAME.getPreferredName(), Regression::new),
// ML - Data frame evaluation
new NamedWriteableRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME.getPreferredName(),
BinarySoftClassification::new),

View File

@ -9,8 +9,22 @@ import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.xcontent.ToXContentObject;
import java.util.Map;
import java.util.Set;
public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable {
/**
* @return The analysis parameters as a map
*/
Map<String, Object> getParams();
/**
* @return {@code true} if this analysis supports fields with categorical values (i.e. text, keyword, ip)
*/
boolean supportsCategoricalFields();
/**
* @return The set of fields that analyzed documents must have for the analysis to operate
*/
Set<String> getRequiredFields();
}

View File

@ -22,6 +22,10 @@ public class MlDataFrameAnalysisNamedXContentProvider implements NamedXContentPr
boolean ignoreUnknownFields = (boolean) c;
return OutlierDetection.fromXContent(p, ignoreUnknownFields);
}));
namedXContent.add(new NamedXContentRegistry.Entry(DataFrameAnalysis.class, Regression.NAME, (p, c) -> {
boolean ignoreUnknownFields = (boolean) c;
return Regression.fromXContent(p, ignoreUnknownFields);
}));
return namedXContent;
}
@ -31,6 +35,8 @@ public class MlDataFrameAnalysisNamedXContentProvider implements NamedXContentPr
namedWriteables.add(new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, OutlierDetection.NAME.getPreferredName(),
OutlierDetection::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, Regression.NAME.getPreferredName(),
Regression::new));
return namedWriteables;
}

View File

@ -16,10 +16,12 @@ import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
public class OutlierDetection implements DataFrameAnalysis {
@ -152,6 +154,16 @@ public class OutlierDetection implements DataFrameAnalysis {
return params;
}
@Override
public boolean supportsCategoricalFields() {
return false;
}
@Override
public Set<String> getRequiredFields() {
return Collections.emptySet();
}
public enum Method {
LOF, LDOF, DISTANCE_KTH_NN, DISTANCE_KNN;

View File

@ -0,0 +1,205 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.analyses;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
public class Regression implements DataFrameAnalysis {
public static final ParseField NAME = new ParseField("regression");
public static final ParseField DEPENDENT_VARIABLE = new ParseField("dependent_variable");
public static final ParseField LAMBDA = new ParseField("lambda");
public static final ParseField GAMMA = new ParseField("gamma");
public static final ParseField ETA = new ParseField("eta");
public static final ParseField MAXIMUM_NUMBER_TREES = new ParseField("maximum_number_trees");
public static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction");
public static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
private static final ConstructingObjectParser<Regression, Void> LENIENT_PARSER = createParser(true);
private static final ConstructingObjectParser<Regression, Void> STRICT_PARSER = createParser(false);
private static ConstructingObjectParser<Regression, Void> createParser(boolean lenient) {
ConstructingObjectParser<Regression, Void> parser = new ConstructingObjectParser<>(NAME.getPreferredName(), lenient,
a -> new Regression((String) a[0], (Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5], (String) a[6]));
parser.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE);
parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), LAMBDA);
parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), GAMMA);
parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), ETA);
parser.declareInt(ConstructingObjectParser.optionalConstructorArg(), MAXIMUM_NUMBER_TREES);
parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), FEATURE_BAG_FRACTION);
parser.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME);
return parser;
}
public static Regression fromXContent(XContentParser parser, boolean ignoreUnknownFields) {
return ignoreUnknownFields ? LENIENT_PARSER.apply(parser, null) : STRICT_PARSER.apply(parser, null);
}
private final String dependentVariable;
private final Double lambda;
private final Double gamma;
private final Double eta;
private final Integer maximumNumberTrees;
private final Double featureBagFraction;
private final String predictionFieldName;
public Regression(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
@Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, @Nullable String predictionFieldName) {
this.dependentVariable = Objects.requireNonNull(dependentVariable);
if (lambda != null && lambda < 0) {
throw ExceptionsHelper.badRequestException("[{}] must be a non-negative double", LAMBDA.getPreferredName());
}
this.lambda = lambda;
if (gamma != null && gamma < 0) {
throw ExceptionsHelper.badRequestException("[{}] must be a non-negative double", GAMMA.getPreferredName());
}
this.gamma = gamma;
if (eta != null && (eta < 0.001 || eta > 1)) {
throw ExceptionsHelper.badRequestException("[{}] must be a double in [0.001, 1]", ETA.getPreferredName());
}
this.eta = eta;
if (maximumNumberTrees != null && (maximumNumberTrees <= 0 || maximumNumberTrees > 2000)) {
throw ExceptionsHelper.badRequestException("[{}] must be an integer in [1, 2000]", MAXIMUM_NUMBER_TREES.getPreferredName());
}
this.maximumNumberTrees = maximumNumberTrees;
if (featureBagFraction != null && (featureBagFraction <= 0 || featureBagFraction > 1.0)) {
throw ExceptionsHelper.badRequestException("[{}] must be a double in (0, 1]", FEATURE_BAG_FRACTION.getPreferredName());
}
this.featureBagFraction = featureBagFraction;
this.predictionFieldName = predictionFieldName;
}
public Regression(String dependentVariable) {
this(dependentVariable, null, null, null, null, null, null);
}
public Regression(StreamInput in) throws IOException {
dependentVariable = in.readString();
lambda = in.readOptionalDouble();
gamma = in.readOptionalDouble();
eta = in.readOptionalDouble();
maximumNumberTrees = in.readOptionalVInt();
featureBagFraction = in.readOptionalDouble();
predictionFieldName = in.readOptionalString();
}
@Override
public String getWriteableName() {
return NAME.getPreferredName();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(dependentVariable);
out.writeOptionalDouble(lambda);
out.writeOptionalDouble(gamma);
out.writeOptionalDouble(eta);
out.writeOptionalVInt(maximumNumberTrees);
out.writeOptionalDouble(featureBagFraction);
out.writeOptionalString(predictionFieldName);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable);
if (lambda != null) {
builder.field(LAMBDA.getPreferredName(), lambda);
}
if (gamma != null) {
builder.field(GAMMA.getPreferredName(), gamma);
}
if (eta != null) {
builder.field(ETA.getPreferredName(), eta);
}
if (maximumNumberTrees != null) {
builder.field(MAXIMUM_NUMBER_TREES.getPreferredName(), maximumNumberTrees);
}
if (featureBagFraction != null) {
builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction);
}
if (predictionFieldName != null) {
builder.field(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
}
builder.endObject();
return builder;
}
@Override
public Map<String, Object> getParams() {
Map<String, Object> params = new HashMap<>();
params.put(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable);
if (lambda != null) {
params.put(LAMBDA.getPreferredName(), lambda);
}
if (gamma != null) {
params.put(GAMMA.getPreferredName(), gamma);
}
if (eta != null) {
params.put(ETA.getPreferredName(), eta);
}
if (maximumNumberTrees != null) {
params.put(MAXIMUM_NUMBER_TREES.getPreferredName(), maximumNumberTrees);
}
if (featureBagFraction != null) {
params.put(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction);
}
if (predictionFieldName != null) {
params.put(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
}
return params;
}
@Override
public boolean supportsCategoricalFields() {
return true;
}
@Override
public Set<String> getRequiredFields() {
return Collections.singleton(dependentVariable);
}
@Override
public int hashCode() {
return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName);
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Regression that = (Regression) o;
return Objects.equals(dependentVariable, that.dependentVariable)
&& Objects.equals(lambda, that.lambda)
&& Objects.equals(gamma, that.gamma)
&& Objects.equals(eta, that.eta)
&& Objects.equals(maximumNumberTrees, that.maximumNumberTrees)
&& Objects.equals(featureBagFraction, that.featureBagFraction)
&& Objects.equals(predictionFieldName, that.predictionFieldName);
}
}

View File

@ -31,6 +31,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig;
import org.elasticsearch.xpack.core.ml.job.config.AnalysisLimits;
import org.elasticsearch.xpack.core.ml.job.config.DataDescription;
@ -443,6 +444,31 @@ public class ElasticsearchMappings {
.endObject()
.endObject()
.endObject()
.startObject(Regression.NAME.getPreferredName())
.startObject(PROPERTIES)
.startObject(Regression.DEPENDENT_VARIABLE.getPreferredName())
.field(TYPE, KEYWORD)
.endObject()
.startObject(Regression.LAMBDA.getPreferredName())
.field(TYPE, DOUBLE)
.endObject()
.startObject(Regression.GAMMA.getPreferredName())
.field(TYPE, DOUBLE)
.endObject()
.startObject(Regression.ETA.getPreferredName())
.field(TYPE, DOUBLE)
.endObject()
.startObject(Regression.MAXIMUM_NUMBER_TREES.getPreferredName())
.field(TYPE, INTEGER)
.endObject()
.startObject(Regression.FEATURE_BAG_FRACTION.getPreferredName())
.field(TYPE, DOUBLE)
.endObject()
.startObject(Regression.PREDICTION_FIELD_NAME.getPreferredName())
.field(TYPE, KEYWORD)
.endObject()
.endObject()
.endObject()
.endObject()
.endObject()
// re-used: CREATE_TIME

View File

@ -14,6 +14,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig;
import org.elasticsearch.xpack.core.ml.job.config.AnalysisLimits;
import org.elasticsearch.xpack.core.ml.job.config.DataDescription;
@ -299,6 +300,14 @@ public final class ReservedFieldNames {
OutlierDetection.N_NEIGHBORS.getPreferredName(),
OutlierDetection.METHOD.getPreferredName(),
OutlierDetection.FEATURE_INFLUENCE_THRESHOLD.getPreferredName(),
Regression.NAME.getPreferredName(),
Regression.DEPENDENT_VARIABLE.getPreferredName(),
Regression.LAMBDA.getPreferredName(),
Regression.GAMMA.getPreferredName(),
Regression.ETA.getPreferredName(),
Regression.MAXIMUM_NUMBER_TREES.getPreferredName(),
Regression.FEATURE_BAG_FRACTION.getPreferredName(),
Regression.PREDICTION_FIELD_NAME.getPreferredName(),
ElasticsearchMappings.CONFIG_TYPE,

View File

@ -0,0 +1,100 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.analyses;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractSerializingTestCase;
import java.io.IOException;
import static org.hamcrest.Matchers.equalTo;
public class RegressionTests extends AbstractSerializingTestCase<Regression> {
@Override
protected Regression doParseInstance(XContentParser parser) throws IOException {
return Regression.fromXContent(parser, false);
}
@Override
protected Regression createTestInstance() {
return createRandom();
}
public static Regression createRandom() {
Double lambda = randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, true);
Double gamma = randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, true);
Double eta = randomBoolean() ? null : randomDoubleBetween(0.001, 1.0, true);
Integer maximumNumberTrees = randomBoolean() ? null : randomIntBetween(1, 2000);
Double featureBagFraction = randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false);
String predictionFieldName = randomBoolean() ? null : randomAlphaOfLength(10);
return new Regression(randomAlphaOfLength(10), lambda, gamma, eta, maximumNumberTrees, featureBagFraction,
predictionFieldName);
}
@Override
protected Writeable.Reader<Regression> instanceReader() {
return Regression::new;
}
public void testRegression_GivenNegativeLambda() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", -0.00001, 0.0, 0.5, 500, 0.3, "result"));
assertThat(e.getMessage(), equalTo("[lambda] must be a non-negative double"));
}
public void testRegression_GivenNegativeGamma() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", 0.0, -0.00001, 0.5, 500, 0.3, "result"));
assertThat(e.getMessage(), equalTo("[gamma] must be a non-negative double"));
}
public void testRegression_GivenEtaIsZero() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", 0.0, 0.0, 0.0, 500, 0.3, "result"));
assertThat(e.getMessage(), equalTo("[eta] must be a double in [0.001, 1]"));
}
public void testRegression_GivenEtaIsGreaterThanOne() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", 0.0, 0.0, 1.00001, 500, 0.3, "result"));
assertThat(e.getMessage(), equalTo("[eta] must be a double in [0.001, 1]"));
}
public void testRegression_GivenMaximumNumberTreesIsZero() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", 0.0, 0.0, 0.5, 0, 0.3, "result"));
assertThat(e.getMessage(), equalTo("[maximum_number_trees] must be an integer in [1, 2000]"));
}
public void testRegression_GivenMaximumNumberTreesIsGreaterThan2k() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", 0.0, 0.0, 0.5, 2001, 0.3, "result"));
assertThat(e.getMessage(), equalTo("[maximum_number_trees] must be an integer in [1, 2000]"));
}
public void testRegression_GivenFeatureBagFractionIsLessThanZero() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", 0.0, 0.0, 0.5, 500, -0.00001, "result"));
assertThat(e.getMessage(), equalTo("[feature_bag_fraction] must be a double in (0, 1]"));
}
public void testRegression_GivenFeatureBagFractionIsGreaterThanOne() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", 0.0, 0.0, 0.5, 500, 1.00001, "result"));
assertThat(e.getMessage(), equalTo("[feature_bag_fraction] must be a double in (0, 1]"));
}
}

View File

@ -69,6 +69,15 @@ integTest.runner {
'ml/data_frame_analytics_crud/Test get stats given expression without matches and allow_no_match is false',
'ml/data_frame_analytics_crud/Test delete given missing config',
'ml/data_frame_analytics_crud/Test max model memory limit',
'ml/data_frame_analytics_crud/Test put regression given dependent_variable is not defined',
'ml/data_frame_analytics_crud/Test put regression given negative lambda',
'ml/data_frame_analytics_crud/Test put regression given negative gamma',
'ml/data_frame_analytics_crud/Test put regression given eta less than 1e-3',
'ml/data_frame_analytics_crud/Test put regression given eta greater than one',
'ml/data_frame_analytics_crud/Test put regression given maximum_number_trees is zero',
'ml/data_frame_analytics_crud/Test put regression given maximum_number_trees is greater than 2k',
'ml/data_frame_analytics_crud/Test put regression given feature_bag_fraction is negative',
'ml/data_frame_analytics_crud/Test put regression given feature_bag_fraction is greater than one',
'ml/evaluate_data_frame/Test given missing index',
'ml/evaluate_data_frame/Test given index does not exist',
'ml/evaluate_data_frame/Test given missing evaluation',

View File

@ -21,6 +21,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import java.io.IOException;
import java.util.ArrayList;
@ -118,4 +119,13 @@ abstract class MlNativeDataFrameAnalyticsIntegTestCase extends MlNativeIntegTest
assertThat(stats.get(0).getId(), equalTo(id));
assertThat(stats.get(0).getState(), equalTo(state));
}
protected static DataFrameAnalyticsConfig buildRegressionAnalytics(String id, String[] sourceIndex, String destIndex,
@Nullable String resultsField, String dependentVariable) {
DataFrameAnalyticsConfig.Builder configBuilder = new DataFrameAnalyticsConfig.Builder(id);
configBuilder.setSource(new DataFrameAnalyticsSource(sourceIndex, null));
configBuilder.setDest(new DataFrameAnalyticsDest(destIndex, resultsField));
configBuilder.setAnalysis(new Regression(dependentVariable));
return configBuilder.build();
}
}

View File

@ -21,9 +21,12 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
import org.junit.After;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
@ -362,4 +365,68 @@ public class RunDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsIntegTest
.setQuery(QueryBuilders.existsQuery("ml.outlier_score")).get();
assertThat(searchResponse.getHits().getTotalHits().value, equalTo((long) bulkRequestBuilder.numberOfActions()));
}
public void testRegressionWithNumericFeatureAndFewDocuments() throws Exception {
String sourceIndex = "test-regression-with-numeric-feature-and-few-docs";
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk();
bulkRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
List<Double> featureValues = Arrays.asList(1.0, 2.0, 3.0);
List<Double> dependentVariableValues = Arrays.asList(10.0, 20.0, 30.0);
for (int i = 0; i < 350; i++) {
Double field = featureValues.get(i % 3);
Double value = dependentVariableValues.get(i % 3);
IndexRequest indexRequest = new IndexRequest(sourceIndex);
if (i < 300) {
indexRequest.source("feature", field, "variable", value);
} else {
indexRequest.source("feature", field);
}
bulkRequestBuilder.add(indexRequest);
}
BulkResponse bulkResponse = bulkRequestBuilder.get();
if (bulkResponse.hasFailures()) {
fail("Failed to index data: " + bulkResponse.buildFailureMessage());
}
String id = "test_regression_with_numeric_feature_and_few_docs";
DataFrameAnalyticsConfig config = buildRegressionAnalytics(id, new String[] {sourceIndex},
sourceIndex + "-results", null, "variable");
registerAnalytics(config);
putAnalytics(config);
assertState(id, DataFrameAnalyticsState.STOPPED);
startAnalytics(id);
waitUntilAnalyticsIsStopped(id);
int resultsWithPrediction = 0;
SearchResponse sourceData = client().prepareSearch(sourceIndex).get();
for (SearchHit hit : sourceData.getHits()) {
GetResponse destDocGetResponse = client().prepareGet().setIndex(config.getDest().getIndex()).setId(hit.getId()).get();
assertThat(destDocGetResponse.isExists(), is(true));
Map<String, Object> sourceDoc = hit.getSourceAsMap();
Map<String, Object> destDoc = destDocGetResponse.getSource();
for (String field : sourceDoc.keySet()) {
assertThat(destDoc.containsKey(field), is(true));
assertThat(destDoc.get(field), equalTo(sourceDoc.get(field)));
}
assertThat(destDoc.containsKey("ml"), is(true));
@SuppressWarnings("unchecked")
Map<String, Object> resultsObject = (Map<String, Object>) destDoc.get("ml");
if (resultsObject.containsKey("variable_prediction")) {
resultsWithPrediction++;
double featureValue = (double) destDoc.get("feature");
double predictionValue = (double) resultsObject.get("variable_prediction");
// it seems for this case values can be as far off as 2.0
assertThat(predictionValue, closeTo(10 * featureValue, 2.0));
}
}
assertThat(resultsWithPrediction, greaterThan(0));
}
}

View File

@ -16,10 +16,12 @@ import org.elasticsearch.search.SearchHit;
import java.io.IOException;
import java.text.ParseException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
/**
* Represents a field to be extracted by the datafeed.
@ -37,11 +39,14 @@ public abstract class ExtractedField {
/** The name of the field we extract */
protected final String name;
private final Set<String> types;
private final ExtractionMethod extractionMethod;
protected ExtractedField(String alias, String name, ExtractionMethod extractionMethod) {
protected ExtractedField(String alias, String name, Set<String> types, ExtractionMethod extractionMethod) {
this.alias = Objects.requireNonNull(alias);
this.name = Objects.requireNonNull(name);
this.types = Objects.requireNonNull(types);
this.extractionMethod = Objects.requireNonNull(extractionMethod);
}
@ -53,6 +58,10 @@ public abstract class ExtractedField {
return name;
}
public Set<String> getTypes() {
return types;
}
public ExtractionMethod getExtractionMethod() {
return extractionMethod;
}
@ -65,32 +74,32 @@ public abstract class ExtractedField {
return null;
}
public static ExtractedField newTimeField(String name, ExtractionMethod extractionMethod) {
public static ExtractedField newTimeField(String name, Set<String> types, ExtractionMethod extractionMethod) {
if (extractionMethod == ExtractionMethod.SOURCE) {
throw new IllegalArgumentException("time field cannot be extracted from source");
}
return new TimeField(name, extractionMethod);
return new TimeField(name, types, extractionMethod);
}
public static ExtractedField newGeoShapeField(String alias, String name) {
return new GeoShapeField(alias, name);
return new GeoShapeField(alias, name, Collections.singleton("geo_shape"));
}
public static ExtractedField newGeoPointField(String alias, String name) {
return new GeoPointField(alias, name);
return new GeoPointField(alias, name, Collections.singleton("geo_point"));
}
public static ExtractedField newField(String name, ExtractionMethod extractionMethod) {
return newField(name, name, extractionMethod);
public static ExtractedField newField(String name, Set<String> types, ExtractionMethod extractionMethod) {
return newField(name, name, types, extractionMethod);
}
public static ExtractedField newField(String alias, String name, ExtractionMethod extractionMethod) {
public static ExtractedField newField(String alias, String name, Set<String> types, ExtractionMethod extractionMethod) {
switch (extractionMethod) {
case DOC_VALUE:
case SCRIPT_FIELD:
return new FromFields(alias, name, extractionMethod);
return new FromFields(alias, name, types, extractionMethod);
case SOURCE:
return new FromSource(alias, name);
return new FromSource(alias, name, types);
default:
throw new IllegalArgumentException("Invalid extraction method [" + extractionMethod + "]");
}
@ -98,7 +107,7 @@ public abstract class ExtractedField {
public ExtractedField newFromSource() {
if (supportsFromSource()) {
return new FromSource(alias, name);
return new FromSource(alias, name, types);
}
throw new IllegalStateException("Field (alias [" + alias + "], name [" + name + "]) should be extracted via ["
+ extractionMethod + "] and cannot be extracted from source");
@ -106,8 +115,8 @@ public abstract class ExtractedField {
private static class FromFields extends ExtractedField {
FromFields(String alias, String name, ExtractionMethod extractionMethod) {
super(alias, name, extractionMethod);
FromFields(String alias, String name, Set<String> types, ExtractionMethod extractionMethod) {
super(alias, name, types, extractionMethod);
}
@Override
@ -129,8 +138,8 @@ public abstract class ExtractedField {
private static class GeoShapeField extends FromSource {
private static final WellKnownText wkt = new WellKnownText(true, new StandardValidator(true));
GeoShapeField(String alias, String name) {
super(alias, name);
GeoShapeField(String alias, String name, Set<String> types) {
super(alias, name, types);
}
@Override
@ -186,8 +195,8 @@ public abstract class ExtractedField {
private static class GeoPointField extends FromFields {
GeoPointField(String alias, String name) {
super(alias, name, ExtractionMethod.DOC_VALUE);
GeoPointField(String alias, String name, Set<String> types) {
super(alias, name, types, ExtractionMethod.DOC_VALUE);
}
@Override
@ -222,8 +231,8 @@ public abstract class ExtractedField {
private static final String EPOCH_MILLIS_FORMAT = "epoch_millis";
TimeField(String name, ExtractionMethod extractionMethod) {
super(name, name, extractionMethod);
TimeField(String name, Set<String> types, ExtractionMethod extractionMethod) {
super(name, name, types, extractionMethod);
}
@Override
@ -255,8 +264,8 @@ public abstract class ExtractedField {
private String[] namePath;
FromSource(String alias, String name) {
super(alias, name, ExtractionMethod.SOURCE);
FromSource(String alias, String name, Set<String> types) {
super(alias, name, types, ExtractionMethod.SOURCE);
namePath = name.split("\\.");
}

View File

@ -47,15 +47,6 @@ public class ExtractedFields {
return docValueFields;
}
/**
* Returns a new instance which only contains fields matching the given extraction method
* @param method the extraction method to filter fields on
* @return a new instance which only contains fields matching the given extraction method
*/
public ExtractedFields filterFields(ExtractedField.ExtractionMethod method) {
return new ExtractedFields(filterFields(method, allFields));
}
private static List<ExtractedField> filterFields(ExtractedField.ExtractionMethod method, List<ExtractedField> fields) {
return fields.stream().filter(field -> field.getExtractionMethod() == method).collect(Collectors.toList());
}
@ -79,12 +70,13 @@ public class ExtractedFields {
protected ExtractedField detect(String field) {
String internalField = field;
ExtractedField.ExtractionMethod method = ExtractedField.ExtractionMethod.SOURCE;
Set<String> types = getTypes(field);
if (scriptFields.contains(field)) {
method = ExtractedField.ExtractionMethod.SCRIPT_FIELD;
} else if (isAggregatable(field)) {
method = ExtractedField.ExtractionMethod.DOC_VALUE;
if (isFieldOfType(field, "date")) {
return ExtractedField.newTimeField(field, method);
return ExtractedField.newTimeField(field, types, method);
}
} else if (isFieldOfType(field, TEXT)) {
String parentField = MlStrings.getParentField(field);
@ -107,7 +99,12 @@ public class ExtractedFields {
return ExtractedField.newGeoShapeField(field, internalField);
}
return ExtractedField.newField(field, internalField, method);
return ExtractedField.newField(field, internalField, types, method);
}
private Set<String> getTypes(String field) {
Map<String, FieldCapabilities> fieldCaps = fieldsCapabilities.getField(field);
return fieldCaps == null ? Collections.emptySet() : fieldCaps.keySet();
}
protected boolean isAggregatable(String field) {

View File

@ -12,6 +12,7 @@ import org.elasticsearch.xpack.core.ml.job.config.Job;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Set;
@ -55,12 +56,20 @@ public class TimeBasedExtractedFields extends ExtractedFields {
if (scriptFields.contains(timeField) == false && extractionMethodDetector.isAggregatable(timeField) == false) {
throw new IllegalArgumentException("cannot retrieve time field [" + timeField + "] because it is not aggregatable");
}
ExtractedField timeExtractedField = ExtractedField.newTimeField(timeField, scriptFields.contains(timeField) ?
ExtractedField.ExtractionMethod.SCRIPT_FIELD : ExtractedField.ExtractionMethod.DOC_VALUE);
ExtractedField timeExtractedField = extractedTimeField(timeField, scriptFields, fieldsCapabilities);
List<String> remainingFields = job.allInputFields().stream().filter(f -> !f.equals(timeField)).collect(Collectors.toList());
List<ExtractedField> allExtractedFields = new ArrayList<>(remainingFields.size() + 1);
allExtractedFields.add(timeExtractedField);
remainingFields.stream().forEach(field -> allExtractedFields.add(extractionMethodDetector.detect(field)));
return new TimeBasedExtractedFields(timeExtractedField, allExtractedFields);
}
private static ExtractedField extractedTimeField(String timeField, Set<String> scriptFields,
FieldCapabilitiesResponse fieldCapabilities) {
if (scriptFields.contains(timeField)) {
return ExtractedField.newTimeField(timeField, Collections.emptySet(), ExtractedField.ExtractionMethod.SCRIPT_FIELD);
}
return ExtractedField.newTimeField(timeField, fieldCapabilities.getField(timeField).keySet(),
ExtractedField.ExtractionMethod.DOC_VALUE);
}
}

View File

@ -29,11 +29,13 @@ import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;
import java.util.stream.Collectors;
@ -179,7 +181,7 @@ public class DataFrameDataExtractor {
for (int i = 0; i < extractedValues.length; ++i) {
ExtractedField field = context.extractedFields.getAllFields().get(i);
Object[] values = field.value(hit);
if (values.length == 1 && values[0] instanceof Number) {
if (values.length == 1 && (values[0] instanceof Number || values[0] instanceof String)) {
extractedValues[i] = Objects.toString(values[0]);
} else {
extractedValues = null;
@ -233,6 +235,17 @@ public class DataFrameDataExtractor {
return new DataSummary(searchResponse.getHits().getTotalHits().value, context.extractedFields.getAllFields().size());
}
public Set<String> getCategoricalFields() {
Set<String> categoricalFields = new HashSet<>();
for (ExtractedField extractedField : context.extractedFields.getAllFields()) {
String fieldName = extractedField.getName();
if (ExtractedFieldsDetector.CATEGORICAL_TYPES.containsAll(extractedField.getTypes())) {
categoricalFields.add(fieldName);
}
}
return categoricalFields;
}
public static class DataSummary {
public final long rows;

View File

@ -5,6 +5,8 @@
*/
package org.elasticsearch.xpack.ml.dataframe.extractor;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.fieldcaps.FieldCapabilities;
import org.elasticsearch.action.fieldcaps.FieldCapabilitiesResponse;
@ -20,6 +22,7 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.NameResolver;
import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedField;
import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedFields;
import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsIndex;
import java.util.ArrayList;
import java.util.Arrays;
@ -35,24 +38,24 @@ import java.util.stream.Stream;
public class ExtractedFieldsDetector {
private static final Logger LOGGER = LogManager.getLogger(ExtractedFieldsDetector.class);
/**
* Fields to ignore. These are mostly internal meta fields.
*/
private static final List<String> IGNORE_FIELDS = Arrays.asList("_id", "_field_names", "_index", "_parent", "_routing", "_seq_no",
"_source", "_type", "_uid", "_version", "_feature", "_ignored");
"_source", "_type", "_uid", "_version", "_feature", "_ignored", DataFrameAnalyticsIndex.ID_COPY);
/**
* The types supported by data frames
*/
private static final Set<String> COMPATIBLE_FIELD_TYPES;
public static final Set<String> CATEGORICAL_TYPES = Collections.unmodifiableSet(new HashSet<>(Arrays.asList("text", "keyword", "ip")));
private static final Set<String> NUMERICAL_TYPES;
static {
Set<String> compatibleTypes = Stream.of(NumberFieldMapper.NumberType.values())
Set<String> numericalTypes = Stream.of(NumberFieldMapper.NumberType.values())
.map(NumberFieldMapper.NumberType::typeName)
.collect(Collectors.toSet());
compatibleTypes.add("scaled_float"); // have to add manually since scaled_float is in a module
COMPATIBLE_FIELD_TYPES = Collections.unmodifiableSet(compatibleTypes);
numericalTypes.add("scaled_float");
NUMERICAL_TYPES = Collections.unmodifiableSet(numericalTypes);
}
private final String[] index;
@ -79,16 +82,18 @@ public class ExtractedFieldsDetector {
// Ignore fields under the results object
fields.removeIf(field -> field.startsWith(config.getDest().getResultsField() + "."));
includeAndExcludeFields(fields);
removeFieldsWithIncompatibleTypes(fields);
includeAndExcludeFields(fields, index);
checkRequiredFieldsArePresent(fields);
if (fields.isEmpty()) {
throw ExceptionsHelper.badRequestException("No compatible fields could be detected in index {}", Arrays.toString(index));
}
List<String> sortedFields = new ArrayList<>(fields);
// We sort the fields to ensure the checksum for each document is deterministic
Collections.sort(sortedFields);
ExtractedFields extractedFields = ExtractedFields.build(sortedFields, Collections.emptySet(), fieldCapabilitiesResponse)
.filterFields(ExtractedField.ExtractionMethod.DOC_VALUE);
if (extractedFields.getAllFields().isEmpty()) {
throw ExceptionsHelper.badRequestException("No compatible fields could be detected in index {}", Arrays.toString(index));
}
ExtractedFields extractedFields = ExtractedFields.build(sortedFields, Collections.emptySet(), fieldCapabilitiesResponse);
if (extractedFields.getDocValueFields().size() > docValueFieldsLimit) {
extractedFields = fetchFromSourceIfSupported(extractedFields);
if (extractedFields.getDocValueFields().size() > docValueFieldsLimit) {
@ -120,13 +125,25 @@ public class ExtractedFieldsDetector {
while (fieldsIterator.hasNext()) {
String field = fieldsIterator.next();
Map<String, FieldCapabilities> fieldCaps = fieldCapabilitiesResponse.getField(field);
if (fieldCaps == null || COMPATIBLE_FIELD_TYPES.containsAll(fieldCaps.keySet()) == false) {
if (fieldCaps == null) {
LOGGER.debug("[{}] Removing field [{}] because it is missing from mappings", config.getId(), field);
fieldsIterator.remove();
} else {
Set<String> fieldTypes = fieldCaps.keySet();
if (NUMERICAL_TYPES.containsAll(fieldTypes)) {
LOGGER.debug("[{}] field [{}] is compatible as it is numerical", config.getId(), field);
} else if (config.getAnalysis().supportsCategoricalFields() && CATEGORICAL_TYPES.containsAll(fieldTypes)) {
LOGGER.debug("[{}] field [{}] is compatible as it is categorical", config.getId(), field);
} else {
LOGGER.debug("[{}] Removing field [{}] because its types are not supported; types {}",
config.getId(), field, fieldTypes);
fieldsIterator.remove();
}
}
}
}
private void includeAndExcludeFields(Set<String> fields, String[] index) {
private void includeAndExcludeFields(Set<String> fields) {
FetchSourceContext analyzedFields = config.getAnalyzedFields();
if (analyzedFields == null) {
return;
@ -159,6 +176,16 @@ public class ExtractedFieldsDetector {
}
}
private void checkRequiredFieldsArePresent(Set<String> fields) {
List<String> missingFields = config.getAnalysis().getRequiredFields()
.stream()
.filter(f -> fields.contains(f) == false)
.collect(Collectors.toList());
if (missingFields.isEmpty() == false) {
throw ExceptionsHelper.badRequestException("required fields {} are missing", missingFields);
}
}
private ExtractedFields fetchFromSourceIfSupported(ExtractedFields extractedFields) {
List<ExtractedField> adjusted = new ArrayList<>(extractedFields.getAllFields().size());
for (ExtractedField field : extractedFields.getDocValueFields()) {

View File

@ -12,6 +12,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
import java.io.IOException;
import java.util.Objects;
import java.util.Set;
public class AnalyticsProcessConfig implements ToXContentObject {
@ -21,21 +22,24 @@ public class AnalyticsProcessConfig implements ToXContentObject {
private static final String THREADS = "threads";
private static final String ANALYSIS = "analysis";
private static final String RESULTS_FIELD = "results_field";
private static final String CATEGORICAL_FIELDS = "categorical_fields";
private final long rows;
private final int cols;
private final ByteSizeValue memoryLimit;
private final int threads;
private final DataFrameAnalysis analysis;
private final String resultsField;
private final Set<String> categoricalFields;
private final DataFrameAnalysis analysis;
public AnalyticsProcessConfig(long rows, int cols, ByteSizeValue memoryLimit, int threads, String resultsField,
DataFrameAnalysis analysis) {
Set<String> categoricalFields, DataFrameAnalysis analysis) {
this.rows = rows;
this.cols = cols;
this.memoryLimit = Objects.requireNonNull(memoryLimit);
this.threads = threads;
this.resultsField = Objects.requireNonNull(resultsField);
this.categoricalFields = Objects.requireNonNull(categoricalFields);
this.analysis = Objects.requireNonNull(analysis);
}
@ -51,6 +55,7 @@ public class AnalyticsProcessConfig implements ToXContentObject {
builder.field(MEMORY_LIMIT, memoryLimit.getBytes());
builder.field(THREADS, threads);
builder.field(RESULTS_FIELD, resultsField);
builder.field(CATEGORICAL_FIELDS, categoricalFields);
builder.field(ANALYSIS, new DataFrameAnalysisWrapper(analysis));
builder.endObject();
return builder;

View File

@ -26,6 +26,7 @@ import java.io.IOException;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ExecutorService;
@ -283,8 +284,9 @@ public class AnalyticsProcessManager {
private AnalyticsProcessConfig createProcessConfig(DataFrameAnalyticsConfig config, DataFrameDataExtractor dataExtractor) {
DataFrameDataExtractor.DataSummary dataSummary = dataExtractor.collectDataSummary();
Set<String> categoricalFields = dataExtractor.getCategoricalFields();
AnalyticsProcessConfig processConfig = new AnalyticsProcessConfig(dataSummary.rows, dataSummary.cols,
config.getModelMemoryLimit(), 1, config.getDest().getResultsField(), config.getAnalysis());
config.getModelMemoryLimit(), 1, config.getDest().getResultsField(), categoricalFields, config.getAnalysis());
return processConfig;
}
}

View File

@ -10,6 +10,7 @@ import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.ml.test.SearchHitBuilder;
import java.util.Arrays;
import java.util.Collections;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.startsWith;
@ -19,46 +20,51 @@ public class ExtractedFieldTests extends ESTestCase {
public void testValueGivenDocValue() {
SearchHit hit = new SearchHitBuilder(42).addField("single", "bar").addField("array", Arrays.asList("a", "b")).build();
ExtractedField single = ExtractedField.newField("single", ExtractedField.ExtractionMethod.DOC_VALUE);
ExtractedField single = ExtractedField.newField("single", Collections.singleton("keyword"),
ExtractedField.ExtractionMethod.DOC_VALUE);
assertThat(single.value(hit), equalTo(new String[] { "bar" }));
ExtractedField array = ExtractedField.newField("array", ExtractedField.ExtractionMethod.DOC_VALUE);
ExtractedField array = ExtractedField.newField("array", Collections.singleton("keyword"),
ExtractedField.ExtractionMethod.DOC_VALUE);
assertThat(array.value(hit), equalTo(new String[] { "a", "b" }));
ExtractedField missing = ExtractedField.newField("missing", ExtractedField.ExtractionMethod.DOC_VALUE);
ExtractedField missing = ExtractedField.newField("missing",Collections.singleton("keyword"),
ExtractedField.ExtractionMethod.DOC_VALUE);
assertThat(missing.value(hit), equalTo(new Object[0]));
}
public void testValueGivenScriptField() {
SearchHit hit = new SearchHitBuilder(42).addField("single", "bar").addField("array", Arrays.asList("a", "b")).build();
ExtractedField single = ExtractedField.newField("single", ExtractedField.ExtractionMethod.SCRIPT_FIELD);
ExtractedField single = ExtractedField.newField("single",Collections.emptySet(),
ExtractedField.ExtractionMethod.SCRIPT_FIELD);
assertThat(single.value(hit), equalTo(new String[] { "bar" }));
ExtractedField array = ExtractedField.newField("array", ExtractedField.ExtractionMethod.SCRIPT_FIELD);
ExtractedField array = ExtractedField.newField("array", Collections.emptySet(), ExtractedField.ExtractionMethod.SCRIPT_FIELD);
assertThat(array.value(hit), equalTo(new String[] { "a", "b" }));
ExtractedField missing = ExtractedField.newField("missing", ExtractedField.ExtractionMethod.SCRIPT_FIELD);
ExtractedField missing = ExtractedField.newField("missing", Collections.emptySet(), ExtractedField.ExtractionMethod.SCRIPT_FIELD);
assertThat(missing.value(hit), equalTo(new Object[0]));
}
public void testValueGivenSource() {
SearchHit hit = new SearchHitBuilder(42).setSource("{\"single\":\"bar\",\"array\":[\"a\",\"b\"]}").build();
ExtractedField single = ExtractedField.newField("single", ExtractedField.ExtractionMethod.SOURCE);
ExtractedField single = ExtractedField.newField("single", Collections.singleton("text"), ExtractedField.ExtractionMethod.SOURCE);
assertThat(single.value(hit), equalTo(new String[] { "bar" }));
ExtractedField array = ExtractedField.newField("array", ExtractedField.ExtractionMethod.SOURCE);
ExtractedField array = ExtractedField.newField("array", Collections.singleton("text"), ExtractedField.ExtractionMethod.SOURCE);
assertThat(array.value(hit), equalTo(new String[] { "a", "b" }));
ExtractedField missing = ExtractedField.newField("missing", ExtractedField.ExtractionMethod.SOURCE);
ExtractedField missing = ExtractedField.newField("missing", Collections.singleton("text"), ExtractedField.ExtractionMethod.SOURCE);
assertThat(missing.value(hit), equalTo(new Object[0]));
}
public void testValueGivenNestedSource() {
SearchHit hit = new SearchHitBuilder(42).setSource("{\"level_1\":{\"level_2\":{\"foo\":\"bar\"}}}").build();
ExtractedField nested = ExtractedField.newField("alias", "level_1.level_2.foo", ExtractedField.ExtractionMethod.SOURCE);
ExtractedField nested = ExtractedField.newField("alias", "level_1.level_2.foo", Collections.singleton("text"),
ExtractedField.ExtractionMethod.SOURCE);
assertThat(nested.value(hit), equalTo(new String[] { "bar" }));
}
@ -91,49 +97,54 @@ public class ExtractedFieldTests extends ESTestCase {
}
public void testValueGivenSourceAndHitWithNoSource() {
ExtractedField missing = ExtractedField.newField("missing", ExtractedField.ExtractionMethod.SOURCE);
ExtractedField missing = ExtractedField.newField("missing", Collections.singleton("text"), ExtractedField.ExtractionMethod.SOURCE);
assertThat(missing.value(new SearchHitBuilder(3).build()), equalTo(new Object[0]));
}
public void testValueGivenMismatchingMethod() {
SearchHit hit = new SearchHitBuilder(42).addField("a", 1).setSource("{\"b\":2}").build();
ExtractedField invalidA = ExtractedField.newField("a", ExtractedField.ExtractionMethod.SOURCE);
ExtractedField invalidA = ExtractedField.newField("a", Collections.singleton("text"), ExtractedField.ExtractionMethod.SOURCE);
assertThat(invalidA.value(hit), equalTo(new Object[0]));
ExtractedField validA = ExtractedField.newField("a", ExtractedField.ExtractionMethod.DOC_VALUE);
ExtractedField validA = ExtractedField.newField("a", Collections.singleton("keyword"), ExtractedField.ExtractionMethod.DOC_VALUE);
assertThat(validA.value(hit), equalTo(new Integer[] { 1 }));
ExtractedField invalidB = ExtractedField.newField("b", ExtractedField.ExtractionMethod.DOC_VALUE);
ExtractedField invalidB = ExtractedField.newField("b", Collections.singleton("keyword"),
ExtractedField.ExtractionMethod.DOC_VALUE);
assertThat(invalidB.value(hit), equalTo(new Object[0]));
ExtractedField validB = ExtractedField.newField("b", ExtractedField.ExtractionMethod.SOURCE);
ExtractedField validB = ExtractedField.newField("b", Collections.singleton("text"), ExtractedField.ExtractionMethod.SOURCE);
assertThat(validB.value(hit), equalTo(new Integer[] { 2 }));
}
public void testValueGivenEmptyHit() {
SearchHit hit = new SearchHitBuilder(42).build();
ExtractedField docValue = ExtractedField.newField("a", ExtractedField.ExtractionMethod.SOURCE);
ExtractedField docValue = ExtractedField.newField("a", Collections.singleton("text"), ExtractedField.ExtractionMethod.SOURCE);
assertThat(docValue.value(hit), equalTo(new Object[0]));
ExtractedField sourceField = ExtractedField.newField("b", ExtractedField.ExtractionMethod.DOC_VALUE);
ExtractedField sourceField = ExtractedField.newField("b", Collections.singleton("keyword"),
ExtractedField.ExtractionMethod.DOC_VALUE);
assertThat(sourceField.value(hit), equalTo(new Object[0]));
}
public void testNewTimeFieldGivenSource() {
expectThrows(IllegalArgumentException.class, () -> ExtractedField.newTimeField("time", ExtractedField.ExtractionMethod.SOURCE));
expectThrows(IllegalArgumentException.class, () -> ExtractedField.newTimeField("time", Collections.singleton("date"),
ExtractedField.ExtractionMethod.SOURCE));
}
public void testValueGivenStringTimeField() {
final long millis = randomLong();
final SearchHit hit = new SearchHitBuilder(randomInt()).addField("time", Long.toString(millis)).build();
final ExtractedField timeField = ExtractedField.newTimeField("time", ExtractedField.ExtractionMethod.DOC_VALUE);
final ExtractedField timeField = ExtractedField.newTimeField("time", Collections.singleton("date"),
ExtractedField.ExtractionMethod.DOC_VALUE);
assertThat(timeField.value(hit), equalTo(new Object[] { millis }));
}
public void testValueGivenLongTimeField() {
final long millis = randomLong();
final SearchHit hit = new SearchHitBuilder(randomInt()).addField("time", millis).build();
final ExtractedField timeField = ExtractedField.newTimeField("time", ExtractedField.ExtractionMethod.DOC_VALUE);
final ExtractedField timeField = ExtractedField.newTimeField("time", Collections.singleton("date"),
ExtractedField.ExtractionMethod.DOC_VALUE);
assertThat(timeField.value(hit), equalTo(new Object[] { millis }));
}
@ -141,13 +152,15 @@ public class ExtractedFieldTests extends ESTestCase {
// Prior to 6.x, timestamps were simply `long` milliseconds-past-the-epoch values
final long millis = randomLong();
final SearchHit hit = new SearchHitBuilder(randomInt()).addField("time", millis).build();
final ExtractedField timeField = ExtractedField.newTimeField("time", ExtractedField.ExtractionMethod.DOC_VALUE);
final ExtractedField timeField = ExtractedField.newTimeField("time", Collections.singleton("date"),
ExtractedField.ExtractionMethod.DOC_VALUE);
assertThat(timeField.value(hit), equalTo(new Object[] { millis }));
}
public void testValueGivenUnknownFormatTimeField() {
final SearchHit hit = new SearchHitBuilder(randomInt()).addField("time", new Object()).build();
final ExtractedField timeField = ExtractedField.newTimeField("time", ExtractedField.ExtractionMethod.DOC_VALUE);
final ExtractedField timeField = ExtractedField.newTimeField("time", Collections.singleton("date"),
ExtractedField.ExtractionMethod.DOC_VALUE);
assertThat(expectThrows(IllegalStateException.class, () -> timeField.value(hit)).getMessage(),
startsWith("Unexpected value for a time field"));
}
@ -155,14 +168,15 @@ public class ExtractedFieldTests extends ESTestCase {
public void testAliasVersusName() {
SearchHit hit = new SearchHitBuilder(42).addField("a", 1).addField("b", 2).build();
ExtractedField field = ExtractedField.newField("a", "a", ExtractedField.ExtractionMethod.DOC_VALUE);
ExtractedField field = ExtractedField.newField("a", "a", Collections.singleton("int"),
ExtractedField.ExtractionMethod.DOC_VALUE);
assertThat(field.getAlias(), equalTo("a"));
assertThat(field.getName(), equalTo("a"));
assertThat(field.value(hit), equalTo(new Integer[] { 1 }));
hit = new SearchHitBuilder(42).addField("a", 1).addField("b", 2).build();
field = ExtractedField.newField("a", "b", ExtractedField.ExtractionMethod.DOC_VALUE);
field = ExtractedField.newField("a", "b", Collections.singleton("int"), ExtractedField.ExtractionMethod.DOC_VALUE);
assertThat(field.getAlias(), equalTo("a"));
assertThat(field.getName(), equalTo("b"));
assertThat(field.value(hit), equalTo(new Integer[] { 2 }));
@ -170,11 +184,11 @@ public class ExtractedFieldTests extends ESTestCase {
public void testGetDocValueFormat() {
for (ExtractedField.ExtractionMethod method : ExtractedField.ExtractionMethod.values()) {
assertThat(ExtractedField.newField("f", method).getDocValueFormat(), equalTo(null));
assertThat(ExtractedField.newField("f", Collections.emptySet(), method).getDocValueFormat(), equalTo(null));
}
assertThat(ExtractedField.newTimeField("doc_value_time", ExtractedField.ExtractionMethod.DOC_VALUE).getDocValueFormat(),
equalTo("epoch_millis"));
assertThat(ExtractedField.newTimeField("source_time", ExtractedField.ExtractionMethod.SCRIPT_FIELD).getDocValueFormat(),
equalTo("epoch_millis"));
assertThat(ExtractedField.newTimeField("doc_value_time", Collections.singleton("date"),
ExtractedField.ExtractionMethod.DOC_VALUE).getDocValueFormat(), equalTo("epoch_millis"));
assertThat(ExtractedField.newTimeField("source_time", Collections.emptySet(),
ExtractedField.ExtractionMethod.SCRIPT_FIELD).getDocValueFormat(), equalTo("epoch_millis"));
}
}

View File

@ -27,12 +27,18 @@ import static org.mockito.Mockito.when;
public class ExtractedFieldsTests extends ESTestCase {
public void testAllTypesOfFields() {
ExtractedField docValue1 = ExtractedField.newField("doc1", ExtractedField.ExtractionMethod.DOC_VALUE);
ExtractedField docValue2 = ExtractedField.newField("doc2", ExtractedField.ExtractionMethod.DOC_VALUE);
ExtractedField scriptField1 = ExtractedField.newField("scripted1", ExtractedField.ExtractionMethod.SCRIPT_FIELD);
ExtractedField scriptField2 = ExtractedField.newField("scripted2", ExtractedField.ExtractionMethod.SCRIPT_FIELD);
ExtractedField sourceField1 = ExtractedField.newField("src1", ExtractedField.ExtractionMethod.SOURCE);
ExtractedField sourceField2 = ExtractedField.newField("src2", ExtractedField.ExtractionMethod.SOURCE);
ExtractedField docValue1 = ExtractedField.newField("doc1", Collections.singleton("keyword"),
ExtractedField.ExtractionMethod.DOC_VALUE);
ExtractedField docValue2 = ExtractedField.newField("doc2", Collections.singleton("ip"),
ExtractedField.ExtractionMethod.DOC_VALUE);
ExtractedField scriptField1 = ExtractedField.newField("scripted1", Collections.emptySet(),
ExtractedField.ExtractionMethod.SCRIPT_FIELD);
ExtractedField scriptField2 = ExtractedField.newField("scripted2", Collections.emptySet(),
ExtractedField.ExtractionMethod.SCRIPT_FIELD);
ExtractedField sourceField1 = ExtractedField.newField("src1", Collections.singleton("text"),
ExtractedField.ExtractionMethod.SOURCE);
ExtractedField sourceField2 = ExtractedField.newField("src2", Collections.singleton("text"),
ExtractedField.ExtractionMethod.SOURCE);
ExtractedFields extractedFields = new ExtractedFields(Arrays.asList(
docValue1, docValue2, scriptField1, scriptField2, sourceField1, sourceField2));

View File

@ -29,7 +29,8 @@ import static org.mockito.Mockito.when;
public class TimeBasedExtractedFieldsTests extends ESTestCase {
private ExtractedField timeField = ExtractedField.newTimeField("time", ExtractedField.ExtractionMethod.DOC_VALUE);
private ExtractedField timeField = ExtractedField.newTimeField("time", Collections.singleton("date"),
ExtractedField.ExtractionMethod.DOC_VALUE);
public void testInvalidConstruction() {
expectThrows(IllegalArgumentException.class, () -> new TimeBasedExtractedFields(timeField, Collections.emptyList()));
@ -46,12 +47,18 @@ public class TimeBasedExtractedFieldsTests extends ESTestCase {
}
public void testAllTypesOfFields() {
ExtractedField docValue1 = ExtractedField.newField("doc1", ExtractedField.ExtractionMethod.DOC_VALUE);
ExtractedField docValue2 = ExtractedField.newField("doc2", ExtractedField.ExtractionMethod.DOC_VALUE);
ExtractedField scriptField1 = ExtractedField.newField("scripted1", ExtractedField.ExtractionMethod.SCRIPT_FIELD);
ExtractedField scriptField2 = ExtractedField.newField("scripted2", ExtractedField.ExtractionMethod.SCRIPT_FIELD);
ExtractedField sourceField1 = ExtractedField.newField("src1", ExtractedField.ExtractionMethod.SOURCE);
ExtractedField sourceField2 = ExtractedField.newField("src2", ExtractedField.ExtractionMethod.SOURCE);
ExtractedField docValue1 = ExtractedField.newField("doc1", Collections.singleton("keyword"),
ExtractedField.ExtractionMethod.DOC_VALUE);
ExtractedField docValue2 = ExtractedField.newField("doc2", Collections.singleton("float"),
ExtractedField.ExtractionMethod.DOC_VALUE);
ExtractedField scriptField1 = ExtractedField.newField("scripted1", Collections.emptySet(),
ExtractedField.ExtractionMethod.SCRIPT_FIELD);
ExtractedField scriptField2 = ExtractedField.newField("scripted2", Collections.emptySet(),
ExtractedField.ExtractionMethod.SCRIPT_FIELD);
ExtractedField sourceField1 = ExtractedField.newField("src1", Collections.singleton("text"),
ExtractedField.ExtractionMethod.SOURCE);
ExtractedField sourceField2 = ExtractedField.newField("src2", Collections.singleton("text"),
ExtractedField.ExtractionMethod.SOURCE);
TimeBasedExtractedFields extractedFields = new TimeBasedExtractedFields(timeField, Arrays.asList(timeField,
docValue1, docValue2, scriptField1, scriptField2, sourceField1, sourceField2));

View File

@ -135,9 +135,11 @@ public class ScrollDataExtractorTests extends ESTestCase {
capturedSearchRequests = new ArrayList<>();
capturedContinueScrollIds = new ArrayList<>();
jobId = "test-job";
ExtractedField timeField = ExtractedField.newField("time", ExtractedField.ExtractionMethod.DOC_VALUE);
ExtractedField timeField = ExtractedField.newField("time", Collections.singleton("date"),
ExtractedField.ExtractionMethod.DOC_VALUE);
extractedFields = new TimeBasedExtractedFields(timeField,
Arrays.asList(timeField, ExtractedField.newField("field_1", ExtractedField.ExtractionMethod.DOC_VALUE)));
Arrays.asList(timeField, ExtractedField.newField("field_1", Collections.singleton("keyword"),
ExtractedField.ExtractionMethod.DOC_VALUE)));
indices = Arrays.asList("index-1", "index-2");
query = QueryBuilders.matchAllQuery();
scriptFields = Collections.emptyList();

View File

@ -16,16 +16,21 @@ import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Collections;
import static org.hamcrest.Matchers.equalTo;
public class SearchHitToJsonProcessorTests extends ESTestCase {
public void testProcessGivenSingleHit() throws IOException {
ExtractedField timeField = ExtractedField.newField("time", ExtractedField.ExtractionMethod.DOC_VALUE);
ExtractedField missingField = ExtractedField.newField("missing", ExtractedField.ExtractionMethod.DOC_VALUE);
ExtractedField singleField = ExtractedField.newField("single", ExtractedField.ExtractionMethod.DOC_VALUE);
ExtractedField arrayField = ExtractedField.newField("array", ExtractedField.ExtractionMethod.DOC_VALUE);
ExtractedField timeField = ExtractedField.newField("time", Collections.singleton("date"),
ExtractedField.ExtractionMethod.DOC_VALUE);
ExtractedField missingField = ExtractedField.newField("missing", Collections.singleton("float"),
ExtractedField.ExtractionMethod.DOC_VALUE);
ExtractedField singleField = ExtractedField.newField("single", Collections.singleton("keyword"),
ExtractedField.ExtractionMethod.DOC_VALUE);
ExtractedField arrayField = ExtractedField.newField("array", Collections.singleton("keyword"),
ExtractedField.ExtractionMethod.DOC_VALUE);
TimeBasedExtractedFields extractedFields = new TimeBasedExtractedFields(timeField,
Arrays.asList(timeField, missingField, singleField, arrayField));
@ -41,10 +46,14 @@ public class SearchHitToJsonProcessorTests extends ESTestCase {
}
public void testProcessGivenMultipleHits() throws IOException {
ExtractedField timeField = ExtractedField.newField("time", ExtractedField.ExtractionMethod.DOC_VALUE);
ExtractedField missingField = ExtractedField.newField("missing", ExtractedField.ExtractionMethod.DOC_VALUE);
ExtractedField singleField = ExtractedField.newField("single", ExtractedField.ExtractionMethod.DOC_VALUE);
ExtractedField arrayField = ExtractedField.newField("array", ExtractedField.ExtractionMethod.DOC_VALUE);
ExtractedField timeField = ExtractedField.newField("time", Collections.singleton("date"),
ExtractedField.ExtractionMethod.DOC_VALUE);
ExtractedField missingField = ExtractedField.newField("missing", Collections.singleton("float"),
ExtractedField.ExtractionMethod.DOC_VALUE);
ExtractedField singleField = ExtractedField.newField("single", Collections.singleton("keyword"),
ExtractedField.ExtractionMethod.DOC_VALUE);
ExtractedField arrayField = ExtractedField.newField("array", Collections.singleton("keyword"),
ExtractedField.ExtractionMethod.DOC_VALUE);
TimeBasedExtractedFields extractedFields = new TimeBasedExtractedFields(timeField,
Arrays.asList(timeField, missingField, singleField, arrayField));

View File

@ -71,8 +71,8 @@ public class DataFrameDataExtractorTests extends ESTestCase {
indices = Arrays.asList("index-1", "index-2");
query = QueryBuilders.matchAllQuery();
extractedFields = new ExtractedFields(Arrays.asList(
ExtractedField.newField("field_1", ExtractedField.ExtractionMethod.DOC_VALUE),
ExtractedField.newField("field_2", ExtractedField.ExtractionMethod.DOC_VALUE)));
ExtractedField.newField("field_1", Collections.singleton("keyword"), ExtractedField.ExtractionMethod.DOC_VALUE),
ExtractedField.newField("field_2", Collections.singleton("keyword"), ExtractedField.ExtractionMethod.DOC_VALUE)));
scrollSize = 1000;
headers = Collections.emptyMap();
@ -288,8 +288,8 @@ public class DataFrameDataExtractorTests extends ESTestCase {
public void testIncludeSourceIsFalseAndAtLeastOneSourceField() throws IOException {
extractedFields = new ExtractedFields(Arrays.asList(
ExtractedField.newField("field_1", ExtractedField.ExtractionMethod.DOC_VALUE),
ExtractedField.newField("field_2", ExtractedField.ExtractionMethod.SOURCE)));
ExtractedField.newField("field_1", Collections.singleton("keyword"), ExtractedField.ExtractionMethod.DOC_VALUE),
ExtractedField.newField("field_2", Collections.singleton("text"), ExtractedField.ExtractionMethod.SOURCE)));
TestExtractor dataExtractor = createExtractor(false);

View File

@ -14,6 +14,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedField;
import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedFields;
@ -38,11 +39,11 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
private static final String RESULTS_FIELD = "ml";
public void testDetect_GivenFloatField() {
FieldCapabilitiesResponse fieldCapabilities= new MockFieldCapsResponseBuilder()
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
.addAggregatableField("some_float", "float").build();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildAnalyticsConfig(), false, 100, fieldCapabilities);
SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities);
ExtractedFields extractedFields = extractedFieldsDetector.detect();
List<ExtractedField> allFields = extractedFields.getAllFields();
@ -52,12 +53,12 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
}
public void testDetect_GivenNumericFieldWithMultipleTypes() {
FieldCapabilitiesResponse fieldCapabilities= new MockFieldCapsResponseBuilder()
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
.addAggregatableField("some_number", "long", "integer", "short", "byte", "double", "float", "half_float", "scaled_float")
.build();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildAnalyticsConfig(), false, 100, fieldCapabilities);
SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities);
ExtractedFields extractedFields = extractedFieldsDetector.detect();
List<ExtractedField> allFields = extractedFields.getAllFields();
@ -67,36 +68,36 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
}
public void testDetect_GivenNonNumericField() {
FieldCapabilitiesResponse fieldCapabilities= new MockFieldCapsResponseBuilder()
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
.addAggregatableField("some_keyword", "keyword").build();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildAnalyticsConfig(), false, 100, fieldCapabilities);
SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities);
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect());
assertThat(e.getMessage(), equalTo("No compatible fields could be detected in index [source_index]"));
}
public void testDetect_GivenFieldWithNumericAndNonNumericTypes() {
FieldCapabilitiesResponse fieldCapabilities= new MockFieldCapsResponseBuilder()
public void testDetect_GivenOutlierDetectionAndFieldWithNumericAndNonNumericTypes() {
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
.addAggregatableField("indecisive_field", "float", "keyword").build();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildAnalyticsConfig(), false, 100, fieldCapabilities);
SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities);
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect());
assertThat(e.getMessage(), equalTo("No compatible fields could be detected in index [source_index]"));
}
public void testDetect_GivenMultipleFields() {
FieldCapabilitiesResponse fieldCapabilities= new MockFieldCapsResponseBuilder()
public void testDetect_GivenOutlierDetectionAndMultipleFields() {
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
.addAggregatableField("some_float", "float")
.addAggregatableField("some_long", "long")
.addAggregatableField("some_keyword", "keyword")
.build();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildAnalyticsConfig(), false, 100, fieldCapabilities);
SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities);
ExtractedFields extractedFields = extractedFieldsDetector.detect();
List<ExtractedField> allFields = extractedFields.getAllFields();
@ -107,12 +108,46 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
contains(equalTo(ExtractedField.ExtractionMethod.DOC_VALUE)));
}
public void testDetect_GivenRegressionAndMultipleFields() {
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
.addAggregatableField("some_float", "float")
.addAggregatableField("some_long", "long")
.addAggregatableField("some_keyword", "keyword")
.addAggregatableField("foo", "keyword")
.build();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildRegressionConfig("foo"), false, 100, fieldCapabilities);
ExtractedFields extractedFields = extractedFieldsDetector.detect();
List<ExtractedField> allFields = extractedFields.getAllFields();
assertThat(allFields.size(), equalTo(4));
assertThat(allFields.stream().map(ExtractedField::getName).collect(Collectors.toList()),
contains("foo", "some_float", "some_keyword", "some_long"));
assertThat(allFields.stream().map(ExtractedField::getExtractionMethod).collect(Collectors.toSet()),
contains(equalTo(ExtractedField.ExtractionMethod.DOC_VALUE)));
}
public void testDetect_GivenRegressionAndRequiredFieldMissing() {
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
.addAggregatableField("some_float", "float")
.addAggregatableField("some_long", "long")
.addAggregatableField("some_keyword", "keyword")
.build();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildRegressionConfig("foo"), false, 100, fieldCapabilities);
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect());
assertThat(e.getMessage(), equalTo("required fields [foo] are missing"));
}
public void testDetect_GivenIgnoredField() {
FieldCapabilitiesResponse fieldCapabilities= new MockFieldCapsResponseBuilder()
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
.addAggregatableField("_id", "float").build();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildAnalyticsConfig(), false, 100, fieldCapabilities);
SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities);
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect());
assertThat(e.getMessage(), equalTo("No compatible fields could be detected in index [source_index]"));
@ -134,7 +169,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
FieldCapabilitiesResponse fieldCapabilities = mockFieldCapsResponseBuilder.build();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildAnalyticsConfig(), false, 100, fieldCapabilities);
SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities);
ExtractedFields extractedFields = extractedFieldsDetector.detect();
List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName)
@ -151,7 +186,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
FetchSourceContext desiredFields = new FetchSourceContext(true, new String[]{"your_field1", "my*"}, new String[0]);
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildAnalyticsConfig(desiredFields), false, 100, fieldCapabilities);
SOURCE_INDEX, buildOutlierDetectionConfig(desiredFields), false, 100, fieldCapabilities);
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect());
assertThat(e.getMessage(), equalTo("No field [your_field1] could be detected"));
@ -166,7 +201,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
FetchSourceContext desiredFields = new FetchSourceContext(true, new String[0], new String[]{"my_*"});
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildAnalyticsConfig(desiredFields), false, 100, fieldCapabilities);
SOURCE_INDEX, buildOutlierDetectionConfig(desiredFields), false, 100, fieldCapabilities);
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect());
assertThat(e.getMessage(), equalTo("No compatible fields could be detected in index [source_index]"));
}
@ -182,7 +217,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
FetchSourceContext desiredFields = new FetchSourceContext(true, new String[]{"your*", "my_*"}, new String[]{"*nope"});
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildAnalyticsConfig(desiredFields), false, 100, fieldCapabilities);
SOURCE_INDEX, buildOutlierDetectionConfig(desiredFields), false, 100, fieldCapabilities);
ExtractedFields extractedFields = extractedFieldsDetector.detect();
List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName)
@ -199,7 +234,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
.build();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildAnalyticsConfig(), false, 100, fieldCapabilities);
SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities);
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect());
assertThat(e.getMessage(), equalTo("A field that matches the dest.results_field [ml] already exists; " +
@ -215,7 +250,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
.build();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildAnalyticsConfig(), true, 100, fieldCapabilities);
SOURCE_INDEX, buildOutlierDetectionConfig(), true, 100, fieldCapabilities);
ExtractedFields extractedFields = extractedFieldsDetector.detect();
List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName)
@ -232,7 +267,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
.build();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildAnalyticsConfig(), true, 4, fieldCapabilities);
SOURCE_INDEX, buildOutlierDetectionConfig(), true, 4, fieldCapabilities);
ExtractedFields extractedFields = extractedFieldsDetector.detect();
List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName)
@ -251,7 +286,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
.build();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildAnalyticsConfig(), true, 3, fieldCapabilities);
SOURCE_INDEX, buildOutlierDetectionConfig(), true, 3, fieldCapabilities);
ExtractedFields extractedFields = extractedFieldsDetector.detect();
List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName)
@ -270,7 +305,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
.build();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildAnalyticsConfig(), true, 2, fieldCapabilities);
SOURCE_INDEX, buildOutlierDetectionConfig(), true, 2, fieldCapabilities);
ExtractedFields extractedFields = extractedFieldsDetector.detect();
List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName)
@ -280,11 +315,11 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
contains(equalTo(ExtractedField.ExtractionMethod.SOURCE)));
}
private static DataFrameAnalyticsConfig buildAnalyticsConfig() {
return buildAnalyticsConfig(null);
private static DataFrameAnalyticsConfig buildOutlierDetectionConfig() {
return buildOutlierDetectionConfig(null);
}
private static DataFrameAnalyticsConfig buildAnalyticsConfig(FetchSourceContext analyzedFields) {
private static DataFrameAnalyticsConfig buildOutlierDetectionConfig(FetchSourceContext analyzedFields) {
return new DataFrameAnalyticsConfig.Builder("foo")
.setSource(new DataFrameAnalyticsSource(SOURCE_INDEX, null))
.setDest(new DataFrameAnalyticsDest(DEST_INDEX, null))
@ -293,6 +328,19 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
.build();
}
private static DataFrameAnalyticsConfig buildRegressionConfig(String dependentVariable) {
return buildRegressionConfig(dependentVariable, null);
}
private static DataFrameAnalyticsConfig buildRegressionConfig(String dependentVariable, FetchSourceContext analyzedFields) {
return new DataFrameAnalyticsConfig.Builder("foo")
.setSource(new DataFrameAnalyticsSource(SOURCE_INDEX, null))
.setDest(new DataFrameAnalyticsDest(DEST_INDEX, null))
.setAnalyzedFields(analyzedFields)
.setAnalysis(new Regression(dependentVariable))
.build();
}
private static class MockFieldCapsResponseBuilder {
private final Map<String, Map<String, FieldCapabilities>> fieldCaps = new HashMap<>();

View File

@ -607,7 +607,11 @@ setup:
"dest": {
"index": "index-bar_dest"
},
"analysis": {"outlier_detection":{}}
"analysis": {
"regression":{
"dependent_variable": "to_predict"
}
}
}
- match: { id: "bar" }
@ -768,7 +772,11 @@ setup:
"dest": {
"index": "index-bar_dest"
},
"analysis": {"outlier_detection":{}}
"analysis": {
"regression":{
"dependent_variable": "to_predict"
}
}
}
- match: { id: "bar" }
@ -930,3 +938,247 @@ setup:
xpack.ml.max_model_memory_limit: null
- match: {transient: {}}
---
"Test put regression given dependent_variable is not defined":
- do:
catch: /parse_exception/
ml.put_data_frame_analytics:
id: "regression-without-dependent-variable"
body: >
{
"source": {
"index": "index-source"
},
"dest": {
"index": "index-dest"
},
"analysis": {
"regression": {}
}
}
---
"Test put regression given negative lambda":
- do:
catch: /\[lambda\] must be a non-negative double/
ml.put_data_frame_analytics:
id: "regression-negative-lambda"
body: >
{
"source": {
"index": "index-source"
},
"dest": {
"index": "index-dest"
},
"analysis": {
"regression": {
"dependent_variable": "foo",
"lambda": -1.0
}
}
}
---
"Test put regression given negative gamma":
- do:
catch: /\[gamma\] must be a non-negative double/
ml.put_data_frame_analytics:
id: "regression-negative-gamma"
body: >
{
"source": {
"index": "index-source"
},
"dest": {
"index": "index-dest"
},
"analysis": {
"regression": {
"dependent_variable": "foo",
"gamma": -1.0
}
}
}
---
"Test put regression given eta less than 1e-3":
- do:
catch: /\[eta\] must be a double in \[0.001, 1\]/
ml.put_data_frame_analytics:
id: "regression-eta-greater-less-than-valid"
body: >
{
"source": {
"index": "index-source"
},
"dest": {
"index": "index-dest"
},
"analysis": {
"regression": {
"dependent_variable": "foo",
"eta": 0.0009
}
}
}
---
"Test put regression given eta greater than one":
- do:
catch: /\[eta\] must be a double in \[0.001, 1\]/
ml.put_data_frame_analytics:
id: "regression-eta-greater-than-one"
body: >
{
"source": {
"index": "index-source"
},
"dest": {
"index": "index-dest"
},
"analysis": {
"regression": {
"dependent_variable": "foo",
"eta": 1.00001
}
}
}
---
"Test put regression given maximum_number_trees is zero":
- do:
catch: /\[maximum_number_trees\] must be an integer in \[1, 2000\]/
ml.put_data_frame_analytics:
id: "regression-maximum-number-trees-is-zero"
body: >
{
"source": {
"index": "index-source"
},
"dest": {
"index": "index-dest"
},
"analysis": {
"regression": {
"dependent_variable": "foo",
"maximum_number_trees": 0
}
}
}
---
"Test put regression given maximum_number_trees is greater than 2k":
- do:
catch: /\[maximum_number_trees\] must be an integer in \[1, 2000\]/
ml.put_data_frame_analytics:
id: "regression-maximum-number-trees-greater-than-2k"
body: >
{
"source": {
"index": "index-source"
},
"dest": {
"index": "index-dest"
},
"analysis": {
"regression": {
"dependent_variable": "foo",
"maximum_number_trees": 2001
}
}
}
---
"Test put regression given feature_bag_fraction is negative":
- do:
catch: /\[feature_bag_fraction\] must be a double in \(0, 1\]/
ml.put_data_frame_analytics:
id: "regression-feature-bag-fraction-is-negative"
body: >
{
"source": {
"index": "index-source"
},
"dest": {
"index": "index-dest"
},
"analysis": {
"regression": {
"dependent_variable": "foo",
"feature_bag_fraction": -0.0001
}
}
}
---
"Test put regression given feature_bag_fraction is greater than one":
- do:
catch: /\[feature_bag_fraction\] must be a double in \(0, 1\]/
ml.put_data_frame_analytics:
id: "regression-feature-bag-fraction-is-greater-than-one"
body: >
{
"source": {
"index": "index-source"
},
"dest": {
"index": "index-dest"
},
"analysis": {
"regression": {
"dependent_variable": "foo",
"feature_bag_fraction": 1.0001
}
}
}
---
"Test put regression given valid":
- do:
ml.put_data_frame_analytics:
id: "valid-regression"
body: >
{
"source": {
"index": "index-source"
},
"dest": {
"index": "index-dest"
},
"analysis": {
"regression": {
"dependent_variable": "foo",
"lambda": 3.14,
"gamma": 0.42,
"eta": 0.5,
"maximum_number_trees": 400,
"feature_bag_fraction": 0.3
}
}
}
- match: { id: "valid-regression" }
- match: { source.index: ["index-source"] }
- match: { dest.index: "index-dest" }
- match: { analysis: {
"regression":{
"dependent_variable": "foo",
"lambda": 3.14,
"gamma": 0.42,
"eta": 0.5,
"maximum_number_trees": 400,
"feature_bag_fraction": 0.3
}
}}
- is_true: create_time
- is_true: version

View File

@ -11,15 +11,22 @@ import org.elasticsearch.client.Response;
import org.elasticsearch.client.ResponseException;
import org.elasticsearch.client.WarningFailureException;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.common.xcontent.support.XContentMapValues;
import org.elasticsearch.upgrades.AbstractFullClusterRestartTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig;
import org.elasticsearch.xpack.core.ml.job.config.DataDescription;
import org.elasticsearch.xpack.core.ml.job.config.Detector;
import org.elasticsearch.xpack.core.ml.job.config.Job;
import org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings;
import org.elasticsearch.xpack.test.rest.XPackRestTestConstants;
import org.elasticsearch.xpack.test.rest.XPackRestTestHelper;
import org.junit.Before;
@ -28,12 +35,12 @@ import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Base64;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.nullValue;
public class MlConfigIndexMappingsFullClusterRestartIT extends AbstractFullClusterRestartTestCase {
@ -41,14 +48,23 @@ public class MlConfigIndexMappingsFullClusterRestartIT extends AbstractFullClust
private static final String OLD_CLUSTER_JOB_ID = "ml-config-mappings-old-cluster-job";
private static final String NEW_CLUSTER_JOB_ID = "ml-config-mappings-new-cluster-job";
private static final Map<String, Object> EXPECTED_DATA_FRAME_ANALYSIS_MAPPINGS =
mapOf(
"properties", mapOf(
"outlier_detection", mapOf(
"properties", mapOf(
"method", mapOf("type", "keyword"),
"n_neighbors", mapOf("type", "integer"),
"feature_influence_threshold", mapOf("type", "double")))));
private static final Map<String, Object> EXPECTED_DATA_FRAME_ANALYSIS_MAPPINGS = getDataFrameAnalysisMappings();
@SuppressWarnings("unchecked")
private static Map<String, Object> getDataFrameAnalysisMappings() {
try (XContentBuilder builder = JsonXContent.contentBuilder()) {
builder.startObject();
ElasticsearchMappings.addDataFrameAnalyticsFields(builder);
builder.endObject();
Map<String, Object> asMap = builder.generator().contentType().xContent().createParser(
NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, BytesReference.bytes(builder).streamInput()).map();
return (Map<String, Object>) asMap.get(DataFrameAnalyticsConfig.ANALYSIS.getPreferredName());
} catch (IOException e) {
fail("Failed to initialize expected data frame analysis mappings");
}
return null;
}
@Override
protected Settings restClientSettings() {
@ -71,8 +87,8 @@ public class MlConfigIndexMappingsFullClusterRestartIT extends AbstractFullClust
// trigger .ml-config index creation
createAnomalyDetectorJob(OLD_CLUSTER_JOB_ID);
if (getOldClusterVersion().onOrAfter(Version.V_7_3_0)) {
// .ml-config has correct mappings from the start
assertThat(mappingsForDataFrameAnalysis(), is(equalTo(EXPECTED_DATA_FRAME_ANALYSIS_MAPPINGS)));
// .ml-config has mappings for analytics as the feature was introduced in 7.3.0
assertThat(mappingsForDataFrameAnalysis(), is(notNullValue()));
} else {
// .ml-config does not yet have correct mappings, it will need an update after cluster is upgraded
assertThat(mappingsForDataFrameAnalysis(), is(nullValue()));
@ -125,18 +141,4 @@ public class MlConfigIndexMappingsFullClusterRestartIT extends AbstractFullClust
mappings = (Map<String, Object>) XContentMapValues.extractValue(mappings, "properties", "analysis");
return mappings;
}
private static <K, V> Map<K, V> mapOf(K k1, V v1) {
Map<K, V> map = new HashMap<>();
map.put(k1, v1);
return map;
}
private static <K, V> Map<K, V> mapOf(K k1, V v1, K k2, V v2, K k3, V v3) {
Map<K, V> map = new HashMap<>();
map.put(k1, v1);
map.put(k2, v2);
map.put(k3, v3);
return map;
}
}