This commit adds a first draft of a regression analysis to data frame analytics. There is high probability that the exact syntax might change. This commit adds the new analysis type and its parameters as well as appropriate validation. It also modifies the extractor and the fields detector to be able to handle categorical fields as regression analysis supports them.
This commit is contained in:
parent
d1ed9bdbfd
commit
27497ff75f
|
@ -147,6 +147,7 @@ import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState;
|
|||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.AucRoc;
|
||||
|
@ -454,6 +455,7 @@ public class XPackClientPlugin extends Plugin implements ActionPlugin, NetworkPl
|
|||
MachineLearningFeatureSetUsage::new),
|
||||
// ML - Data frame analytics
|
||||
new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, OutlierDetection.NAME.getPreferredName(), OutlierDetection::new),
|
||||
new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, Regression.NAME.getPreferredName(), Regression::new),
|
||||
// ML - Data frame evaluation
|
||||
new NamedWriteableRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME.getPreferredName(),
|
||||
BinarySoftClassification::new),
|
||||
|
|
|
@ -9,8 +9,22 @@ import org.elasticsearch.common.io.stream.NamedWriteable;
|
|||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable {
|
||||
|
||||
/**
|
||||
* @return The analysis parameters as a map
|
||||
*/
|
||||
Map<String, Object> getParams();
|
||||
|
||||
/**
|
||||
* @return {@code true} if this analysis supports fields with categorical values (i.e. text, keyword, ip)
|
||||
*/
|
||||
boolean supportsCategoricalFields();
|
||||
|
||||
/**
|
||||
* @return The set of fields that analyzed documents must have for the analysis to operate
|
||||
*/
|
||||
Set<String> getRequiredFields();
|
||||
}
|
||||
|
|
|
@ -22,6 +22,10 @@ public class MlDataFrameAnalysisNamedXContentProvider implements NamedXContentPr
|
|||
boolean ignoreUnknownFields = (boolean) c;
|
||||
return OutlierDetection.fromXContent(p, ignoreUnknownFields);
|
||||
}));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(DataFrameAnalysis.class, Regression.NAME, (p, c) -> {
|
||||
boolean ignoreUnknownFields = (boolean) c;
|
||||
return Regression.fromXContent(p, ignoreUnknownFields);
|
||||
}));
|
||||
|
||||
return namedXContent;
|
||||
}
|
||||
|
@ -31,6 +35,8 @@ public class MlDataFrameAnalysisNamedXContentProvider implements NamedXContentPr
|
|||
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, OutlierDetection.NAME.getPreferredName(),
|
||||
OutlierDetection::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, Regression.NAME.getPreferredName(),
|
||||
Regression::new));
|
||||
|
||||
return namedWriteables;
|
||||
}
|
||||
|
|
|
@ -16,10 +16,12 @@ import org.elasticsearch.common.xcontent.XContentParser;
|
|||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.Locale;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
|
||||
public class OutlierDetection implements DataFrameAnalysis {
|
||||
|
||||
|
@ -152,6 +154,16 @@ public class OutlierDetection implements DataFrameAnalysis {
|
|||
return params;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean supportsCategoricalFields() {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Set<String> getRequiredFields() {
|
||||
return Collections.emptySet();
|
||||
}
|
||||
|
||||
public enum Method {
|
||||
LOF, LDOF, DISTANCE_KTH_NN, DISTANCE_KNN;
|
||||
|
||||
|
|
|
@ -0,0 +1,205 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.analyses;
|
||||
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
|
||||
public class Regression implements DataFrameAnalysis {
|
||||
|
||||
public static final ParseField NAME = new ParseField("regression");
|
||||
|
||||
public static final ParseField DEPENDENT_VARIABLE = new ParseField("dependent_variable");
|
||||
public static final ParseField LAMBDA = new ParseField("lambda");
|
||||
public static final ParseField GAMMA = new ParseField("gamma");
|
||||
public static final ParseField ETA = new ParseField("eta");
|
||||
public static final ParseField MAXIMUM_NUMBER_TREES = new ParseField("maximum_number_trees");
|
||||
public static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction");
|
||||
public static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
|
||||
|
||||
private static final ConstructingObjectParser<Regression, Void> LENIENT_PARSER = createParser(true);
|
||||
private static final ConstructingObjectParser<Regression, Void> STRICT_PARSER = createParser(false);
|
||||
|
||||
private static ConstructingObjectParser<Regression, Void> createParser(boolean lenient) {
|
||||
ConstructingObjectParser<Regression, Void> parser = new ConstructingObjectParser<>(NAME.getPreferredName(), lenient,
|
||||
a -> new Regression((String) a[0], (Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5], (String) a[6]));
|
||||
parser.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE);
|
||||
parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), LAMBDA);
|
||||
parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), GAMMA);
|
||||
parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), ETA);
|
||||
parser.declareInt(ConstructingObjectParser.optionalConstructorArg(), MAXIMUM_NUMBER_TREES);
|
||||
parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), FEATURE_BAG_FRACTION);
|
||||
parser.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME);
|
||||
return parser;
|
||||
}
|
||||
|
||||
public static Regression fromXContent(XContentParser parser, boolean ignoreUnknownFields) {
|
||||
return ignoreUnknownFields ? LENIENT_PARSER.apply(parser, null) : STRICT_PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
private final String dependentVariable;
|
||||
private final Double lambda;
|
||||
private final Double gamma;
|
||||
private final Double eta;
|
||||
private final Integer maximumNumberTrees;
|
||||
private final Double featureBagFraction;
|
||||
private final String predictionFieldName;
|
||||
|
||||
public Regression(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
|
||||
@Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, @Nullable String predictionFieldName) {
|
||||
this.dependentVariable = Objects.requireNonNull(dependentVariable);
|
||||
|
||||
if (lambda != null && lambda < 0) {
|
||||
throw ExceptionsHelper.badRequestException("[{}] must be a non-negative double", LAMBDA.getPreferredName());
|
||||
}
|
||||
this.lambda = lambda;
|
||||
|
||||
if (gamma != null && gamma < 0) {
|
||||
throw ExceptionsHelper.badRequestException("[{}] must be a non-negative double", GAMMA.getPreferredName());
|
||||
}
|
||||
this.gamma = gamma;
|
||||
|
||||
if (eta != null && (eta < 0.001 || eta > 1)) {
|
||||
throw ExceptionsHelper.badRequestException("[{}] must be a double in [0.001, 1]", ETA.getPreferredName());
|
||||
}
|
||||
this.eta = eta;
|
||||
|
||||
if (maximumNumberTrees != null && (maximumNumberTrees <= 0 || maximumNumberTrees > 2000)) {
|
||||
throw ExceptionsHelper.badRequestException("[{}] must be an integer in [1, 2000]", MAXIMUM_NUMBER_TREES.getPreferredName());
|
||||
}
|
||||
this.maximumNumberTrees = maximumNumberTrees;
|
||||
|
||||
if (featureBagFraction != null && (featureBagFraction <= 0 || featureBagFraction > 1.0)) {
|
||||
throw ExceptionsHelper.badRequestException("[{}] must be a double in (0, 1]", FEATURE_BAG_FRACTION.getPreferredName());
|
||||
}
|
||||
this.featureBagFraction = featureBagFraction;
|
||||
|
||||
this.predictionFieldName = predictionFieldName;
|
||||
}
|
||||
|
||||
public Regression(String dependentVariable) {
|
||||
this(dependentVariable, null, null, null, null, null, null);
|
||||
}
|
||||
|
||||
public Regression(StreamInput in) throws IOException {
|
||||
dependentVariable = in.readString();
|
||||
lambda = in.readOptionalDouble();
|
||||
gamma = in.readOptionalDouble();
|
||||
eta = in.readOptionalDouble();
|
||||
maximumNumberTrees = in.readOptionalVInt();
|
||||
featureBagFraction = in.readOptionalDouble();
|
||||
predictionFieldName = in.readOptionalString();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeString(dependentVariable);
|
||||
out.writeOptionalDouble(lambda);
|
||||
out.writeOptionalDouble(gamma);
|
||||
out.writeOptionalDouble(eta);
|
||||
out.writeOptionalVInt(maximumNumberTrees);
|
||||
out.writeOptionalDouble(featureBagFraction);
|
||||
out.writeOptionalString(predictionFieldName);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable);
|
||||
if (lambda != null) {
|
||||
builder.field(LAMBDA.getPreferredName(), lambda);
|
||||
}
|
||||
if (gamma != null) {
|
||||
builder.field(GAMMA.getPreferredName(), gamma);
|
||||
}
|
||||
if (eta != null) {
|
||||
builder.field(ETA.getPreferredName(), eta);
|
||||
}
|
||||
if (maximumNumberTrees != null) {
|
||||
builder.field(MAXIMUM_NUMBER_TREES.getPreferredName(), maximumNumberTrees);
|
||||
}
|
||||
if (featureBagFraction != null) {
|
||||
builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction);
|
||||
}
|
||||
if (predictionFieldName != null) {
|
||||
builder.field(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, Object> getParams() {
|
||||
Map<String, Object> params = new HashMap<>();
|
||||
params.put(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable);
|
||||
if (lambda != null) {
|
||||
params.put(LAMBDA.getPreferredName(), lambda);
|
||||
}
|
||||
if (gamma != null) {
|
||||
params.put(GAMMA.getPreferredName(), gamma);
|
||||
}
|
||||
if (eta != null) {
|
||||
params.put(ETA.getPreferredName(), eta);
|
||||
}
|
||||
if (maximumNumberTrees != null) {
|
||||
params.put(MAXIMUM_NUMBER_TREES.getPreferredName(), maximumNumberTrees);
|
||||
}
|
||||
if (featureBagFraction != null) {
|
||||
params.put(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction);
|
||||
}
|
||||
if (predictionFieldName != null) {
|
||||
params.put(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
|
||||
}
|
||||
return params;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean supportsCategoricalFields() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Set<String> getRequiredFields() {
|
||||
return Collections.singleton(dependentVariable);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
Regression that = (Regression) o;
|
||||
return Objects.equals(dependentVariable, that.dependentVariable)
|
||||
&& Objects.equals(lambda, that.lambda)
|
||||
&& Objects.equals(gamma, that.gamma)
|
||||
&& Objects.equals(eta, that.eta)
|
||||
&& Objects.equals(maximumNumberTrees, that.maximumNumberTrees)
|
||||
&& Objects.equals(featureBagFraction, that.featureBagFraction)
|
||||
&& Objects.equals(predictionFieldName, that.predictionFieldName);
|
||||
}
|
||||
}
|
|
@ -31,6 +31,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
|||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||
import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig;
|
||||
import org.elasticsearch.xpack.core.ml.job.config.AnalysisLimits;
|
||||
import org.elasticsearch.xpack.core.ml.job.config.DataDescription;
|
||||
|
@ -443,6 +444,31 @@ public class ElasticsearchMappings {
|
|||
.endObject()
|
||||
.endObject()
|
||||
.endObject()
|
||||
.startObject(Regression.NAME.getPreferredName())
|
||||
.startObject(PROPERTIES)
|
||||
.startObject(Regression.DEPENDENT_VARIABLE.getPreferredName())
|
||||
.field(TYPE, KEYWORD)
|
||||
.endObject()
|
||||
.startObject(Regression.LAMBDA.getPreferredName())
|
||||
.field(TYPE, DOUBLE)
|
||||
.endObject()
|
||||
.startObject(Regression.GAMMA.getPreferredName())
|
||||
.field(TYPE, DOUBLE)
|
||||
.endObject()
|
||||
.startObject(Regression.ETA.getPreferredName())
|
||||
.field(TYPE, DOUBLE)
|
||||
.endObject()
|
||||
.startObject(Regression.MAXIMUM_NUMBER_TREES.getPreferredName())
|
||||
.field(TYPE, INTEGER)
|
||||
.endObject()
|
||||
.startObject(Regression.FEATURE_BAG_FRACTION.getPreferredName())
|
||||
.field(TYPE, DOUBLE)
|
||||
.endObject()
|
||||
.startObject(Regression.PREDICTION_FIELD_NAME.getPreferredName())
|
||||
.field(TYPE, KEYWORD)
|
||||
.endObject()
|
||||
.endObject()
|
||||
.endObject()
|
||||
.endObject()
|
||||
.endObject()
|
||||
// re-used: CREATE_TIME
|
||||
|
|
|
@ -14,6 +14,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
|||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||
import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig;
|
||||
import org.elasticsearch.xpack.core.ml.job.config.AnalysisLimits;
|
||||
import org.elasticsearch.xpack.core.ml.job.config.DataDescription;
|
||||
|
@ -299,6 +300,14 @@ public final class ReservedFieldNames {
|
|||
OutlierDetection.N_NEIGHBORS.getPreferredName(),
|
||||
OutlierDetection.METHOD.getPreferredName(),
|
||||
OutlierDetection.FEATURE_INFLUENCE_THRESHOLD.getPreferredName(),
|
||||
Regression.NAME.getPreferredName(),
|
||||
Regression.DEPENDENT_VARIABLE.getPreferredName(),
|
||||
Regression.LAMBDA.getPreferredName(),
|
||||
Regression.GAMMA.getPreferredName(),
|
||||
Regression.ETA.getPreferredName(),
|
||||
Regression.MAXIMUM_NUMBER_TREES.getPreferredName(),
|
||||
Regression.FEATURE_BAG_FRACTION.getPreferredName(),
|
||||
Regression.PREDICTION_FIELD_NAME.getPreferredName(),
|
||||
|
||||
ElasticsearchMappings.CONFIG_TYPE,
|
||||
|
||||
|
|
|
@ -0,0 +1,100 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.analyses;
|
||||
|
||||
import org.elasticsearch.ElasticsearchStatusException;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public class RegressionTests extends AbstractSerializingTestCase<Regression> {
|
||||
|
||||
@Override
|
||||
protected Regression doParseInstance(XContentParser parser) throws IOException {
|
||||
return Regression.fromXContent(parser, false);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Regression createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
public static Regression createRandom() {
|
||||
Double lambda = randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, true);
|
||||
Double gamma = randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, true);
|
||||
Double eta = randomBoolean() ? null : randomDoubleBetween(0.001, 1.0, true);
|
||||
Integer maximumNumberTrees = randomBoolean() ? null : randomIntBetween(1, 2000);
|
||||
Double featureBagFraction = randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false);
|
||||
String predictionFieldName = randomBoolean() ? null : randomAlphaOfLength(10);
|
||||
return new Regression(randomAlphaOfLength(10), lambda, gamma, eta, maximumNumberTrees, featureBagFraction,
|
||||
predictionFieldName);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<Regression> instanceReader() {
|
||||
return Regression::new;
|
||||
}
|
||||
|
||||
public void testRegression_GivenNegativeLambda() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Regression("foo", -0.00001, 0.0, 0.5, 500, 0.3, "result"));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[lambda] must be a non-negative double"));
|
||||
}
|
||||
|
||||
public void testRegression_GivenNegativeGamma() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Regression("foo", 0.0, -0.00001, 0.5, 500, 0.3, "result"));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[gamma] must be a non-negative double"));
|
||||
}
|
||||
|
||||
public void testRegression_GivenEtaIsZero() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Regression("foo", 0.0, 0.0, 0.0, 500, 0.3, "result"));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[eta] must be a double in [0.001, 1]"));
|
||||
}
|
||||
|
||||
public void testRegression_GivenEtaIsGreaterThanOne() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Regression("foo", 0.0, 0.0, 1.00001, 500, 0.3, "result"));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[eta] must be a double in [0.001, 1]"));
|
||||
}
|
||||
|
||||
public void testRegression_GivenMaximumNumberTreesIsZero() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Regression("foo", 0.0, 0.0, 0.5, 0, 0.3, "result"));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[maximum_number_trees] must be an integer in [1, 2000]"));
|
||||
}
|
||||
|
||||
public void testRegression_GivenMaximumNumberTreesIsGreaterThan2k() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Regression("foo", 0.0, 0.0, 0.5, 2001, 0.3, "result"));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[maximum_number_trees] must be an integer in [1, 2000]"));
|
||||
}
|
||||
|
||||
public void testRegression_GivenFeatureBagFractionIsLessThanZero() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Regression("foo", 0.0, 0.0, 0.5, 500, -0.00001, "result"));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[feature_bag_fraction] must be a double in (0, 1]"));
|
||||
}
|
||||
|
||||
public void testRegression_GivenFeatureBagFractionIsGreaterThanOne() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Regression("foo", 0.0, 0.0, 0.5, 500, 1.00001, "result"));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[feature_bag_fraction] must be a double in (0, 1]"));
|
||||
}
|
||||
}
|
|
@ -69,6 +69,15 @@ integTest.runner {
|
|||
'ml/data_frame_analytics_crud/Test get stats given expression without matches and allow_no_match is false',
|
||||
'ml/data_frame_analytics_crud/Test delete given missing config',
|
||||
'ml/data_frame_analytics_crud/Test max model memory limit',
|
||||
'ml/data_frame_analytics_crud/Test put regression given dependent_variable is not defined',
|
||||
'ml/data_frame_analytics_crud/Test put regression given negative lambda',
|
||||
'ml/data_frame_analytics_crud/Test put regression given negative gamma',
|
||||
'ml/data_frame_analytics_crud/Test put regression given eta less than 1e-3',
|
||||
'ml/data_frame_analytics_crud/Test put regression given eta greater than one',
|
||||
'ml/data_frame_analytics_crud/Test put regression given maximum_number_trees is zero',
|
||||
'ml/data_frame_analytics_crud/Test put regression given maximum_number_trees is greater than 2k',
|
||||
'ml/data_frame_analytics_crud/Test put regression given feature_bag_fraction is negative',
|
||||
'ml/data_frame_analytics_crud/Test put regression given feature_bag_fraction is greater than one',
|
||||
'ml/evaluate_data_frame/Test given missing index',
|
||||
'ml/evaluate_data_frame/Test given index does not exist',
|
||||
'ml/evaluate_data_frame/Test given missing evaluation',
|
||||
|
|
|
@ -21,6 +21,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest;
|
|||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
|
@ -118,4 +119,13 @@ abstract class MlNativeDataFrameAnalyticsIntegTestCase extends MlNativeIntegTest
|
|||
assertThat(stats.get(0).getId(), equalTo(id));
|
||||
assertThat(stats.get(0).getState(), equalTo(state));
|
||||
}
|
||||
|
||||
protected static DataFrameAnalyticsConfig buildRegressionAnalytics(String id, String[] sourceIndex, String destIndex,
|
||||
@Nullable String resultsField, String dependentVariable) {
|
||||
DataFrameAnalyticsConfig.Builder configBuilder = new DataFrameAnalyticsConfig.Builder(id);
|
||||
configBuilder.setSource(new DataFrameAnalyticsSource(sourceIndex, null));
|
||||
configBuilder.setDest(new DataFrameAnalyticsDest(destIndex, resultsField));
|
||||
configBuilder.setAnalysis(new Regression(dependentVariable));
|
||||
return configBuilder.build();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,9 +21,12 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
|||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
|
||||
import org.junit.After;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.hamcrest.Matchers.allOf;
|
||||
import static org.hamcrest.Matchers.closeTo;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.greaterThan;
|
||||
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
|
||||
|
@ -362,4 +365,68 @@ public class RunDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsIntegTest
|
|||
.setQuery(QueryBuilders.existsQuery("ml.outlier_score")).get();
|
||||
assertThat(searchResponse.getHits().getTotalHits().value, equalTo((long) bulkRequestBuilder.numberOfActions()));
|
||||
}
|
||||
|
||||
public void testRegressionWithNumericFeatureAndFewDocuments() throws Exception {
|
||||
String sourceIndex = "test-regression-with-numeric-feature-and-few-docs";
|
||||
|
||||
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk();
|
||||
bulkRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
|
||||
|
||||
List<Double> featureValues = Arrays.asList(1.0, 2.0, 3.0);
|
||||
List<Double> dependentVariableValues = Arrays.asList(10.0, 20.0, 30.0);
|
||||
|
||||
for (int i = 0; i < 350; i++) {
|
||||
Double field = featureValues.get(i % 3);
|
||||
Double value = dependentVariableValues.get(i % 3);
|
||||
|
||||
IndexRequest indexRequest = new IndexRequest(sourceIndex);
|
||||
if (i < 300) {
|
||||
indexRequest.source("feature", field, "variable", value);
|
||||
} else {
|
||||
indexRequest.source("feature", field);
|
||||
}
|
||||
bulkRequestBuilder.add(indexRequest);
|
||||
}
|
||||
BulkResponse bulkResponse = bulkRequestBuilder.get();
|
||||
if (bulkResponse.hasFailures()) {
|
||||
fail("Failed to index data: " + bulkResponse.buildFailureMessage());
|
||||
}
|
||||
|
||||
String id = "test_regression_with_numeric_feature_and_few_docs";
|
||||
DataFrameAnalyticsConfig config = buildRegressionAnalytics(id, new String[] {sourceIndex},
|
||||
sourceIndex + "-results", null, "variable");
|
||||
registerAnalytics(config);
|
||||
putAnalytics(config);
|
||||
|
||||
assertState(id, DataFrameAnalyticsState.STOPPED);
|
||||
|
||||
startAnalytics(id);
|
||||
waitUntilAnalyticsIsStopped(id);
|
||||
|
||||
int resultsWithPrediction = 0;
|
||||
SearchResponse sourceData = client().prepareSearch(sourceIndex).get();
|
||||
for (SearchHit hit : sourceData.getHits()) {
|
||||
GetResponse destDocGetResponse = client().prepareGet().setIndex(config.getDest().getIndex()).setId(hit.getId()).get();
|
||||
assertThat(destDocGetResponse.isExists(), is(true));
|
||||
Map<String, Object> sourceDoc = hit.getSourceAsMap();
|
||||
Map<String, Object> destDoc = destDocGetResponse.getSource();
|
||||
for (String field : sourceDoc.keySet()) {
|
||||
assertThat(destDoc.containsKey(field), is(true));
|
||||
assertThat(destDoc.get(field), equalTo(sourceDoc.get(field)));
|
||||
}
|
||||
assertThat(destDoc.containsKey("ml"), is(true));
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
Map<String, Object> resultsObject = (Map<String, Object>) destDoc.get("ml");
|
||||
|
||||
if (resultsObject.containsKey("variable_prediction")) {
|
||||
resultsWithPrediction++;
|
||||
double featureValue = (double) destDoc.get("feature");
|
||||
double predictionValue = (double) resultsObject.get("variable_prediction");
|
||||
// it seems for this case values can be as far off as 2.0
|
||||
assertThat(predictionValue, closeTo(10 * featureValue, 2.0));
|
||||
}
|
||||
}
|
||||
assertThat(resultsWithPrediction, greaterThan(0));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,10 +16,12 @@ import org.elasticsearch.search.SearchHit;
|
|||
import java.io.IOException;
|
||||
import java.text.ParseException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* Represents a field to be extracted by the datafeed.
|
||||
|
@ -37,11 +39,14 @@ public abstract class ExtractedField {
|
|||
/** The name of the field we extract */
|
||||
protected final String name;
|
||||
|
||||
private final Set<String> types;
|
||||
|
||||
private final ExtractionMethod extractionMethod;
|
||||
|
||||
protected ExtractedField(String alias, String name, ExtractionMethod extractionMethod) {
|
||||
protected ExtractedField(String alias, String name, Set<String> types, ExtractionMethod extractionMethod) {
|
||||
this.alias = Objects.requireNonNull(alias);
|
||||
this.name = Objects.requireNonNull(name);
|
||||
this.types = Objects.requireNonNull(types);
|
||||
this.extractionMethod = Objects.requireNonNull(extractionMethod);
|
||||
}
|
||||
|
||||
|
@ -53,6 +58,10 @@ public abstract class ExtractedField {
|
|||
return name;
|
||||
}
|
||||
|
||||
public Set<String> getTypes() {
|
||||
return types;
|
||||
}
|
||||
|
||||
public ExtractionMethod getExtractionMethod() {
|
||||
return extractionMethod;
|
||||
}
|
||||
|
@ -65,32 +74,32 @@ public abstract class ExtractedField {
|
|||
return null;
|
||||
}
|
||||
|
||||
public static ExtractedField newTimeField(String name, ExtractionMethod extractionMethod) {
|
||||
public static ExtractedField newTimeField(String name, Set<String> types, ExtractionMethod extractionMethod) {
|
||||
if (extractionMethod == ExtractionMethod.SOURCE) {
|
||||
throw new IllegalArgumentException("time field cannot be extracted from source");
|
||||
}
|
||||
return new TimeField(name, extractionMethod);
|
||||
return new TimeField(name, types, extractionMethod);
|
||||
}
|
||||
|
||||
public static ExtractedField newGeoShapeField(String alias, String name) {
|
||||
return new GeoShapeField(alias, name);
|
||||
return new GeoShapeField(alias, name, Collections.singleton("geo_shape"));
|
||||
}
|
||||
|
||||
public static ExtractedField newGeoPointField(String alias, String name) {
|
||||
return new GeoPointField(alias, name);
|
||||
return new GeoPointField(alias, name, Collections.singleton("geo_point"));
|
||||
}
|
||||
|
||||
public static ExtractedField newField(String name, ExtractionMethod extractionMethod) {
|
||||
return newField(name, name, extractionMethod);
|
||||
public static ExtractedField newField(String name, Set<String> types, ExtractionMethod extractionMethod) {
|
||||
return newField(name, name, types, extractionMethod);
|
||||
}
|
||||
|
||||
public static ExtractedField newField(String alias, String name, ExtractionMethod extractionMethod) {
|
||||
public static ExtractedField newField(String alias, String name, Set<String> types, ExtractionMethod extractionMethod) {
|
||||
switch (extractionMethod) {
|
||||
case DOC_VALUE:
|
||||
case SCRIPT_FIELD:
|
||||
return new FromFields(alias, name, extractionMethod);
|
||||
return new FromFields(alias, name, types, extractionMethod);
|
||||
case SOURCE:
|
||||
return new FromSource(alias, name);
|
||||
return new FromSource(alias, name, types);
|
||||
default:
|
||||
throw new IllegalArgumentException("Invalid extraction method [" + extractionMethod + "]");
|
||||
}
|
||||
|
@ -98,7 +107,7 @@ public abstract class ExtractedField {
|
|||
|
||||
public ExtractedField newFromSource() {
|
||||
if (supportsFromSource()) {
|
||||
return new FromSource(alias, name);
|
||||
return new FromSource(alias, name, types);
|
||||
}
|
||||
throw new IllegalStateException("Field (alias [" + alias + "], name [" + name + "]) should be extracted via ["
|
||||
+ extractionMethod + "] and cannot be extracted from source");
|
||||
|
@ -106,8 +115,8 @@ public abstract class ExtractedField {
|
|||
|
||||
private static class FromFields extends ExtractedField {
|
||||
|
||||
FromFields(String alias, String name, ExtractionMethod extractionMethod) {
|
||||
super(alias, name, extractionMethod);
|
||||
FromFields(String alias, String name, Set<String> types, ExtractionMethod extractionMethod) {
|
||||
super(alias, name, types, extractionMethod);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -129,8 +138,8 @@ public abstract class ExtractedField {
|
|||
private static class GeoShapeField extends FromSource {
|
||||
private static final WellKnownText wkt = new WellKnownText(true, new StandardValidator(true));
|
||||
|
||||
GeoShapeField(String alias, String name) {
|
||||
super(alias, name);
|
||||
GeoShapeField(String alias, String name, Set<String> types) {
|
||||
super(alias, name, types);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -186,8 +195,8 @@ public abstract class ExtractedField {
|
|||
|
||||
private static class GeoPointField extends FromFields {
|
||||
|
||||
GeoPointField(String alias, String name) {
|
||||
super(alias, name, ExtractionMethod.DOC_VALUE);
|
||||
GeoPointField(String alias, String name, Set<String> types) {
|
||||
super(alias, name, types, ExtractionMethod.DOC_VALUE);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -222,8 +231,8 @@ public abstract class ExtractedField {
|
|||
|
||||
private static final String EPOCH_MILLIS_FORMAT = "epoch_millis";
|
||||
|
||||
TimeField(String name, ExtractionMethod extractionMethod) {
|
||||
super(name, name, extractionMethod);
|
||||
TimeField(String name, Set<String> types, ExtractionMethod extractionMethod) {
|
||||
super(name, name, types, extractionMethod);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -255,8 +264,8 @@ public abstract class ExtractedField {
|
|||
|
||||
private String[] namePath;
|
||||
|
||||
FromSource(String alias, String name) {
|
||||
super(alias, name, ExtractionMethod.SOURCE);
|
||||
FromSource(String alias, String name, Set<String> types) {
|
||||
super(alias, name, types, ExtractionMethod.SOURCE);
|
||||
namePath = name.split("\\.");
|
||||
}
|
||||
|
||||
|
|
|
@ -47,15 +47,6 @@ public class ExtractedFields {
|
|||
return docValueFields;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a new instance which only contains fields matching the given extraction method
|
||||
* @param method the extraction method to filter fields on
|
||||
* @return a new instance which only contains fields matching the given extraction method
|
||||
*/
|
||||
public ExtractedFields filterFields(ExtractedField.ExtractionMethod method) {
|
||||
return new ExtractedFields(filterFields(method, allFields));
|
||||
}
|
||||
|
||||
private static List<ExtractedField> filterFields(ExtractedField.ExtractionMethod method, List<ExtractedField> fields) {
|
||||
return fields.stream().filter(field -> field.getExtractionMethod() == method).collect(Collectors.toList());
|
||||
}
|
||||
|
@ -79,12 +70,13 @@ public class ExtractedFields {
|
|||
protected ExtractedField detect(String field) {
|
||||
String internalField = field;
|
||||
ExtractedField.ExtractionMethod method = ExtractedField.ExtractionMethod.SOURCE;
|
||||
Set<String> types = getTypes(field);
|
||||
if (scriptFields.contains(field)) {
|
||||
method = ExtractedField.ExtractionMethod.SCRIPT_FIELD;
|
||||
} else if (isAggregatable(field)) {
|
||||
method = ExtractedField.ExtractionMethod.DOC_VALUE;
|
||||
if (isFieldOfType(field, "date")) {
|
||||
return ExtractedField.newTimeField(field, method);
|
||||
return ExtractedField.newTimeField(field, types, method);
|
||||
}
|
||||
} else if (isFieldOfType(field, TEXT)) {
|
||||
String parentField = MlStrings.getParentField(field);
|
||||
|
@ -107,7 +99,12 @@ public class ExtractedFields {
|
|||
return ExtractedField.newGeoShapeField(field, internalField);
|
||||
}
|
||||
|
||||
return ExtractedField.newField(field, internalField, method);
|
||||
return ExtractedField.newField(field, internalField, types, method);
|
||||
}
|
||||
|
||||
private Set<String> getTypes(String field) {
|
||||
Map<String, FieldCapabilities> fieldCaps = fieldsCapabilities.getField(field);
|
||||
return fieldCaps == null ? Collections.emptySet() : fieldCaps.keySet();
|
||||
}
|
||||
|
||||
protected boolean isAggregatable(String field) {
|
||||
|
|
|
@ -12,6 +12,7 @@ import org.elasticsearch.xpack.core.ml.job.config.Job;
|
|||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
|
@ -55,12 +56,20 @@ public class TimeBasedExtractedFields extends ExtractedFields {
|
|||
if (scriptFields.contains(timeField) == false && extractionMethodDetector.isAggregatable(timeField) == false) {
|
||||
throw new IllegalArgumentException("cannot retrieve time field [" + timeField + "] because it is not aggregatable");
|
||||
}
|
||||
ExtractedField timeExtractedField = ExtractedField.newTimeField(timeField, scriptFields.contains(timeField) ?
|
||||
ExtractedField.ExtractionMethod.SCRIPT_FIELD : ExtractedField.ExtractionMethod.DOC_VALUE);
|
||||
ExtractedField timeExtractedField = extractedTimeField(timeField, scriptFields, fieldsCapabilities);
|
||||
List<String> remainingFields = job.allInputFields().stream().filter(f -> !f.equals(timeField)).collect(Collectors.toList());
|
||||
List<ExtractedField> allExtractedFields = new ArrayList<>(remainingFields.size() + 1);
|
||||
allExtractedFields.add(timeExtractedField);
|
||||
remainingFields.stream().forEach(field -> allExtractedFields.add(extractionMethodDetector.detect(field)));
|
||||
return new TimeBasedExtractedFields(timeExtractedField, allExtractedFields);
|
||||
}
|
||||
|
||||
private static ExtractedField extractedTimeField(String timeField, Set<String> scriptFields,
|
||||
FieldCapabilitiesResponse fieldCapabilities) {
|
||||
if (scriptFields.contains(timeField)) {
|
||||
return ExtractedField.newTimeField(timeField, Collections.emptySet(), ExtractedField.ExtractionMethod.SCRIPT_FIELD);
|
||||
}
|
||||
return ExtractedField.newTimeField(timeField, fieldCapabilities.getField(timeField).keySet(),
|
||||
ExtractedField.ExtractionMethod.DOC_VALUE);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -29,11 +29,13 @@ import java.io.IOException;
|
|||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.NoSuchElementException;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.function.Supplier;
|
||||
import java.util.stream.Collectors;
|
||||
|
@ -179,7 +181,7 @@ public class DataFrameDataExtractor {
|
|||
for (int i = 0; i < extractedValues.length; ++i) {
|
||||
ExtractedField field = context.extractedFields.getAllFields().get(i);
|
||||
Object[] values = field.value(hit);
|
||||
if (values.length == 1 && values[0] instanceof Number) {
|
||||
if (values.length == 1 && (values[0] instanceof Number || values[0] instanceof String)) {
|
||||
extractedValues[i] = Objects.toString(values[0]);
|
||||
} else {
|
||||
extractedValues = null;
|
||||
|
@ -233,6 +235,17 @@ public class DataFrameDataExtractor {
|
|||
return new DataSummary(searchResponse.getHits().getTotalHits().value, context.extractedFields.getAllFields().size());
|
||||
}
|
||||
|
||||
public Set<String> getCategoricalFields() {
|
||||
Set<String> categoricalFields = new HashSet<>();
|
||||
for (ExtractedField extractedField : context.extractedFields.getAllFields()) {
|
||||
String fieldName = extractedField.getName();
|
||||
if (ExtractedFieldsDetector.CATEGORICAL_TYPES.containsAll(extractedField.getTypes())) {
|
||||
categoricalFields.add(fieldName);
|
||||
}
|
||||
}
|
||||
return categoricalFields;
|
||||
}
|
||||
|
||||
public static class DataSummary {
|
||||
|
||||
public final long rows;
|
||||
|
|
|
@ -5,6 +5,8 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.ml.dataframe.extractor;
|
||||
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
import org.elasticsearch.ResourceNotFoundException;
|
||||
import org.elasticsearch.action.fieldcaps.FieldCapabilities;
|
||||
import org.elasticsearch.action.fieldcaps.FieldCapabilitiesResponse;
|
||||
|
@ -20,6 +22,7 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
|||
import org.elasticsearch.xpack.core.ml.utils.NameResolver;
|
||||
import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedField;
|
||||
import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedFields;
|
||||
import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsIndex;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
|
@ -35,24 +38,24 @@ import java.util.stream.Stream;
|
|||
|
||||
public class ExtractedFieldsDetector {
|
||||
|
||||
private static final Logger LOGGER = LogManager.getLogger(ExtractedFieldsDetector.class);
|
||||
|
||||
/**
|
||||
* Fields to ignore. These are mostly internal meta fields.
|
||||
*/
|
||||
private static final List<String> IGNORE_FIELDS = Arrays.asList("_id", "_field_names", "_index", "_parent", "_routing", "_seq_no",
|
||||
"_source", "_type", "_uid", "_version", "_feature", "_ignored");
|
||||
"_source", "_type", "_uid", "_version", "_feature", "_ignored", DataFrameAnalyticsIndex.ID_COPY);
|
||||
|
||||
/**
|
||||
* The types supported by data frames
|
||||
*/
|
||||
private static final Set<String> COMPATIBLE_FIELD_TYPES;
|
||||
public static final Set<String> CATEGORICAL_TYPES = Collections.unmodifiableSet(new HashSet<>(Arrays.asList("text", "keyword", "ip")));
|
||||
|
||||
private static final Set<String> NUMERICAL_TYPES;
|
||||
|
||||
static {
|
||||
Set<String> compatibleTypes = Stream.of(NumberFieldMapper.NumberType.values())
|
||||
Set<String> numericalTypes = Stream.of(NumberFieldMapper.NumberType.values())
|
||||
.map(NumberFieldMapper.NumberType::typeName)
|
||||
.collect(Collectors.toSet());
|
||||
compatibleTypes.add("scaled_float"); // have to add manually since scaled_float is in a module
|
||||
|
||||
COMPATIBLE_FIELD_TYPES = Collections.unmodifiableSet(compatibleTypes);
|
||||
numericalTypes.add("scaled_float");
|
||||
NUMERICAL_TYPES = Collections.unmodifiableSet(numericalTypes);
|
||||
}
|
||||
|
||||
private final String[] index;
|
||||
|
@ -79,16 +82,18 @@ public class ExtractedFieldsDetector {
|
|||
// Ignore fields under the results object
|
||||
fields.removeIf(field -> field.startsWith(config.getDest().getResultsField() + "."));
|
||||
|
||||
includeAndExcludeFields(fields);
|
||||
removeFieldsWithIncompatibleTypes(fields);
|
||||
includeAndExcludeFields(fields, index);
|
||||
checkRequiredFieldsArePresent(fields);
|
||||
|
||||
if (fields.isEmpty()) {
|
||||
throw ExceptionsHelper.badRequestException("No compatible fields could be detected in index {}", Arrays.toString(index));
|
||||
}
|
||||
|
||||
List<String> sortedFields = new ArrayList<>(fields);
|
||||
// We sort the fields to ensure the checksum for each document is deterministic
|
||||
Collections.sort(sortedFields);
|
||||
ExtractedFields extractedFields = ExtractedFields.build(sortedFields, Collections.emptySet(), fieldCapabilitiesResponse)
|
||||
.filterFields(ExtractedField.ExtractionMethod.DOC_VALUE);
|
||||
if (extractedFields.getAllFields().isEmpty()) {
|
||||
throw ExceptionsHelper.badRequestException("No compatible fields could be detected in index {}", Arrays.toString(index));
|
||||
}
|
||||
ExtractedFields extractedFields = ExtractedFields.build(sortedFields, Collections.emptySet(), fieldCapabilitiesResponse);
|
||||
if (extractedFields.getDocValueFields().size() > docValueFieldsLimit) {
|
||||
extractedFields = fetchFromSourceIfSupported(extractedFields);
|
||||
if (extractedFields.getDocValueFields().size() > docValueFieldsLimit) {
|
||||
|
@ -120,13 +125,25 @@ public class ExtractedFieldsDetector {
|
|||
while (fieldsIterator.hasNext()) {
|
||||
String field = fieldsIterator.next();
|
||||
Map<String, FieldCapabilities> fieldCaps = fieldCapabilitiesResponse.getField(field);
|
||||
if (fieldCaps == null || COMPATIBLE_FIELD_TYPES.containsAll(fieldCaps.keySet()) == false) {
|
||||
if (fieldCaps == null) {
|
||||
LOGGER.debug("[{}] Removing field [{}] because it is missing from mappings", config.getId(), field);
|
||||
fieldsIterator.remove();
|
||||
} else {
|
||||
Set<String> fieldTypes = fieldCaps.keySet();
|
||||
if (NUMERICAL_TYPES.containsAll(fieldTypes)) {
|
||||
LOGGER.debug("[{}] field [{}] is compatible as it is numerical", config.getId(), field);
|
||||
} else if (config.getAnalysis().supportsCategoricalFields() && CATEGORICAL_TYPES.containsAll(fieldTypes)) {
|
||||
LOGGER.debug("[{}] field [{}] is compatible as it is categorical", config.getId(), field);
|
||||
} else {
|
||||
LOGGER.debug("[{}] Removing field [{}] because its types are not supported; types {}",
|
||||
config.getId(), field, fieldTypes);
|
||||
fieldsIterator.remove();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void includeAndExcludeFields(Set<String> fields, String[] index) {
|
||||
private void includeAndExcludeFields(Set<String> fields) {
|
||||
FetchSourceContext analyzedFields = config.getAnalyzedFields();
|
||||
if (analyzedFields == null) {
|
||||
return;
|
||||
|
@ -159,6 +176,16 @@ public class ExtractedFieldsDetector {
|
|||
}
|
||||
}
|
||||
|
||||
private void checkRequiredFieldsArePresent(Set<String> fields) {
|
||||
List<String> missingFields = config.getAnalysis().getRequiredFields()
|
||||
.stream()
|
||||
.filter(f -> fields.contains(f) == false)
|
||||
.collect(Collectors.toList());
|
||||
if (missingFields.isEmpty() == false) {
|
||||
throw ExceptionsHelper.badRequestException("required fields {} are missing", missingFields);
|
||||
}
|
||||
}
|
||||
|
||||
private ExtractedFields fetchFromSourceIfSupported(ExtractedFields extractedFields) {
|
||||
List<ExtractedField> adjusted = new ArrayList<>(extractedFields.getAllFields().size());
|
||||
for (ExtractedField field : extractedFields.getDocValueFields()) {
|
||||
|
|
|
@ -12,6 +12,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
|
|||
|
||||
import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
|
||||
public class AnalyticsProcessConfig implements ToXContentObject {
|
||||
|
||||
|
@ -21,21 +22,24 @@ public class AnalyticsProcessConfig implements ToXContentObject {
|
|||
private static final String THREADS = "threads";
|
||||
private static final String ANALYSIS = "analysis";
|
||||
private static final String RESULTS_FIELD = "results_field";
|
||||
private static final String CATEGORICAL_FIELDS = "categorical_fields";
|
||||
|
||||
private final long rows;
|
||||
private final int cols;
|
||||
private final ByteSizeValue memoryLimit;
|
||||
private final int threads;
|
||||
private final DataFrameAnalysis analysis;
|
||||
private final String resultsField;
|
||||
private final Set<String> categoricalFields;
|
||||
private final DataFrameAnalysis analysis;
|
||||
|
||||
public AnalyticsProcessConfig(long rows, int cols, ByteSizeValue memoryLimit, int threads, String resultsField,
|
||||
DataFrameAnalysis analysis) {
|
||||
Set<String> categoricalFields, DataFrameAnalysis analysis) {
|
||||
this.rows = rows;
|
||||
this.cols = cols;
|
||||
this.memoryLimit = Objects.requireNonNull(memoryLimit);
|
||||
this.threads = threads;
|
||||
this.resultsField = Objects.requireNonNull(resultsField);
|
||||
this.categoricalFields = Objects.requireNonNull(categoricalFields);
|
||||
this.analysis = Objects.requireNonNull(analysis);
|
||||
}
|
||||
|
||||
|
@ -51,6 +55,7 @@ public class AnalyticsProcessConfig implements ToXContentObject {
|
|||
builder.field(MEMORY_LIMIT, memoryLimit.getBytes());
|
||||
builder.field(THREADS, threads);
|
||||
builder.field(RESULTS_FIELD, resultsField);
|
||||
builder.field(CATEGORICAL_FIELDS, categoricalFields);
|
||||
builder.field(ANALYSIS, new DataFrameAnalysisWrapper(analysis));
|
||||
builder.endObject();
|
||||
return builder;
|
||||
|
|
|
@ -26,6 +26,7 @@ import java.io.IOException;
|
|||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.ConcurrentMap;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
|
@ -283,8 +284,9 @@ public class AnalyticsProcessManager {
|
|||
|
||||
private AnalyticsProcessConfig createProcessConfig(DataFrameAnalyticsConfig config, DataFrameDataExtractor dataExtractor) {
|
||||
DataFrameDataExtractor.DataSummary dataSummary = dataExtractor.collectDataSummary();
|
||||
Set<String> categoricalFields = dataExtractor.getCategoricalFields();
|
||||
AnalyticsProcessConfig processConfig = new AnalyticsProcessConfig(dataSummary.rows, dataSummary.cols,
|
||||
config.getModelMemoryLimit(), 1, config.getDest().getResultsField(), config.getAnalysis());
|
||||
config.getModelMemoryLimit(), 1, config.getDest().getResultsField(), categoricalFields, config.getAnalysis());
|
||||
return processConfig;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,6 +10,7 @@ import org.elasticsearch.test.ESTestCase;
|
|||
import org.elasticsearch.xpack.ml.test.SearchHitBuilder;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.startsWith;
|
||||
|
@ -19,46 +20,51 @@ public class ExtractedFieldTests extends ESTestCase {
|
|||
public void testValueGivenDocValue() {
|
||||
SearchHit hit = new SearchHitBuilder(42).addField("single", "bar").addField("array", Arrays.asList("a", "b")).build();
|
||||
|
||||
ExtractedField single = ExtractedField.newField("single", ExtractedField.ExtractionMethod.DOC_VALUE);
|
||||
ExtractedField single = ExtractedField.newField("single", Collections.singleton("keyword"),
|
||||
ExtractedField.ExtractionMethod.DOC_VALUE);
|
||||
assertThat(single.value(hit), equalTo(new String[] { "bar" }));
|
||||
|
||||
ExtractedField array = ExtractedField.newField("array", ExtractedField.ExtractionMethod.DOC_VALUE);
|
||||
ExtractedField array = ExtractedField.newField("array", Collections.singleton("keyword"),
|
||||
ExtractedField.ExtractionMethod.DOC_VALUE);
|
||||
assertThat(array.value(hit), equalTo(new String[] { "a", "b" }));
|
||||
|
||||
ExtractedField missing = ExtractedField.newField("missing", ExtractedField.ExtractionMethod.DOC_VALUE);
|
||||
ExtractedField missing = ExtractedField.newField("missing",Collections.singleton("keyword"),
|
||||
ExtractedField.ExtractionMethod.DOC_VALUE);
|
||||
assertThat(missing.value(hit), equalTo(new Object[0]));
|
||||
}
|
||||
|
||||
public void testValueGivenScriptField() {
|
||||
SearchHit hit = new SearchHitBuilder(42).addField("single", "bar").addField("array", Arrays.asList("a", "b")).build();
|
||||
|
||||
ExtractedField single = ExtractedField.newField("single", ExtractedField.ExtractionMethod.SCRIPT_FIELD);
|
||||
ExtractedField single = ExtractedField.newField("single",Collections.emptySet(),
|
||||
ExtractedField.ExtractionMethod.SCRIPT_FIELD);
|
||||
assertThat(single.value(hit), equalTo(new String[] { "bar" }));
|
||||
|
||||
ExtractedField array = ExtractedField.newField("array", ExtractedField.ExtractionMethod.SCRIPT_FIELD);
|
||||
ExtractedField array = ExtractedField.newField("array", Collections.emptySet(), ExtractedField.ExtractionMethod.SCRIPT_FIELD);
|
||||
assertThat(array.value(hit), equalTo(new String[] { "a", "b" }));
|
||||
|
||||
ExtractedField missing = ExtractedField.newField("missing", ExtractedField.ExtractionMethod.SCRIPT_FIELD);
|
||||
ExtractedField missing = ExtractedField.newField("missing", Collections.emptySet(), ExtractedField.ExtractionMethod.SCRIPT_FIELD);
|
||||
assertThat(missing.value(hit), equalTo(new Object[0]));
|
||||
}
|
||||
|
||||
public void testValueGivenSource() {
|
||||
SearchHit hit = new SearchHitBuilder(42).setSource("{\"single\":\"bar\",\"array\":[\"a\",\"b\"]}").build();
|
||||
|
||||
ExtractedField single = ExtractedField.newField("single", ExtractedField.ExtractionMethod.SOURCE);
|
||||
ExtractedField single = ExtractedField.newField("single", Collections.singleton("text"), ExtractedField.ExtractionMethod.SOURCE);
|
||||
assertThat(single.value(hit), equalTo(new String[] { "bar" }));
|
||||
|
||||
ExtractedField array = ExtractedField.newField("array", ExtractedField.ExtractionMethod.SOURCE);
|
||||
ExtractedField array = ExtractedField.newField("array", Collections.singleton("text"), ExtractedField.ExtractionMethod.SOURCE);
|
||||
assertThat(array.value(hit), equalTo(new String[] { "a", "b" }));
|
||||
|
||||
ExtractedField missing = ExtractedField.newField("missing", ExtractedField.ExtractionMethod.SOURCE);
|
||||
ExtractedField missing = ExtractedField.newField("missing", Collections.singleton("text"), ExtractedField.ExtractionMethod.SOURCE);
|
||||
assertThat(missing.value(hit), equalTo(new Object[0]));
|
||||
}
|
||||
|
||||
public void testValueGivenNestedSource() {
|
||||
SearchHit hit = new SearchHitBuilder(42).setSource("{\"level_1\":{\"level_2\":{\"foo\":\"bar\"}}}").build();
|
||||
|
||||
ExtractedField nested = ExtractedField.newField("alias", "level_1.level_2.foo", ExtractedField.ExtractionMethod.SOURCE);
|
||||
ExtractedField nested = ExtractedField.newField("alias", "level_1.level_2.foo", Collections.singleton("text"),
|
||||
ExtractedField.ExtractionMethod.SOURCE);
|
||||
assertThat(nested.value(hit), equalTo(new String[] { "bar" }));
|
||||
}
|
||||
|
||||
|
@ -91,49 +97,54 @@ public class ExtractedFieldTests extends ESTestCase {
|
|||
}
|
||||
|
||||
public void testValueGivenSourceAndHitWithNoSource() {
|
||||
ExtractedField missing = ExtractedField.newField("missing", ExtractedField.ExtractionMethod.SOURCE);
|
||||
ExtractedField missing = ExtractedField.newField("missing", Collections.singleton("text"), ExtractedField.ExtractionMethod.SOURCE);
|
||||
assertThat(missing.value(new SearchHitBuilder(3).build()), equalTo(new Object[0]));
|
||||
}
|
||||
|
||||
public void testValueGivenMismatchingMethod() {
|
||||
SearchHit hit = new SearchHitBuilder(42).addField("a", 1).setSource("{\"b\":2}").build();
|
||||
|
||||
ExtractedField invalidA = ExtractedField.newField("a", ExtractedField.ExtractionMethod.SOURCE);
|
||||
ExtractedField invalidA = ExtractedField.newField("a", Collections.singleton("text"), ExtractedField.ExtractionMethod.SOURCE);
|
||||
assertThat(invalidA.value(hit), equalTo(new Object[0]));
|
||||
ExtractedField validA = ExtractedField.newField("a", ExtractedField.ExtractionMethod.DOC_VALUE);
|
||||
ExtractedField validA = ExtractedField.newField("a", Collections.singleton("keyword"), ExtractedField.ExtractionMethod.DOC_VALUE);
|
||||
assertThat(validA.value(hit), equalTo(new Integer[] { 1 }));
|
||||
|
||||
ExtractedField invalidB = ExtractedField.newField("b", ExtractedField.ExtractionMethod.DOC_VALUE);
|
||||
ExtractedField invalidB = ExtractedField.newField("b", Collections.singleton("keyword"),
|
||||
ExtractedField.ExtractionMethod.DOC_VALUE);
|
||||
assertThat(invalidB.value(hit), equalTo(new Object[0]));
|
||||
ExtractedField validB = ExtractedField.newField("b", ExtractedField.ExtractionMethod.SOURCE);
|
||||
ExtractedField validB = ExtractedField.newField("b", Collections.singleton("text"), ExtractedField.ExtractionMethod.SOURCE);
|
||||
assertThat(validB.value(hit), equalTo(new Integer[] { 2 }));
|
||||
}
|
||||
|
||||
public void testValueGivenEmptyHit() {
|
||||
SearchHit hit = new SearchHitBuilder(42).build();
|
||||
|
||||
ExtractedField docValue = ExtractedField.newField("a", ExtractedField.ExtractionMethod.SOURCE);
|
||||
ExtractedField docValue = ExtractedField.newField("a", Collections.singleton("text"), ExtractedField.ExtractionMethod.SOURCE);
|
||||
assertThat(docValue.value(hit), equalTo(new Object[0]));
|
||||
|
||||
ExtractedField sourceField = ExtractedField.newField("b", ExtractedField.ExtractionMethod.DOC_VALUE);
|
||||
ExtractedField sourceField = ExtractedField.newField("b", Collections.singleton("keyword"),
|
||||
ExtractedField.ExtractionMethod.DOC_VALUE);
|
||||
assertThat(sourceField.value(hit), equalTo(new Object[0]));
|
||||
}
|
||||
|
||||
public void testNewTimeFieldGivenSource() {
|
||||
expectThrows(IllegalArgumentException.class, () -> ExtractedField.newTimeField("time", ExtractedField.ExtractionMethod.SOURCE));
|
||||
expectThrows(IllegalArgumentException.class, () -> ExtractedField.newTimeField("time", Collections.singleton("date"),
|
||||
ExtractedField.ExtractionMethod.SOURCE));
|
||||
}
|
||||
|
||||
public void testValueGivenStringTimeField() {
|
||||
final long millis = randomLong();
|
||||
final SearchHit hit = new SearchHitBuilder(randomInt()).addField("time", Long.toString(millis)).build();
|
||||
final ExtractedField timeField = ExtractedField.newTimeField("time", ExtractedField.ExtractionMethod.DOC_VALUE);
|
||||
final ExtractedField timeField = ExtractedField.newTimeField("time", Collections.singleton("date"),
|
||||
ExtractedField.ExtractionMethod.DOC_VALUE);
|
||||
assertThat(timeField.value(hit), equalTo(new Object[] { millis }));
|
||||
}
|
||||
|
||||
public void testValueGivenLongTimeField() {
|
||||
final long millis = randomLong();
|
||||
final SearchHit hit = new SearchHitBuilder(randomInt()).addField("time", millis).build();
|
||||
final ExtractedField timeField = ExtractedField.newTimeField("time", ExtractedField.ExtractionMethod.DOC_VALUE);
|
||||
final ExtractedField timeField = ExtractedField.newTimeField("time", Collections.singleton("date"),
|
||||
ExtractedField.ExtractionMethod.DOC_VALUE);
|
||||
assertThat(timeField.value(hit), equalTo(new Object[] { millis }));
|
||||
}
|
||||
|
||||
|
@ -141,13 +152,15 @@ public class ExtractedFieldTests extends ESTestCase {
|
|||
// Prior to 6.x, timestamps were simply `long` milliseconds-past-the-epoch values
|
||||
final long millis = randomLong();
|
||||
final SearchHit hit = new SearchHitBuilder(randomInt()).addField("time", millis).build();
|
||||
final ExtractedField timeField = ExtractedField.newTimeField("time", ExtractedField.ExtractionMethod.DOC_VALUE);
|
||||
final ExtractedField timeField = ExtractedField.newTimeField("time", Collections.singleton("date"),
|
||||
ExtractedField.ExtractionMethod.DOC_VALUE);
|
||||
assertThat(timeField.value(hit), equalTo(new Object[] { millis }));
|
||||
}
|
||||
|
||||
public void testValueGivenUnknownFormatTimeField() {
|
||||
final SearchHit hit = new SearchHitBuilder(randomInt()).addField("time", new Object()).build();
|
||||
final ExtractedField timeField = ExtractedField.newTimeField("time", ExtractedField.ExtractionMethod.DOC_VALUE);
|
||||
final ExtractedField timeField = ExtractedField.newTimeField("time", Collections.singleton("date"),
|
||||
ExtractedField.ExtractionMethod.DOC_VALUE);
|
||||
assertThat(expectThrows(IllegalStateException.class, () -> timeField.value(hit)).getMessage(),
|
||||
startsWith("Unexpected value for a time field"));
|
||||
}
|
||||
|
@ -155,14 +168,15 @@ public class ExtractedFieldTests extends ESTestCase {
|
|||
public void testAliasVersusName() {
|
||||
SearchHit hit = new SearchHitBuilder(42).addField("a", 1).addField("b", 2).build();
|
||||
|
||||
ExtractedField field = ExtractedField.newField("a", "a", ExtractedField.ExtractionMethod.DOC_VALUE);
|
||||
ExtractedField field = ExtractedField.newField("a", "a", Collections.singleton("int"),
|
||||
ExtractedField.ExtractionMethod.DOC_VALUE);
|
||||
assertThat(field.getAlias(), equalTo("a"));
|
||||
assertThat(field.getName(), equalTo("a"));
|
||||
assertThat(field.value(hit), equalTo(new Integer[] { 1 }));
|
||||
|
||||
hit = new SearchHitBuilder(42).addField("a", 1).addField("b", 2).build();
|
||||
|
||||
field = ExtractedField.newField("a", "b", ExtractedField.ExtractionMethod.DOC_VALUE);
|
||||
field = ExtractedField.newField("a", "b", Collections.singleton("int"), ExtractedField.ExtractionMethod.DOC_VALUE);
|
||||
assertThat(field.getAlias(), equalTo("a"));
|
||||
assertThat(field.getName(), equalTo("b"));
|
||||
assertThat(field.value(hit), equalTo(new Integer[] { 2 }));
|
||||
|
@ -170,11 +184,11 @@ public class ExtractedFieldTests extends ESTestCase {
|
|||
|
||||
public void testGetDocValueFormat() {
|
||||
for (ExtractedField.ExtractionMethod method : ExtractedField.ExtractionMethod.values()) {
|
||||
assertThat(ExtractedField.newField("f", method).getDocValueFormat(), equalTo(null));
|
||||
assertThat(ExtractedField.newField("f", Collections.emptySet(), method).getDocValueFormat(), equalTo(null));
|
||||
}
|
||||
assertThat(ExtractedField.newTimeField("doc_value_time", ExtractedField.ExtractionMethod.DOC_VALUE).getDocValueFormat(),
|
||||
equalTo("epoch_millis"));
|
||||
assertThat(ExtractedField.newTimeField("source_time", ExtractedField.ExtractionMethod.SCRIPT_FIELD).getDocValueFormat(),
|
||||
equalTo("epoch_millis"));
|
||||
assertThat(ExtractedField.newTimeField("doc_value_time", Collections.singleton("date"),
|
||||
ExtractedField.ExtractionMethod.DOC_VALUE).getDocValueFormat(), equalTo("epoch_millis"));
|
||||
assertThat(ExtractedField.newTimeField("source_time", Collections.emptySet(),
|
||||
ExtractedField.ExtractionMethod.SCRIPT_FIELD).getDocValueFormat(), equalTo("epoch_millis"));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -27,12 +27,18 @@ import static org.mockito.Mockito.when;
|
|||
public class ExtractedFieldsTests extends ESTestCase {
|
||||
|
||||
public void testAllTypesOfFields() {
|
||||
ExtractedField docValue1 = ExtractedField.newField("doc1", ExtractedField.ExtractionMethod.DOC_VALUE);
|
||||
ExtractedField docValue2 = ExtractedField.newField("doc2", ExtractedField.ExtractionMethod.DOC_VALUE);
|
||||
ExtractedField scriptField1 = ExtractedField.newField("scripted1", ExtractedField.ExtractionMethod.SCRIPT_FIELD);
|
||||
ExtractedField scriptField2 = ExtractedField.newField("scripted2", ExtractedField.ExtractionMethod.SCRIPT_FIELD);
|
||||
ExtractedField sourceField1 = ExtractedField.newField("src1", ExtractedField.ExtractionMethod.SOURCE);
|
||||
ExtractedField sourceField2 = ExtractedField.newField("src2", ExtractedField.ExtractionMethod.SOURCE);
|
||||
ExtractedField docValue1 = ExtractedField.newField("doc1", Collections.singleton("keyword"),
|
||||
ExtractedField.ExtractionMethod.DOC_VALUE);
|
||||
ExtractedField docValue2 = ExtractedField.newField("doc2", Collections.singleton("ip"),
|
||||
ExtractedField.ExtractionMethod.DOC_VALUE);
|
||||
ExtractedField scriptField1 = ExtractedField.newField("scripted1", Collections.emptySet(),
|
||||
ExtractedField.ExtractionMethod.SCRIPT_FIELD);
|
||||
ExtractedField scriptField2 = ExtractedField.newField("scripted2", Collections.emptySet(),
|
||||
ExtractedField.ExtractionMethod.SCRIPT_FIELD);
|
||||
ExtractedField sourceField1 = ExtractedField.newField("src1", Collections.singleton("text"),
|
||||
ExtractedField.ExtractionMethod.SOURCE);
|
||||
ExtractedField sourceField2 = ExtractedField.newField("src2", Collections.singleton("text"),
|
||||
ExtractedField.ExtractionMethod.SOURCE);
|
||||
ExtractedFields extractedFields = new ExtractedFields(Arrays.asList(
|
||||
docValue1, docValue2, scriptField1, scriptField2, sourceField1, sourceField2));
|
||||
|
||||
|
|
|
@ -29,7 +29,8 @@ import static org.mockito.Mockito.when;
|
|||
|
||||
public class TimeBasedExtractedFieldsTests extends ESTestCase {
|
||||
|
||||
private ExtractedField timeField = ExtractedField.newTimeField("time", ExtractedField.ExtractionMethod.DOC_VALUE);
|
||||
private ExtractedField timeField = ExtractedField.newTimeField("time", Collections.singleton("date"),
|
||||
ExtractedField.ExtractionMethod.DOC_VALUE);
|
||||
|
||||
public void testInvalidConstruction() {
|
||||
expectThrows(IllegalArgumentException.class, () -> new TimeBasedExtractedFields(timeField, Collections.emptyList()));
|
||||
|
@ -46,12 +47,18 @@ public class TimeBasedExtractedFieldsTests extends ESTestCase {
|
|||
}
|
||||
|
||||
public void testAllTypesOfFields() {
|
||||
ExtractedField docValue1 = ExtractedField.newField("doc1", ExtractedField.ExtractionMethod.DOC_VALUE);
|
||||
ExtractedField docValue2 = ExtractedField.newField("doc2", ExtractedField.ExtractionMethod.DOC_VALUE);
|
||||
ExtractedField scriptField1 = ExtractedField.newField("scripted1", ExtractedField.ExtractionMethod.SCRIPT_FIELD);
|
||||
ExtractedField scriptField2 = ExtractedField.newField("scripted2", ExtractedField.ExtractionMethod.SCRIPT_FIELD);
|
||||
ExtractedField sourceField1 = ExtractedField.newField("src1", ExtractedField.ExtractionMethod.SOURCE);
|
||||
ExtractedField sourceField2 = ExtractedField.newField("src2", ExtractedField.ExtractionMethod.SOURCE);
|
||||
ExtractedField docValue1 = ExtractedField.newField("doc1", Collections.singleton("keyword"),
|
||||
ExtractedField.ExtractionMethod.DOC_VALUE);
|
||||
ExtractedField docValue2 = ExtractedField.newField("doc2", Collections.singleton("float"),
|
||||
ExtractedField.ExtractionMethod.DOC_VALUE);
|
||||
ExtractedField scriptField1 = ExtractedField.newField("scripted1", Collections.emptySet(),
|
||||
ExtractedField.ExtractionMethod.SCRIPT_FIELD);
|
||||
ExtractedField scriptField2 = ExtractedField.newField("scripted2", Collections.emptySet(),
|
||||
ExtractedField.ExtractionMethod.SCRIPT_FIELD);
|
||||
ExtractedField sourceField1 = ExtractedField.newField("src1", Collections.singleton("text"),
|
||||
ExtractedField.ExtractionMethod.SOURCE);
|
||||
ExtractedField sourceField2 = ExtractedField.newField("src2", Collections.singleton("text"),
|
||||
ExtractedField.ExtractionMethod.SOURCE);
|
||||
TimeBasedExtractedFields extractedFields = new TimeBasedExtractedFields(timeField, Arrays.asList(timeField,
|
||||
docValue1, docValue2, scriptField1, scriptField2, sourceField1, sourceField2));
|
||||
|
||||
|
|
|
@ -135,9 +135,11 @@ public class ScrollDataExtractorTests extends ESTestCase {
|
|||
capturedSearchRequests = new ArrayList<>();
|
||||
capturedContinueScrollIds = new ArrayList<>();
|
||||
jobId = "test-job";
|
||||
ExtractedField timeField = ExtractedField.newField("time", ExtractedField.ExtractionMethod.DOC_VALUE);
|
||||
ExtractedField timeField = ExtractedField.newField("time", Collections.singleton("date"),
|
||||
ExtractedField.ExtractionMethod.DOC_VALUE);
|
||||
extractedFields = new TimeBasedExtractedFields(timeField,
|
||||
Arrays.asList(timeField, ExtractedField.newField("field_1", ExtractedField.ExtractionMethod.DOC_VALUE)));
|
||||
Arrays.asList(timeField, ExtractedField.newField("field_1", Collections.singleton("keyword"),
|
||||
ExtractedField.ExtractionMethod.DOC_VALUE)));
|
||||
indices = Arrays.asList("index-1", "index-2");
|
||||
query = QueryBuilders.matchAllQuery();
|
||||
scriptFields = Collections.emptyList();
|
||||
|
|
|
@ -16,16 +16,21 @@ import java.io.ByteArrayOutputStream;
|
|||
import java.io.IOException;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public class SearchHitToJsonProcessorTests extends ESTestCase {
|
||||
|
||||
public void testProcessGivenSingleHit() throws IOException {
|
||||
ExtractedField timeField = ExtractedField.newField("time", ExtractedField.ExtractionMethod.DOC_VALUE);
|
||||
ExtractedField missingField = ExtractedField.newField("missing", ExtractedField.ExtractionMethod.DOC_VALUE);
|
||||
ExtractedField singleField = ExtractedField.newField("single", ExtractedField.ExtractionMethod.DOC_VALUE);
|
||||
ExtractedField arrayField = ExtractedField.newField("array", ExtractedField.ExtractionMethod.DOC_VALUE);
|
||||
ExtractedField timeField = ExtractedField.newField("time", Collections.singleton("date"),
|
||||
ExtractedField.ExtractionMethod.DOC_VALUE);
|
||||
ExtractedField missingField = ExtractedField.newField("missing", Collections.singleton("float"),
|
||||
ExtractedField.ExtractionMethod.DOC_VALUE);
|
||||
ExtractedField singleField = ExtractedField.newField("single", Collections.singleton("keyword"),
|
||||
ExtractedField.ExtractionMethod.DOC_VALUE);
|
||||
ExtractedField arrayField = ExtractedField.newField("array", Collections.singleton("keyword"),
|
||||
ExtractedField.ExtractionMethod.DOC_VALUE);
|
||||
TimeBasedExtractedFields extractedFields = new TimeBasedExtractedFields(timeField,
|
||||
Arrays.asList(timeField, missingField, singleField, arrayField));
|
||||
|
||||
|
@ -41,10 +46,14 @@ public class SearchHitToJsonProcessorTests extends ESTestCase {
|
|||
}
|
||||
|
||||
public void testProcessGivenMultipleHits() throws IOException {
|
||||
ExtractedField timeField = ExtractedField.newField("time", ExtractedField.ExtractionMethod.DOC_VALUE);
|
||||
ExtractedField missingField = ExtractedField.newField("missing", ExtractedField.ExtractionMethod.DOC_VALUE);
|
||||
ExtractedField singleField = ExtractedField.newField("single", ExtractedField.ExtractionMethod.DOC_VALUE);
|
||||
ExtractedField arrayField = ExtractedField.newField("array", ExtractedField.ExtractionMethod.DOC_VALUE);
|
||||
ExtractedField timeField = ExtractedField.newField("time", Collections.singleton("date"),
|
||||
ExtractedField.ExtractionMethod.DOC_VALUE);
|
||||
ExtractedField missingField = ExtractedField.newField("missing", Collections.singleton("float"),
|
||||
ExtractedField.ExtractionMethod.DOC_VALUE);
|
||||
ExtractedField singleField = ExtractedField.newField("single", Collections.singleton("keyword"),
|
||||
ExtractedField.ExtractionMethod.DOC_VALUE);
|
||||
ExtractedField arrayField = ExtractedField.newField("array", Collections.singleton("keyword"),
|
||||
ExtractedField.ExtractionMethod.DOC_VALUE);
|
||||
TimeBasedExtractedFields extractedFields = new TimeBasedExtractedFields(timeField,
|
||||
Arrays.asList(timeField, missingField, singleField, arrayField));
|
||||
|
||||
|
|
|
@ -71,8 +71,8 @@ public class DataFrameDataExtractorTests extends ESTestCase {
|
|||
indices = Arrays.asList("index-1", "index-2");
|
||||
query = QueryBuilders.matchAllQuery();
|
||||
extractedFields = new ExtractedFields(Arrays.asList(
|
||||
ExtractedField.newField("field_1", ExtractedField.ExtractionMethod.DOC_VALUE),
|
||||
ExtractedField.newField("field_2", ExtractedField.ExtractionMethod.DOC_VALUE)));
|
||||
ExtractedField.newField("field_1", Collections.singleton("keyword"), ExtractedField.ExtractionMethod.DOC_VALUE),
|
||||
ExtractedField.newField("field_2", Collections.singleton("keyword"), ExtractedField.ExtractionMethod.DOC_VALUE)));
|
||||
scrollSize = 1000;
|
||||
headers = Collections.emptyMap();
|
||||
|
||||
|
@ -288,8 +288,8 @@ public class DataFrameDataExtractorTests extends ESTestCase {
|
|||
|
||||
public void testIncludeSourceIsFalseAndAtLeastOneSourceField() throws IOException {
|
||||
extractedFields = new ExtractedFields(Arrays.asList(
|
||||
ExtractedField.newField("field_1", ExtractedField.ExtractionMethod.DOC_VALUE),
|
||||
ExtractedField.newField("field_2", ExtractedField.ExtractionMethod.SOURCE)));
|
||||
ExtractedField.newField("field_1", Collections.singleton("keyword"), ExtractedField.ExtractionMethod.DOC_VALUE),
|
||||
ExtractedField.newField("field_2", Collections.singleton("text"), ExtractedField.ExtractionMethod.SOURCE)));
|
||||
|
||||
TestExtractor dataExtractor = createExtractor(false);
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
|||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||
import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedField;
|
||||
import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedFields;
|
||||
|
||||
|
@ -38,11 +39,11 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
|
|||
private static final String RESULTS_FIELD = "ml";
|
||||
|
||||
public void testDetect_GivenFloatField() {
|
||||
FieldCapabilitiesResponse fieldCapabilities= new MockFieldCapsResponseBuilder()
|
||||
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
|
||||
.addAggregatableField("some_float", "float").build();
|
||||
|
||||
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
|
||||
SOURCE_INDEX, buildAnalyticsConfig(), false, 100, fieldCapabilities);
|
||||
SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities);
|
||||
ExtractedFields extractedFields = extractedFieldsDetector.detect();
|
||||
|
||||
List<ExtractedField> allFields = extractedFields.getAllFields();
|
||||
|
@ -52,12 +53,12 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
|
|||
}
|
||||
|
||||
public void testDetect_GivenNumericFieldWithMultipleTypes() {
|
||||
FieldCapabilitiesResponse fieldCapabilities= new MockFieldCapsResponseBuilder()
|
||||
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
|
||||
.addAggregatableField("some_number", "long", "integer", "short", "byte", "double", "float", "half_float", "scaled_float")
|
||||
.build();
|
||||
|
||||
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
|
||||
SOURCE_INDEX, buildAnalyticsConfig(), false, 100, fieldCapabilities);
|
||||
SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities);
|
||||
ExtractedFields extractedFields = extractedFieldsDetector.detect();
|
||||
|
||||
List<ExtractedField> allFields = extractedFields.getAllFields();
|
||||
|
@ -67,36 +68,36 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
|
|||
}
|
||||
|
||||
public void testDetect_GivenNonNumericField() {
|
||||
FieldCapabilitiesResponse fieldCapabilities= new MockFieldCapsResponseBuilder()
|
||||
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
|
||||
.addAggregatableField("some_keyword", "keyword").build();
|
||||
|
||||
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
|
||||
SOURCE_INDEX, buildAnalyticsConfig(), false, 100, fieldCapabilities);
|
||||
SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities);
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect());
|
||||
|
||||
assertThat(e.getMessage(), equalTo("No compatible fields could be detected in index [source_index]"));
|
||||
}
|
||||
|
||||
public void testDetect_GivenFieldWithNumericAndNonNumericTypes() {
|
||||
FieldCapabilitiesResponse fieldCapabilities= new MockFieldCapsResponseBuilder()
|
||||
public void testDetect_GivenOutlierDetectionAndFieldWithNumericAndNonNumericTypes() {
|
||||
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
|
||||
.addAggregatableField("indecisive_field", "float", "keyword").build();
|
||||
|
||||
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
|
||||
SOURCE_INDEX, buildAnalyticsConfig(), false, 100, fieldCapabilities);
|
||||
SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities);
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect());
|
||||
|
||||
assertThat(e.getMessage(), equalTo("No compatible fields could be detected in index [source_index]"));
|
||||
}
|
||||
|
||||
public void testDetect_GivenMultipleFields() {
|
||||
FieldCapabilitiesResponse fieldCapabilities= new MockFieldCapsResponseBuilder()
|
||||
public void testDetect_GivenOutlierDetectionAndMultipleFields() {
|
||||
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
|
||||
.addAggregatableField("some_float", "float")
|
||||
.addAggregatableField("some_long", "long")
|
||||
.addAggregatableField("some_keyword", "keyword")
|
||||
.build();
|
||||
|
||||
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
|
||||
SOURCE_INDEX, buildAnalyticsConfig(), false, 100, fieldCapabilities);
|
||||
SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities);
|
||||
ExtractedFields extractedFields = extractedFieldsDetector.detect();
|
||||
|
||||
List<ExtractedField> allFields = extractedFields.getAllFields();
|
||||
|
@ -107,12 +108,46 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
|
|||
contains(equalTo(ExtractedField.ExtractionMethod.DOC_VALUE)));
|
||||
}
|
||||
|
||||
public void testDetect_GivenRegressionAndMultipleFields() {
|
||||
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
|
||||
.addAggregatableField("some_float", "float")
|
||||
.addAggregatableField("some_long", "long")
|
||||
.addAggregatableField("some_keyword", "keyword")
|
||||
.addAggregatableField("foo", "keyword")
|
||||
.build();
|
||||
|
||||
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
|
||||
SOURCE_INDEX, buildRegressionConfig("foo"), false, 100, fieldCapabilities);
|
||||
ExtractedFields extractedFields = extractedFieldsDetector.detect();
|
||||
|
||||
List<ExtractedField> allFields = extractedFields.getAllFields();
|
||||
assertThat(allFields.size(), equalTo(4));
|
||||
assertThat(allFields.stream().map(ExtractedField::getName).collect(Collectors.toList()),
|
||||
contains("foo", "some_float", "some_keyword", "some_long"));
|
||||
assertThat(allFields.stream().map(ExtractedField::getExtractionMethod).collect(Collectors.toSet()),
|
||||
contains(equalTo(ExtractedField.ExtractionMethod.DOC_VALUE)));
|
||||
}
|
||||
|
||||
public void testDetect_GivenRegressionAndRequiredFieldMissing() {
|
||||
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
|
||||
.addAggregatableField("some_float", "float")
|
||||
.addAggregatableField("some_long", "long")
|
||||
.addAggregatableField("some_keyword", "keyword")
|
||||
.build();
|
||||
|
||||
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
|
||||
SOURCE_INDEX, buildRegressionConfig("foo"), false, 100, fieldCapabilities);
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect());
|
||||
|
||||
assertThat(e.getMessage(), equalTo("required fields [foo] are missing"));
|
||||
}
|
||||
|
||||
public void testDetect_GivenIgnoredField() {
|
||||
FieldCapabilitiesResponse fieldCapabilities= new MockFieldCapsResponseBuilder()
|
||||
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
|
||||
.addAggregatableField("_id", "float").build();
|
||||
|
||||
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
|
||||
SOURCE_INDEX, buildAnalyticsConfig(), false, 100, fieldCapabilities);
|
||||
SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities);
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect());
|
||||
|
||||
assertThat(e.getMessage(), equalTo("No compatible fields could be detected in index [source_index]"));
|
||||
|
@ -134,7 +169,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
|
|||
FieldCapabilitiesResponse fieldCapabilities = mockFieldCapsResponseBuilder.build();
|
||||
|
||||
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
|
||||
SOURCE_INDEX, buildAnalyticsConfig(), false, 100, fieldCapabilities);
|
||||
SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities);
|
||||
ExtractedFields extractedFields = extractedFieldsDetector.detect();
|
||||
|
||||
List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName)
|
||||
|
@ -151,7 +186,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
|
|||
FetchSourceContext desiredFields = new FetchSourceContext(true, new String[]{"your_field1", "my*"}, new String[0]);
|
||||
|
||||
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
|
||||
SOURCE_INDEX, buildAnalyticsConfig(desiredFields), false, 100, fieldCapabilities);
|
||||
SOURCE_INDEX, buildOutlierDetectionConfig(desiredFields), false, 100, fieldCapabilities);
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect());
|
||||
|
||||
assertThat(e.getMessage(), equalTo("No field [your_field1] could be detected"));
|
||||
|
@ -166,7 +201,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
|
|||
FetchSourceContext desiredFields = new FetchSourceContext(true, new String[0], new String[]{"my_*"});
|
||||
|
||||
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
|
||||
SOURCE_INDEX, buildAnalyticsConfig(desiredFields), false, 100, fieldCapabilities);
|
||||
SOURCE_INDEX, buildOutlierDetectionConfig(desiredFields), false, 100, fieldCapabilities);
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect());
|
||||
assertThat(e.getMessage(), equalTo("No compatible fields could be detected in index [source_index]"));
|
||||
}
|
||||
|
@ -182,7 +217,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
|
|||
FetchSourceContext desiredFields = new FetchSourceContext(true, new String[]{"your*", "my_*"}, new String[]{"*nope"});
|
||||
|
||||
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
|
||||
SOURCE_INDEX, buildAnalyticsConfig(desiredFields), false, 100, fieldCapabilities);
|
||||
SOURCE_INDEX, buildOutlierDetectionConfig(desiredFields), false, 100, fieldCapabilities);
|
||||
ExtractedFields extractedFields = extractedFieldsDetector.detect();
|
||||
|
||||
List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName)
|
||||
|
@ -199,7 +234,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
|
|||
.build();
|
||||
|
||||
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
|
||||
SOURCE_INDEX, buildAnalyticsConfig(), false, 100, fieldCapabilities);
|
||||
SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities);
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect());
|
||||
|
||||
assertThat(e.getMessage(), equalTo("A field that matches the dest.results_field [ml] already exists; " +
|
||||
|
@ -215,7 +250,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
|
|||
.build();
|
||||
|
||||
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
|
||||
SOURCE_INDEX, buildAnalyticsConfig(), true, 100, fieldCapabilities);
|
||||
SOURCE_INDEX, buildOutlierDetectionConfig(), true, 100, fieldCapabilities);
|
||||
ExtractedFields extractedFields = extractedFieldsDetector.detect();
|
||||
|
||||
List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName)
|
||||
|
@ -232,7 +267,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
|
|||
.build();
|
||||
|
||||
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
|
||||
SOURCE_INDEX, buildAnalyticsConfig(), true, 4, fieldCapabilities);
|
||||
SOURCE_INDEX, buildOutlierDetectionConfig(), true, 4, fieldCapabilities);
|
||||
ExtractedFields extractedFields = extractedFieldsDetector.detect();
|
||||
|
||||
List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName)
|
||||
|
@ -251,7 +286,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
|
|||
.build();
|
||||
|
||||
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
|
||||
SOURCE_INDEX, buildAnalyticsConfig(), true, 3, fieldCapabilities);
|
||||
SOURCE_INDEX, buildOutlierDetectionConfig(), true, 3, fieldCapabilities);
|
||||
ExtractedFields extractedFields = extractedFieldsDetector.detect();
|
||||
|
||||
List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName)
|
||||
|
@ -270,7 +305,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
|
|||
.build();
|
||||
|
||||
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
|
||||
SOURCE_INDEX, buildAnalyticsConfig(), true, 2, fieldCapabilities);
|
||||
SOURCE_INDEX, buildOutlierDetectionConfig(), true, 2, fieldCapabilities);
|
||||
ExtractedFields extractedFields = extractedFieldsDetector.detect();
|
||||
|
||||
List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName)
|
||||
|
@ -280,11 +315,11 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
|
|||
contains(equalTo(ExtractedField.ExtractionMethod.SOURCE)));
|
||||
}
|
||||
|
||||
private static DataFrameAnalyticsConfig buildAnalyticsConfig() {
|
||||
return buildAnalyticsConfig(null);
|
||||
private static DataFrameAnalyticsConfig buildOutlierDetectionConfig() {
|
||||
return buildOutlierDetectionConfig(null);
|
||||
}
|
||||
|
||||
private static DataFrameAnalyticsConfig buildAnalyticsConfig(FetchSourceContext analyzedFields) {
|
||||
private static DataFrameAnalyticsConfig buildOutlierDetectionConfig(FetchSourceContext analyzedFields) {
|
||||
return new DataFrameAnalyticsConfig.Builder("foo")
|
||||
.setSource(new DataFrameAnalyticsSource(SOURCE_INDEX, null))
|
||||
.setDest(new DataFrameAnalyticsDest(DEST_INDEX, null))
|
||||
|
@ -293,6 +328,19 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
|
|||
.build();
|
||||
}
|
||||
|
||||
private static DataFrameAnalyticsConfig buildRegressionConfig(String dependentVariable) {
|
||||
return buildRegressionConfig(dependentVariable, null);
|
||||
}
|
||||
|
||||
private static DataFrameAnalyticsConfig buildRegressionConfig(String dependentVariable, FetchSourceContext analyzedFields) {
|
||||
return new DataFrameAnalyticsConfig.Builder("foo")
|
||||
.setSource(new DataFrameAnalyticsSource(SOURCE_INDEX, null))
|
||||
.setDest(new DataFrameAnalyticsDest(DEST_INDEX, null))
|
||||
.setAnalyzedFields(analyzedFields)
|
||||
.setAnalysis(new Regression(dependentVariable))
|
||||
.build();
|
||||
}
|
||||
|
||||
private static class MockFieldCapsResponseBuilder {
|
||||
|
||||
private final Map<String, Map<String, FieldCapabilities>> fieldCaps = new HashMap<>();
|
||||
|
|
|
@ -607,7 +607,11 @@ setup:
|
|||
"dest": {
|
||||
"index": "index-bar_dest"
|
||||
},
|
||||
"analysis": {"outlier_detection":{}}
|
||||
"analysis": {
|
||||
"regression":{
|
||||
"dependent_variable": "to_predict"
|
||||
}
|
||||
}
|
||||
}
|
||||
- match: { id: "bar" }
|
||||
|
||||
|
@ -768,7 +772,11 @@ setup:
|
|||
"dest": {
|
||||
"index": "index-bar_dest"
|
||||
},
|
||||
"analysis": {"outlier_detection":{}}
|
||||
"analysis": {
|
||||
"regression":{
|
||||
"dependent_variable": "to_predict"
|
||||
}
|
||||
}
|
||||
}
|
||||
- match: { id: "bar" }
|
||||
|
||||
|
@ -930,3 +938,247 @@ setup:
|
|||
xpack.ml.max_model_memory_limit: null
|
||||
- match: {transient: {}}
|
||||
|
||||
---
|
||||
"Test put regression given dependent_variable is not defined":
|
||||
|
||||
- do:
|
||||
catch: /parse_exception/
|
||||
ml.put_data_frame_analytics:
|
||||
id: "regression-without-dependent-variable"
|
||||
body: >
|
||||
{
|
||||
"source": {
|
||||
"index": "index-source"
|
||||
},
|
||||
"dest": {
|
||||
"index": "index-dest"
|
||||
},
|
||||
"analysis": {
|
||||
"regression": {}
|
||||
}
|
||||
}
|
||||
|
||||
---
|
||||
"Test put regression given negative lambda":
|
||||
|
||||
- do:
|
||||
catch: /\[lambda\] must be a non-negative double/
|
||||
ml.put_data_frame_analytics:
|
||||
id: "regression-negative-lambda"
|
||||
body: >
|
||||
{
|
||||
"source": {
|
||||
"index": "index-source"
|
||||
},
|
||||
"dest": {
|
||||
"index": "index-dest"
|
||||
},
|
||||
"analysis": {
|
||||
"regression": {
|
||||
"dependent_variable": "foo",
|
||||
"lambda": -1.0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
---
|
||||
"Test put regression given negative gamma":
|
||||
|
||||
- do:
|
||||
catch: /\[gamma\] must be a non-negative double/
|
||||
ml.put_data_frame_analytics:
|
||||
id: "regression-negative-gamma"
|
||||
body: >
|
||||
{
|
||||
"source": {
|
||||
"index": "index-source"
|
||||
},
|
||||
"dest": {
|
||||
"index": "index-dest"
|
||||
},
|
||||
"analysis": {
|
||||
"regression": {
|
||||
"dependent_variable": "foo",
|
||||
"gamma": -1.0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
---
|
||||
"Test put regression given eta less than 1e-3":
|
||||
|
||||
- do:
|
||||
catch: /\[eta\] must be a double in \[0.001, 1\]/
|
||||
ml.put_data_frame_analytics:
|
||||
id: "regression-eta-greater-less-than-valid"
|
||||
body: >
|
||||
{
|
||||
"source": {
|
||||
"index": "index-source"
|
||||
},
|
||||
"dest": {
|
||||
"index": "index-dest"
|
||||
},
|
||||
"analysis": {
|
||||
"regression": {
|
||||
"dependent_variable": "foo",
|
||||
"eta": 0.0009
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
---
|
||||
"Test put regression given eta greater than one":
|
||||
|
||||
- do:
|
||||
catch: /\[eta\] must be a double in \[0.001, 1\]/
|
||||
ml.put_data_frame_analytics:
|
||||
id: "regression-eta-greater-than-one"
|
||||
body: >
|
||||
{
|
||||
"source": {
|
||||
"index": "index-source"
|
||||
},
|
||||
"dest": {
|
||||
"index": "index-dest"
|
||||
},
|
||||
"analysis": {
|
||||
"regression": {
|
||||
"dependent_variable": "foo",
|
||||
"eta": 1.00001
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
---
|
||||
"Test put regression given maximum_number_trees is zero":
|
||||
|
||||
- do:
|
||||
catch: /\[maximum_number_trees\] must be an integer in \[1, 2000\]/
|
||||
ml.put_data_frame_analytics:
|
||||
id: "regression-maximum-number-trees-is-zero"
|
||||
body: >
|
||||
{
|
||||
"source": {
|
||||
"index": "index-source"
|
||||
},
|
||||
"dest": {
|
||||
"index": "index-dest"
|
||||
},
|
||||
"analysis": {
|
||||
"regression": {
|
||||
"dependent_variable": "foo",
|
||||
"maximum_number_trees": 0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
---
|
||||
"Test put regression given maximum_number_trees is greater than 2k":
|
||||
|
||||
- do:
|
||||
catch: /\[maximum_number_trees\] must be an integer in \[1, 2000\]/
|
||||
ml.put_data_frame_analytics:
|
||||
id: "regression-maximum-number-trees-greater-than-2k"
|
||||
body: >
|
||||
{
|
||||
"source": {
|
||||
"index": "index-source"
|
||||
},
|
||||
"dest": {
|
||||
"index": "index-dest"
|
||||
},
|
||||
"analysis": {
|
||||
"regression": {
|
||||
"dependent_variable": "foo",
|
||||
"maximum_number_trees": 2001
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
---
|
||||
"Test put regression given feature_bag_fraction is negative":
|
||||
|
||||
- do:
|
||||
catch: /\[feature_bag_fraction\] must be a double in \(0, 1\]/
|
||||
ml.put_data_frame_analytics:
|
||||
id: "regression-feature-bag-fraction-is-negative"
|
||||
body: >
|
||||
{
|
||||
"source": {
|
||||
"index": "index-source"
|
||||
},
|
||||
"dest": {
|
||||
"index": "index-dest"
|
||||
},
|
||||
"analysis": {
|
||||
"regression": {
|
||||
"dependent_variable": "foo",
|
||||
"feature_bag_fraction": -0.0001
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
---
|
||||
"Test put regression given feature_bag_fraction is greater than one":
|
||||
|
||||
- do:
|
||||
catch: /\[feature_bag_fraction\] must be a double in \(0, 1\]/
|
||||
ml.put_data_frame_analytics:
|
||||
id: "regression-feature-bag-fraction-is-greater-than-one"
|
||||
body: >
|
||||
{
|
||||
"source": {
|
||||
"index": "index-source"
|
||||
},
|
||||
"dest": {
|
||||
"index": "index-dest"
|
||||
},
|
||||
"analysis": {
|
||||
"regression": {
|
||||
"dependent_variable": "foo",
|
||||
"feature_bag_fraction": 1.0001
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
---
|
||||
"Test put regression given valid":
|
||||
|
||||
- do:
|
||||
ml.put_data_frame_analytics:
|
||||
id: "valid-regression"
|
||||
body: >
|
||||
{
|
||||
"source": {
|
||||
"index": "index-source"
|
||||
},
|
||||
"dest": {
|
||||
"index": "index-dest"
|
||||
},
|
||||
"analysis": {
|
||||
"regression": {
|
||||
"dependent_variable": "foo",
|
||||
"lambda": 3.14,
|
||||
"gamma": 0.42,
|
||||
"eta": 0.5,
|
||||
"maximum_number_trees": 400,
|
||||
"feature_bag_fraction": 0.3
|
||||
}
|
||||
}
|
||||
}
|
||||
- match: { id: "valid-regression" }
|
||||
- match: { source.index: ["index-source"] }
|
||||
- match: { dest.index: "index-dest" }
|
||||
- match: { analysis: {
|
||||
"regression":{
|
||||
"dependent_variable": "foo",
|
||||
"lambda": 3.14,
|
||||
"gamma": 0.42,
|
||||
"eta": 0.5,
|
||||
"maximum_number_trees": 400,
|
||||
"feature_bag_fraction": 0.3
|
||||
}
|
||||
}}
|
||||
- is_true: create_time
|
||||
- is_true: version
|
||||
|
|
|
@ -11,15 +11,22 @@ import org.elasticsearch.client.Response;
|
|||
import org.elasticsearch.client.ResponseException;
|
||||
import org.elasticsearch.client.WarningFailureException;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.bytes.BytesReference;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.unit.TimeValue;
|
||||
import org.elasticsearch.common.util.concurrent.ThreadContext;
|
||||
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.json.JsonXContent;
|
||||
import org.elasticsearch.common.xcontent.support.XContentMapValues;
|
||||
import org.elasticsearch.upgrades.AbstractFullClusterRestartTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||
import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig;
|
||||
import org.elasticsearch.xpack.core.ml.job.config.DataDescription;
|
||||
import org.elasticsearch.xpack.core.ml.job.config.Detector;
|
||||
import org.elasticsearch.xpack.core.ml.job.config.Job;
|
||||
import org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings;
|
||||
import org.elasticsearch.xpack.test.rest.XPackRestTestConstants;
|
||||
import org.elasticsearch.xpack.test.rest.XPackRestTestHelper;
|
||||
import org.junit.Before;
|
||||
|
@ -28,12 +35,12 @@ import java.io.IOException;
|
|||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.Base64;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.hamcrest.Matchers.notNullValue;
|
||||
import static org.hamcrest.Matchers.nullValue;
|
||||
|
||||
public class MlConfigIndexMappingsFullClusterRestartIT extends AbstractFullClusterRestartTestCase {
|
||||
|
@ -41,14 +48,23 @@ public class MlConfigIndexMappingsFullClusterRestartIT extends AbstractFullClust
|
|||
private static final String OLD_CLUSTER_JOB_ID = "ml-config-mappings-old-cluster-job";
|
||||
private static final String NEW_CLUSTER_JOB_ID = "ml-config-mappings-new-cluster-job";
|
||||
|
||||
private static final Map<String, Object> EXPECTED_DATA_FRAME_ANALYSIS_MAPPINGS =
|
||||
mapOf(
|
||||
"properties", mapOf(
|
||||
"outlier_detection", mapOf(
|
||||
"properties", mapOf(
|
||||
"method", mapOf("type", "keyword"),
|
||||
"n_neighbors", mapOf("type", "integer"),
|
||||
"feature_influence_threshold", mapOf("type", "double")))));
|
||||
private static final Map<String, Object> EXPECTED_DATA_FRAME_ANALYSIS_MAPPINGS = getDataFrameAnalysisMappings();
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static Map<String, Object> getDataFrameAnalysisMappings() {
|
||||
try (XContentBuilder builder = JsonXContent.contentBuilder()) {
|
||||
builder.startObject();
|
||||
ElasticsearchMappings.addDataFrameAnalyticsFields(builder);
|
||||
builder.endObject();
|
||||
|
||||
Map<String, Object> asMap = builder.generator().contentType().xContent().createParser(
|
||||
NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, BytesReference.bytes(builder).streamInput()).map();
|
||||
return (Map<String, Object>) asMap.get(DataFrameAnalyticsConfig.ANALYSIS.getPreferredName());
|
||||
} catch (IOException e) {
|
||||
fail("Failed to initialize expected data frame analysis mappings");
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Settings restClientSettings() {
|
||||
|
@ -71,8 +87,8 @@ public class MlConfigIndexMappingsFullClusterRestartIT extends AbstractFullClust
|
|||
// trigger .ml-config index creation
|
||||
createAnomalyDetectorJob(OLD_CLUSTER_JOB_ID);
|
||||
if (getOldClusterVersion().onOrAfter(Version.V_7_3_0)) {
|
||||
// .ml-config has correct mappings from the start
|
||||
assertThat(mappingsForDataFrameAnalysis(), is(equalTo(EXPECTED_DATA_FRAME_ANALYSIS_MAPPINGS)));
|
||||
// .ml-config has mappings for analytics as the feature was introduced in 7.3.0
|
||||
assertThat(mappingsForDataFrameAnalysis(), is(notNullValue()));
|
||||
} else {
|
||||
// .ml-config does not yet have correct mappings, it will need an update after cluster is upgraded
|
||||
assertThat(mappingsForDataFrameAnalysis(), is(nullValue()));
|
||||
|
@ -125,18 +141,4 @@ public class MlConfigIndexMappingsFullClusterRestartIT extends AbstractFullClust
|
|||
mappings = (Map<String, Object>) XContentMapValues.extractValue(mappings, "properties", "analysis");
|
||||
return mappings;
|
||||
}
|
||||
|
||||
private static <K, V> Map<K, V> mapOf(K k1, V v1) {
|
||||
Map<K, V> map = new HashMap<>();
|
||||
map.put(k1, v1);
|
||||
return map;
|
||||
}
|
||||
|
||||
private static <K, V> Map<K, V> mapOf(K k1, V v1, K k2, V v2, K k3, V v3) {
|
||||
Map<K, V> map = new HashMap<>();
|
||||
map.put(k1, v1);
|
||||
map.put(k2, v2);
|
||||
map.put(k3, v3);
|
||||
return map;
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue