[7.x] Implement new analysis type: classification (#46537) (#47559)

This commit is contained in:
Przemysław Witek 2019-10-04 13:47:19 +02:00 committed by GitHub
parent 65c473bd4b
commit ec9b77deaa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 1826 additions and 420 deletions

View File

@ -0,0 +1,245 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.Objects;
public class Classification implements DataFrameAnalysis {
public static Classification fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
public static Builder builder(String dependentVariable) {
return new Builder(dependentVariable);
}
public static final ParseField NAME = new ParseField("classification");
static final ParseField DEPENDENT_VARIABLE = new ParseField("dependent_variable");
static final ParseField LAMBDA = new ParseField("lambda");
static final ParseField GAMMA = new ParseField("gamma");
static final ParseField ETA = new ParseField("eta");
static final ParseField MAXIMUM_NUMBER_TREES = new ParseField("maximum_number_trees");
static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction");
static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
private static final ConstructingObjectParser<Classification, Void> PARSER =
new ConstructingObjectParser<>(
NAME.getPreferredName(),
true,
a -> new Classification(
(String) a[0],
(Double) a[1],
(Double) a[2],
(Double) a[3],
(Integer) a[4],
(Double) a[5],
(String) a[6],
(Double) a[7]));
static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE);
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), LAMBDA);
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), GAMMA);
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), ETA);
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), MAXIMUM_NUMBER_TREES);
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), FEATURE_BAG_FRACTION);
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME);
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT);
}
private final String dependentVariable;
private final Double lambda;
private final Double gamma;
private final Double eta;
private final Integer maximumNumberTrees;
private final Double featureBagFraction;
private final String predictionFieldName;
private final Double trainingPercent;
private Classification(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
@Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, @Nullable String predictionFieldName,
@Nullable Double trainingPercent) {
this.dependentVariable = Objects.requireNonNull(dependentVariable);
this.lambda = lambda;
this.gamma = gamma;
this.eta = eta;
this.maximumNumberTrees = maximumNumberTrees;
this.featureBagFraction = featureBagFraction;
this.predictionFieldName = predictionFieldName;
this.trainingPercent = trainingPercent;
}
@Override
public String getName() {
return NAME.getPreferredName();
}
public String getDependentVariable() {
return dependentVariable;
}
public Double getLambda() {
return lambda;
}
public Double getGamma() {
return gamma;
}
public Double getEta() {
return eta;
}
public Integer getMaximumNumberTrees() {
return maximumNumberTrees;
}
public Double getFeatureBagFraction() {
return featureBagFraction;
}
public String getPredictionFieldName() {
return predictionFieldName;
}
public Double getTrainingPercent() {
return trainingPercent;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable);
if (lambda != null) {
builder.field(LAMBDA.getPreferredName(), lambda);
}
if (gamma != null) {
builder.field(GAMMA.getPreferredName(), gamma);
}
if (eta != null) {
builder.field(ETA.getPreferredName(), eta);
}
if (maximumNumberTrees != null) {
builder.field(MAXIMUM_NUMBER_TREES.getPreferredName(), maximumNumberTrees);
}
if (featureBagFraction != null) {
builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction);
}
if (predictionFieldName != null) {
builder.field(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
}
if (trainingPercent != null) {
builder.field(TRAINING_PERCENT.getPreferredName(), trainingPercent);
}
builder.endObject();
return builder;
}
@Override
public int hashCode() {
return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
trainingPercent);
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Classification that = (Classification) o;
return Objects.equals(dependentVariable, that.dependentVariable)
&& Objects.equals(lambda, that.lambda)
&& Objects.equals(gamma, that.gamma)
&& Objects.equals(eta, that.eta)
&& Objects.equals(maximumNumberTrees, that.maximumNumberTrees)
&& Objects.equals(featureBagFraction, that.featureBagFraction)
&& Objects.equals(predictionFieldName, that.predictionFieldName)
&& Objects.equals(trainingPercent, that.trainingPercent);
}
@Override
public String toString() {
return Strings.toString(this);
}
public static class Builder {
private String dependentVariable;
private Double lambda;
private Double gamma;
private Double eta;
private Integer maximumNumberTrees;
private Double featureBagFraction;
private String predictionFieldName;
private Double trainingPercent;
private Builder(String dependentVariable) {
this.dependentVariable = Objects.requireNonNull(dependentVariable);
}
public Builder setLambda(Double lambda) {
this.lambda = lambda;
return this;
}
public Builder setGamma(Double gamma) {
this.gamma = gamma;
return this;
}
public Builder setEta(Double eta) {
this.eta = eta;
return this;
}
public Builder setMaximumNumberTrees(Integer maximumNumberTrees) {
this.maximumNumberTrees = maximumNumberTrees;
return this;
}
public Builder setFeatureBagFraction(Double featureBagFraction) {
this.featureBagFraction = featureBagFraction;
return this;
}
public Builder setPredictionFieldName(String predictionFieldName) {
this.predictionFieldName = predictionFieldName;
return this;
}
public Builder setTrainingPercent(Double trainingPercent) {
this.trainingPercent = trainingPercent;
return this;
}
public Classification build() {
return new Classification(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
trainingPercent);
}
}
}

View File

@ -36,6 +36,10 @@ public class MlDataFrameAnalysisNamedXContentProvider implements NamedXContentPr
new NamedXContentRegistry.Entry(
DataFrameAnalysis.class,
Regression.NAME,
(p, c) -> Regression.fromXContent(p)));
(p, c) -> Regression.fromXContent(p)),
new NamedXContentRegistry.Entry(
DataFrameAnalysis.class,
Classification.NAME,
(p, c) -> Classification.fromXContent(p)));
}
}

View File

@ -49,7 +49,10 @@ public class Regression implements DataFrameAnalysis {
static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
private static final ConstructingObjectParser<Regression, Void> PARSER = new ConstructingObjectParser<>(NAME.getPreferredName(), true,
private static final ConstructingObjectParser<Regression, Void> PARSER =
new ConstructingObjectParser<>(
NAME.getPreferredName(),
true,
a -> new Regression(
(String) a[0],
(Double) a[1],

View File

@ -1315,6 +1315,41 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
assertThat(createdConfig.getDescription(), equalTo("this is a regression"));
}
public void testPutDataFrameAnalyticsConfig_GivenClassification() throws Exception {
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
String configId = "test-put-df-analytics-classification";
DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig.builder()
.setId(configId)
.setSource(DataFrameAnalyticsSource.builder()
.setIndex("put-test-source-index")
.build())
.setDest(DataFrameAnalyticsDest.builder()
.setIndex("put-test-dest-index")
.build())
.setAnalysis(org.elasticsearch.client.ml.dataframe.Classification
.builder("my_dependent_variable")
.setTrainingPercent(80.0)
.build())
.setDescription("this is a classification")
.build();
createIndex("put-test-source-index", defaultMappingForTest());
PutDataFrameAnalyticsResponse putDataFrameAnalyticsResponse = execute(
new PutDataFrameAnalyticsRequest(config),
machineLearningClient::putDataFrameAnalytics, machineLearningClient::putDataFrameAnalyticsAsync);
DataFrameAnalyticsConfig createdConfig = putDataFrameAnalyticsResponse.getConfig();
assertThat(createdConfig.getId(), equalTo(config.getId()));
assertThat(createdConfig.getSource().getIndex(), equalTo(config.getSource().getIndex()));
assertThat(createdConfig.getSource().getQueryConfig(), equalTo(new QueryConfig(new MatchAllQueryBuilder()))); // default value
assertThat(createdConfig.getDest().getIndex(), equalTo(config.getDest().getIndex()));
assertThat(createdConfig.getDest().getResultsField(), equalTo("ml")); // default value
assertThat(createdConfig.getAnalysis(), equalTo(config.getAnalysis()));
assertThat(createdConfig.getAnalyzedFields(), equalTo(config.getAnalyzedFields()));
assertThat(createdConfig.getModelMemoryLimit(), equalTo(ByteSizeValue.parseBytesSizeValue("1gb", ""))); // default value
assertThat(createdConfig.getDescription(), equalTo("this is a classification"));
}
public void testGetDataFrameAnalyticsConfig_SingleConfig() throws Exception {
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
String configId = "get-test-config";

View File

@ -684,7 +684,7 @@ public class RestHighLevelClientTests extends ESTestCase {
public void testProvidedNamedXContents() {
List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
assertEquals(44, namedXContents.size());
assertEquals(45, namedXContents.size());
Map<Class<?>, Integer> categories = new HashMap<>();
List<String> names = new ArrayList<>();
for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
@ -718,9 +718,10 @@ public class RestHighLevelClientTests extends ESTestCase {
assertTrue(names.contains(ShrinkAction.NAME));
assertTrue(names.contains(FreezeAction.NAME));
assertTrue(names.contains(SetPriorityAction.NAME));
assertEquals(Integer.valueOf(2), categories.get(DataFrameAnalysis.class));
assertEquals(Integer.valueOf(3), categories.get(DataFrameAnalysis.class));
assertTrue(names.contains(OutlierDetection.NAME.getPreferredName()));
assertTrue(names.contains(org.elasticsearch.client.ml.dataframe.Regression.NAME.getPreferredName()));
assertTrue(names.contains(org.elasticsearch.client.ml.dataframe.Classification.NAME.getPreferredName()));
assertEquals(Integer.valueOf(1), categories.get(SyncConfig.class));
assertTrue(names.contains(TimeSyncConfig.NAME));
assertEquals(Integer.valueOf(2), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class));

View File

@ -0,0 +1,54 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
public class ClassificationTests extends AbstractXContentTestCase<Classification> {
public static Classification randomClassification() {
return Classification.builder(randomAlphaOfLength(10))
.setLambda(randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, true))
.setGamma(randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, true))
.setEta(randomBoolean() ? null : randomDoubleBetween(0.001, 1.0, true))
.setMaximumNumberTrees(randomBoolean() ? null : randomIntBetween(1, 2000))
.setFeatureBagFraction(randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false))
.setPredictionFieldName(randomBoolean() ? null : randomAlphaOfLength(10))
.setTrainingPercent(randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true))
.build();
}
@Override
protected Classification createTestInstance() {
return randomClassification();
}
@Override
protected Classification doParseInstance(XContentParser parser) throws IOException {
return Classification.fromXContent(parser);
}
@Override
protected boolean supportsUnknownFields() {
return true;
}
}

View File

@ -133,6 +133,7 @@ import org.elasticsearch.xpack.core.ml.action.ValidateDetectorAction;
import org.elasticsearch.xpack.core.ml.action.ValidateJobConfigAction;
import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
@ -466,6 +467,7 @@ public class XPackClientPlugin extends Plugin implements ActionPlugin, NetworkPl
// ML - Data frame analytics
new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, OutlierDetection.NAME.getPreferredName(), OutlierDetection::new),
new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, Regression.NAME.getPreferredName(), Regression::new),
new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, Classification.NAME.getPreferredName(), Classification::new),
// ML - Data frame evaluation
new NamedWriteableRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME.getPreferredName(),
BinarySoftClassification::new),

View File

@ -0,0 +1,156 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.analyses;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.AbstractObjectParser;
import org.elasticsearch.common.xcontent.ToXContentFragment;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
/**
* Parameters used by both {@link Classification} and {@link Regression} analyses.
*/
public class BoostedTreeParams implements ToXContentFragment, Writeable {
static final String NAME = "boosted_tree_params";
public static final ParseField LAMBDA = new ParseField("lambda");
public static final ParseField GAMMA = new ParseField("gamma");
public static final ParseField ETA = new ParseField("eta");
public static final ParseField MAXIMUM_NUMBER_TREES = new ParseField("maximum_number_trees");
public static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction");
static void declareFields(AbstractObjectParser<?, Void> parser) {
parser.declareDouble(optionalConstructorArg(), LAMBDA);
parser.declareDouble(optionalConstructorArg(), GAMMA);
parser.declareDouble(optionalConstructorArg(), ETA);
parser.declareInt(optionalConstructorArg(), MAXIMUM_NUMBER_TREES);
parser.declareDouble(optionalConstructorArg(), FEATURE_BAG_FRACTION);
}
private final Double lambda;
private final Double gamma;
private final Double eta;
private final Integer maximumNumberTrees;
private final Double featureBagFraction;
BoostedTreeParams(@Nullable Double lambda,
@Nullable Double gamma,
@Nullable Double eta,
@Nullable Integer maximumNumberTrees,
@Nullable Double featureBagFraction) {
if (lambda != null && lambda < 0) {
throw ExceptionsHelper.badRequestException("[{}] must be a non-negative double", LAMBDA.getPreferredName());
}
if (gamma != null && gamma < 0) {
throw ExceptionsHelper.badRequestException("[{}] must be a non-negative double", GAMMA.getPreferredName());
}
if (eta != null && (eta < 0.001 || eta > 1)) {
throw ExceptionsHelper.badRequestException("[{}] must be a double in [0.001, 1]", ETA.getPreferredName());
}
if (maximumNumberTrees != null && (maximumNumberTrees <= 0 || maximumNumberTrees > 2000)) {
throw ExceptionsHelper.badRequestException("[{}] must be an integer in [1, 2000]", MAXIMUM_NUMBER_TREES.getPreferredName());
}
if (featureBagFraction != null && (featureBagFraction <= 0 || featureBagFraction > 1.0)) {
throw ExceptionsHelper.badRequestException("[{}] must be a double in (0, 1]", FEATURE_BAG_FRACTION.getPreferredName());
}
this.lambda = lambda;
this.gamma = gamma;
this.eta = eta;
this.maximumNumberTrees = maximumNumberTrees;
this.featureBagFraction = featureBagFraction;
}
BoostedTreeParams() {
this(null, null, null, null, null);
}
BoostedTreeParams(StreamInput in) throws IOException {
lambda = in.readOptionalDouble();
gamma = in.readOptionalDouble();
eta = in.readOptionalDouble();
maximumNumberTrees = in.readOptionalVInt();
featureBagFraction = in.readOptionalDouble();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalDouble(lambda);
out.writeOptionalDouble(gamma);
out.writeOptionalDouble(eta);
out.writeOptionalVInt(maximumNumberTrees);
out.writeOptionalDouble(featureBagFraction);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
if (lambda != null) {
builder.field(LAMBDA.getPreferredName(), lambda);
}
if (gamma != null) {
builder.field(GAMMA.getPreferredName(), gamma);
}
if (eta != null) {
builder.field(ETA.getPreferredName(), eta);
}
if (maximumNumberTrees != null) {
builder.field(MAXIMUM_NUMBER_TREES.getPreferredName(), maximumNumberTrees);
}
if (featureBagFraction != null) {
builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction);
}
return builder;
}
Map<String, Object> getParams() {
Map<String, Object> params = new HashMap<>();
if (lambda != null) {
params.put(LAMBDA.getPreferredName(), lambda);
}
if (gamma != null) {
params.put(GAMMA.getPreferredName(), gamma);
}
if (eta != null) {
params.put(ETA.getPreferredName(), eta);
}
if (maximumNumberTrees != null) {
params.put(MAXIMUM_NUMBER_TREES.getPreferredName(), maximumNumberTrees);
}
if (featureBagFraction != null) {
params.put(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction);
}
return params;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
BoostedTreeParams that = (BoostedTreeParams) o;
return Objects.equals(lambda, that.lambda)
&& Objects.equals(gamma, that.gamma)
&& Objects.equals(eta, that.eta)
&& Objects.equals(maximumNumberTrees, that.maximumNumberTrees)
&& Objects.equals(featureBagFraction, that.featureBagFraction);
}
@Override
public int hashCode() {
return Objects.hash(lambda, gamma, eta, maximumNumberTrees, featureBagFraction);
}
}

View File

@ -0,0 +1,186 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.analyses;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
public class Classification implements DataFrameAnalysis {
public static final ParseField NAME = new ParseField("classification");
public static final ParseField DEPENDENT_VARIABLE = new ParseField("dependent_variable");
public static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
public static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
private static final ConstructingObjectParser<Classification, Void> LENIENT_PARSER = createParser(true);
private static final ConstructingObjectParser<Classification, Void> STRICT_PARSER = createParser(false);
private static ConstructingObjectParser<Classification, Void> createParser(boolean lenient) {
ConstructingObjectParser<Classification, Void> parser = new ConstructingObjectParser<>(
NAME.getPreferredName(),
lenient,
a -> new Classification(
(String) a[0],
new BoostedTreeParams((Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5]),
(String) a[6],
(Integer) a[7],
(Double) a[8]));
parser.declareString(constructorArg(), DEPENDENT_VARIABLE);
BoostedTreeParams.declareFields(parser);
parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME);
parser.declareInt(optionalConstructorArg(), NUM_TOP_CLASSES);
parser.declareDouble(optionalConstructorArg(), TRAINING_PERCENT);
return parser;
}
public static Classification fromXContent(XContentParser parser, boolean ignoreUnknownFields) {
return ignoreUnknownFields ? LENIENT_PARSER.apply(parser, null) : STRICT_PARSER.apply(parser, null);
}
private final String dependentVariable;
private final BoostedTreeParams boostedTreeParams;
private final String predictionFieldName;
private final int numTopClasses;
private final double trainingPercent;
public Classification(String dependentVariable,
BoostedTreeParams boostedTreeParams,
@Nullable String predictionFieldName,
@Nullable Integer numTopClasses,
@Nullable Double trainingPercent) {
if (numTopClasses != null && (numTopClasses < 0 || numTopClasses > 1000)) {
throw ExceptionsHelper.badRequestException("[{}] must be an integer in [0, 1000]", NUM_TOP_CLASSES.getPreferredName());
}
if (trainingPercent != null && (trainingPercent < 1.0 || trainingPercent > 100.0)) {
throw ExceptionsHelper.badRequestException("[{}] must be a double in [1, 100]", TRAINING_PERCENT.getPreferredName());
}
this.dependentVariable = ExceptionsHelper.requireNonNull(dependentVariable, DEPENDENT_VARIABLE);
this.boostedTreeParams = ExceptionsHelper.requireNonNull(boostedTreeParams, BoostedTreeParams.NAME);
this.predictionFieldName = predictionFieldName;
this.numTopClasses = numTopClasses == null ? 0 : numTopClasses;
this.trainingPercent = trainingPercent == null ? 100.0 : trainingPercent;
}
public Classification(String dependentVariable) {
this(dependentVariable, new BoostedTreeParams(), null, null, null);
}
public Classification(StreamInput in) throws IOException {
dependentVariable = in.readString();
boostedTreeParams = new BoostedTreeParams(in);
predictionFieldName = in.readOptionalString();
numTopClasses = in.readOptionalVInt();
trainingPercent = in.readDouble();
}
public String getDependentVariable() {
return dependentVariable;
}
public double getTrainingPercent() {
return trainingPercent;
}
@Override
public String getWriteableName() {
return NAME.getPreferredName();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(dependentVariable);
boostedTreeParams.writeTo(out);
out.writeOptionalString(predictionFieldName);
out.writeOptionalVInt(numTopClasses);
out.writeDouble(trainingPercent);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable);
boostedTreeParams.toXContent(builder, params);
builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses);
if (predictionFieldName != null) {
builder.field(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
}
builder.field(TRAINING_PERCENT.getPreferredName(), trainingPercent);
builder.endObject();
return builder;
}
@Override
public Map<String, Object> getParams() {
Map<String, Object> params = new HashMap<>();
params.put(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable);
params.putAll(boostedTreeParams.getParams());
params.put(NUM_TOP_CLASSES.getPreferredName(), numTopClasses);
if (predictionFieldName != null) {
params.put(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
}
return params;
}
@Override
public boolean supportsCategoricalFields() {
return true;
}
@Override
public List<RequiredField> getRequiredFields() {
return Collections.singletonList(new RequiredField(dependentVariable, Types.categorical()));
}
@Override
public boolean supportsMissingValues() {
return true;
}
@Override
public boolean persistsState() {
return false;
}
@Override
public String getStateDocId(String jobId) {
throw new UnsupportedOperationException();
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Classification that = (Classification) o;
return Objects.equals(dependentVariable, that.dependentVariable)
&& Objects.equals(boostedTreeParams, that.boostedTreeParams)
&& Objects.equals(predictionFieldName, that.predictionFieldName)
&& Objects.equals(numTopClasses, that.numTopClasses)
&& trainingPercent == that.trainingPercent;
}
@Override
public int hashCode() {
return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, numTopClasses, trainingPercent);
}
}

View File

@ -9,35 +9,34 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.plugins.spi.NamedXContentProvider;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
public class MlDataFrameAnalysisNamedXContentProvider implements NamedXContentProvider {
@Override
public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
namedXContent.add(new NamedXContentRegistry.Entry(DataFrameAnalysis.class, OutlierDetection.NAME, (p, c) -> {
return Arrays.asList(
new NamedXContentRegistry.Entry(DataFrameAnalysis.class, OutlierDetection.NAME, (p, c) -> {
boolean ignoreUnknownFields = (boolean) c;
return OutlierDetection.fromXContent(p, ignoreUnknownFields);
}));
namedXContent.add(new NamedXContentRegistry.Entry(DataFrameAnalysis.class, Regression.NAME, (p, c) -> {
}),
new NamedXContentRegistry.Entry(DataFrameAnalysis.class, Regression.NAME, (p, c) -> {
boolean ignoreUnknownFields = (boolean) c;
return Regression.fromXContent(p, ignoreUnknownFields);
}));
return namedXContent;
}),
new NamedXContentRegistry.Entry(DataFrameAnalysis.class, Classification.NAME, (p, c) -> {
boolean ignoreUnknownFields = (boolean) c;
return Classification.fromXContent(p, ignoreUnknownFields);
})
);
}
public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
namedWriteables.add(new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, OutlierDetection.NAME.getPreferredName(),
OutlierDetection::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, Regression.NAME.getPreferredName(),
Regression::new));
return namedWriteables;
return Arrays.asList(
new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, OutlierDetection.NAME.getPreferredName(), OutlierDetection::new),
new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, Regression.NAME.getPreferredName(), Regression::new),
new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, Classification.NAME.getPreferredName(), Classification::new)
);
}
}

View File

@ -21,16 +21,14 @@ import java.util.List;
import java.util.Map;
import java.util.Objects;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
public class Regression implements DataFrameAnalysis {
public static final ParseField NAME = new ParseField("regression");
public static final ParseField DEPENDENT_VARIABLE = new ParseField("dependent_variable");
public static final ParseField LAMBDA = new ParseField("lambda");
public static final ParseField GAMMA = new ParseField("gamma");
public static final ParseField ETA = new ParseField("eta");
public static final ParseField MAXIMUM_NUMBER_TREES = new ParseField("maximum_number_trees");
public static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction");
public static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
public static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
@ -38,17 +36,18 @@ public class Regression implements DataFrameAnalysis {
private static final ConstructingObjectParser<Regression, Void> STRICT_PARSER = createParser(false);
private static ConstructingObjectParser<Regression, Void> createParser(boolean lenient) {
ConstructingObjectParser<Regression, Void> parser = new ConstructingObjectParser<>(NAME.getPreferredName(), lenient,
a -> new Regression((String) a[0], (Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5], (String) a[6],
ConstructingObjectParser<Regression, Void> parser = new ConstructingObjectParser<>(
NAME.getPreferredName(),
lenient,
a -> new Regression(
(String) a[0],
new BoostedTreeParams((Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5]),
(String) a[6],
(Double) a[7]));
parser.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE);
parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), LAMBDA);
parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), GAMMA);
parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), ETA);
parser.declareInt(ConstructingObjectParser.optionalConstructorArg(), MAXIMUM_NUMBER_TREES);
parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), FEATURE_BAG_FRACTION);
parser.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME);
parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT);
parser.declareString(constructorArg(), DEPENDENT_VARIABLE);
BoostedTreeParams.declareFields(parser);
parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME);
parser.declareDouble(optionalConstructorArg(), TRAINING_PERCENT);
return parser;
}
@ -57,63 +56,30 @@ public class Regression implements DataFrameAnalysis {
}
private final String dependentVariable;
private final Double lambda;
private final Double gamma;
private final Double eta;
private final Integer maximumNumberTrees;
private final Double featureBagFraction;
private final BoostedTreeParams boostedTreeParams;
private final String predictionFieldName;
private final double trainingPercent;
public Regression(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
@Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, @Nullable String predictionFieldName,
public Regression(String dependentVariable,
BoostedTreeParams boostedTreeParams,
@Nullable String predictionFieldName,
@Nullable Double trainingPercent) {
this.dependentVariable = Objects.requireNonNull(dependentVariable);
if (lambda != null && lambda < 0) {
throw ExceptionsHelper.badRequestException("[{}] must be a non-negative double", LAMBDA.getPreferredName());
}
this.lambda = lambda;
if (gamma != null && gamma < 0) {
throw ExceptionsHelper.badRequestException("[{}] must be a non-negative double", GAMMA.getPreferredName());
}
this.gamma = gamma;
if (eta != null && (eta < 0.001 || eta > 1)) {
throw ExceptionsHelper.badRequestException("[{}] must be a double in [0.001, 1]", ETA.getPreferredName());
}
this.eta = eta;
if (maximumNumberTrees != null && (maximumNumberTrees <= 0 || maximumNumberTrees > 2000)) {
throw ExceptionsHelper.badRequestException("[{}] must be an integer in [1, 2000]", MAXIMUM_NUMBER_TREES.getPreferredName());
}
this.maximumNumberTrees = maximumNumberTrees;
if (featureBagFraction != null && (featureBagFraction <= 0 || featureBagFraction > 1.0)) {
throw ExceptionsHelper.badRequestException("[{}] must be a double in (0, 1]", FEATURE_BAG_FRACTION.getPreferredName());
}
this.featureBagFraction = featureBagFraction;
this.predictionFieldName = predictionFieldName;
if (trainingPercent != null && (trainingPercent < 1.0 || trainingPercent > 100.0)) {
throw ExceptionsHelper.badRequestException("[{}] must be a double in [1, 100]", TRAINING_PERCENT.getPreferredName());
}
this.dependentVariable = ExceptionsHelper.requireNonNull(dependentVariable, DEPENDENT_VARIABLE);
this.boostedTreeParams = ExceptionsHelper.requireNonNull(boostedTreeParams, BoostedTreeParams.NAME);
this.predictionFieldName = predictionFieldName;
this.trainingPercent = trainingPercent == null ? 100.0 : trainingPercent;
}
public Regression(String dependentVariable) {
this(dependentVariable, null, null, null, null, null, null, null);
this(dependentVariable, new BoostedTreeParams(), null, null);
}
public Regression(StreamInput in) throws IOException {
dependentVariable = in.readString();
lambda = in.readOptionalDouble();
gamma = in.readOptionalDouble();
eta = in.readOptionalDouble();
maximumNumberTrees = in.readOptionalVInt();
featureBagFraction = in.readOptionalDouble();
boostedTreeParams = new BoostedTreeParams(in);
predictionFieldName = in.readOptionalString();
trainingPercent = in.readDouble();
}
@ -134,11 +100,7 @@ public class Regression implements DataFrameAnalysis {
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(dependentVariable);
out.writeOptionalDouble(lambda);
out.writeOptionalDouble(gamma);
out.writeOptionalDouble(eta);
out.writeOptionalVInt(maximumNumberTrees);
out.writeOptionalDouble(featureBagFraction);
boostedTreeParams.writeTo(out);
out.writeOptionalString(predictionFieldName);
out.writeDouble(trainingPercent);
}
@ -147,21 +109,7 @@ public class Regression implements DataFrameAnalysis {
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable);
if (lambda != null) {
builder.field(LAMBDA.getPreferredName(), lambda);
}
if (gamma != null) {
builder.field(GAMMA.getPreferredName(), gamma);
}
if (eta != null) {
builder.field(ETA.getPreferredName(), eta);
}
if (maximumNumberTrees != null) {
builder.field(MAXIMUM_NUMBER_TREES.getPreferredName(), maximumNumberTrees);
}
if (featureBagFraction != null) {
builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction);
}
boostedTreeParams.toXContent(builder, params);
if (predictionFieldName != null) {
builder.field(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
}
@ -174,21 +122,7 @@ public class Regression implements DataFrameAnalysis {
public Map<String, Object> getParams() {
Map<String, Object> params = new HashMap<>();
params.put(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable);
if (lambda != null) {
params.put(LAMBDA.getPreferredName(), lambda);
}
if (gamma != null) {
params.put(GAMMA.getPreferredName(), gamma);
}
if (eta != null) {
params.put(ETA.getPreferredName(), eta);
}
if (maximumNumberTrees != null) {
params.put(MAXIMUM_NUMBER_TREES.getPreferredName(), maximumNumberTrees);
}
if (featureBagFraction != null) {
params.put(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction);
}
params.putAll(boostedTreeParams.getParams());
if (predictionFieldName != null) {
params.put(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
}
@ -220,24 +154,19 @@ public class Regression implements DataFrameAnalysis {
return jobId + "_regression_state#1";
}
@Override
public int hashCode() {
return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
trainingPercent);
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Regression that = (Regression) o;
return Objects.equals(dependentVariable, that.dependentVariable)
&& Objects.equals(lambda, that.lambda)
&& Objects.equals(gamma, that.gamma)
&& Objects.equals(eta, that.eta)
&& Objects.equals(maximumNumberTrees, that.maximumNumberTrees)
&& Objects.equals(featureBagFraction, that.featureBagFraction)
&& Objects.equals(boostedTreeParams, that.boostedTreeParams)
&& Objects.equals(predictionFieldName, that.predictionFieldName)
&& trainingPercent == that.trainingPercent;
}
@Override
public int hashCode() {
return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, trainingPercent);
}
}

View File

@ -7,9 +7,7 @@ package org.elasticsearch.xpack.core.ml.dataframe.analyses;
import org.elasticsearch.index.mapper.NumberFieldMapper;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
@ -21,17 +19,17 @@ public final class Types {
private Types() {}
private static final Set<String> CATEGORICAL_TYPES = Collections.unmodifiableSet(new HashSet<>(Arrays.asList("text", "keyword", "ip")));
private static final Set<String> CATEGORICAL_TYPES =
Collections.unmodifiableSet(
Stream.of("text", "keyword", "ip")
.collect(Collectors.toSet()));
private static final Set<String> NUMERICAL_TYPES;
static {
Set<String> numericalTypes = Stream.of(NumberFieldMapper.NumberType.values())
.map(NumberFieldMapper.NumberType::typeName)
.collect(Collectors.toSet());
numericalTypes.add("scaled_float");
NUMERICAL_TYPES = Collections.unmodifiableSet(numericalTypes);
}
private static final Set<String> NUMERICAL_TYPES =
Collections.unmodifiableSet(
Stream.concat(
Stream.of(NumberFieldMapper.NumberType.values()).map(NumberFieldMapper.NumberType::typeName),
Stream.of("scaled_float"))
.collect(Collectors.toSet()));
public static Set<String> categorical() {
return CATEGORICAL_TYPES;

View File

@ -30,6 +30,8 @@ import org.elasticsearch.xpack.core.ml.datafeed.DelayedDataCheckConfig;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig;
@ -449,19 +451,19 @@ public class ElasticsearchMappings {
.startObject(Regression.DEPENDENT_VARIABLE.getPreferredName())
.field(TYPE, KEYWORD)
.endObject()
.startObject(Regression.LAMBDA.getPreferredName())
.startObject(BoostedTreeParams.LAMBDA.getPreferredName())
.field(TYPE, DOUBLE)
.endObject()
.startObject(Regression.GAMMA.getPreferredName())
.startObject(BoostedTreeParams.GAMMA.getPreferredName())
.field(TYPE, DOUBLE)
.endObject()
.startObject(Regression.ETA.getPreferredName())
.startObject(BoostedTreeParams.ETA.getPreferredName())
.field(TYPE, DOUBLE)
.endObject()
.startObject(Regression.MAXIMUM_NUMBER_TREES.getPreferredName())
.startObject(BoostedTreeParams.MAXIMUM_NUMBER_TREES.getPreferredName())
.field(TYPE, INTEGER)
.endObject()
.startObject(Regression.FEATURE_BAG_FRACTION.getPreferredName())
.startObject(BoostedTreeParams.FEATURE_BAG_FRACTION.getPreferredName())
.field(TYPE, DOUBLE)
.endObject()
.startObject(Regression.PREDICTION_FIELD_NAME.getPreferredName())
@ -472,6 +474,37 @@ public class ElasticsearchMappings {
.endObject()
.endObject()
.endObject()
.startObject(Classification.NAME.getPreferredName())
.startObject(PROPERTIES)
.startObject(Classification.DEPENDENT_VARIABLE.getPreferredName())
.field(TYPE, KEYWORD)
.endObject()
.startObject(BoostedTreeParams.LAMBDA.getPreferredName())
.field(TYPE, DOUBLE)
.endObject()
.startObject(BoostedTreeParams.GAMMA.getPreferredName())
.field(TYPE, DOUBLE)
.endObject()
.startObject(BoostedTreeParams.ETA.getPreferredName())
.field(TYPE, DOUBLE)
.endObject()
.startObject(BoostedTreeParams.MAXIMUM_NUMBER_TREES.getPreferredName())
.field(TYPE, INTEGER)
.endObject()
.startObject(BoostedTreeParams.FEATURE_BAG_FRACTION.getPreferredName())
.field(TYPE, DOUBLE)
.endObject()
.startObject(Classification.PREDICTION_FIELD_NAME.getPreferredName())
.field(TYPE, KEYWORD)
.endObject()
.startObject(Classification.NUM_TOP_CLASSES.getPreferredName())
.field(TYPE, INTEGER)
.endObject()
.startObject(Classification.TRAINING_PERCENT.getPreferredName())
.field(TYPE, DOUBLE)
.endObject()
.endObject()
.endObject()
.endObject()
.endObject()
// re-used: CREATE_TIME

View File

@ -13,6 +13,8 @@ import org.elasticsearch.xpack.core.ml.datafeed.DelayedDataCheckConfig;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig;
@ -303,13 +305,18 @@ public final class ReservedFieldNames {
OutlierDetection.FEATURE_INFLUENCE_THRESHOLD.getPreferredName(),
Regression.NAME.getPreferredName(),
Regression.DEPENDENT_VARIABLE.getPreferredName(),
Regression.LAMBDA.getPreferredName(),
Regression.GAMMA.getPreferredName(),
Regression.ETA.getPreferredName(),
Regression.MAXIMUM_NUMBER_TREES.getPreferredName(),
Regression.FEATURE_BAG_FRACTION.getPreferredName(),
Regression.PREDICTION_FIELD_NAME.getPreferredName(),
Regression.TRAINING_PERCENT.getPreferredName(),
Classification.NAME.getPreferredName(),
Classification.DEPENDENT_VARIABLE.getPreferredName(),
Classification.PREDICTION_FIELD_NAME.getPreferredName(),
Classification.NUM_TOP_CLASSES.getPreferredName(),
Classification.TRAINING_PERCENT.getPreferredName(),
BoostedTreeParams.LAMBDA.getPreferredName(),
BoostedTreeParams.GAMMA.getPreferredName(),
BoostedTreeParams.ETA.getPreferredName(),
BoostedTreeParams.MAXIMUM_NUMBER_TREES.getPreferredName(),
BoostedTreeParams.FEATURE_BAG_FRACTION.getPreferredName(),
ElasticsearchMappings.CONFIG_TYPE,

View File

@ -0,0 +1,105 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.analyses;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractSerializingTestCase;
import java.io.IOException;
import static org.hamcrest.Matchers.equalTo;
public class BoostedTreeParamsTests extends AbstractSerializingTestCase<BoostedTreeParams> {
@Override
protected BoostedTreeParams doParseInstance(XContentParser parser) throws IOException {
ConstructingObjectParser<BoostedTreeParams, Void> objParser =
new ConstructingObjectParser<>(
BoostedTreeParams.NAME,
true,
a -> new BoostedTreeParams((Double) a[0], (Double) a[1], (Double) a[2], (Integer) a[3], (Double) a[4]));
BoostedTreeParams.declareFields(objParser);
return objParser.apply(parser, null);
}
@Override
protected BoostedTreeParams createTestInstance() {
return createRandom();
}
public static BoostedTreeParams createRandom() {
Double lambda = randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, true);
Double gamma = randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, true);
Double eta = randomBoolean() ? null : randomDoubleBetween(0.001, 1.0, true);
Integer maximumNumberTrees = randomBoolean() ? null : randomIntBetween(1, 2000);
Double featureBagFraction = randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false);
return new BoostedTreeParams(lambda, gamma, eta, maximumNumberTrees, featureBagFraction);
}
@Override
protected Writeable.Reader<BoostedTreeParams> instanceReader() {
return BoostedTreeParams::new;
}
public void testConstructor_GivenNegativeLambda() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new BoostedTreeParams(-0.00001, 0.0, 0.5, 500, 0.3));
assertThat(e.getMessage(), equalTo("[lambda] must be a non-negative double"));
}
public void testConstructor_GivenNegativeGamma() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new BoostedTreeParams(0.0, -0.00001, 0.5, 500, 0.3));
assertThat(e.getMessage(), equalTo("[gamma] must be a non-negative double"));
}
public void testConstructor_GivenEtaIsZero() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new BoostedTreeParams(0.0, 0.0, 0.0, 500, 0.3));
assertThat(e.getMessage(), equalTo("[eta] must be a double in [0.001, 1]"));
}
public void testConstructor_GivenEtaIsGreaterThanOne() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new BoostedTreeParams(0.0, 0.0, 1.00001, 500, 0.3));
assertThat(e.getMessage(), equalTo("[eta] must be a double in [0.001, 1]"));
}
public void testConstructor_GivenMaximumNumberTreesIsZero() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new BoostedTreeParams(0.0, 0.0, 0.5, 0, 0.3));
assertThat(e.getMessage(), equalTo("[maximum_number_trees] must be an integer in [1, 2000]"));
}
public void testConstructor_GivenMaximumNumberTreesIsGreaterThan2k() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new BoostedTreeParams(0.0, 0.0, 0.5, 2001, 0.3));
assertThat(e.getMessage(), equalTo("[maximum_number_trees] must be an integer in [1, 2000]"));
}
public void testConstructor_GivenFeatureBagFractionIsLessThanZero() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new BoostedTreeParams(0.0, 0.0, 0.5, 500, -0.00001));
assertThat(e.getMessage(), equalTo("[feature_bag_fraction] must be a double in (0, 1]"));
}
public void testConstructor_GivenFeatureBagFractionIsGreaterThanOne() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.00001));
assertThat(e.getMessage(), equalTo("[feature_bag_fraction] must be a double in (0, 1]"));
}
}

View File

@ -0,0 +1,68 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.analyses;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractSerializingTestCase;
import java.io.IOException;
import static org.hamcrest.Matchers.equalTo;
public class ClassificationTests extends AbstractSerializingTestCase<Classification> {
@Override
protected Classification doParseInstance(XContentParser parser) throws IOException {
return Classification.fromXContent(parser, false);
}
@Override
protected Classification createTestInstance() {
return createRandom();
}
public static Classification createRandom() {
String dependentVariableName = randomAlphaOfLength(10);
BoostedTreeParams boostedTreeParams = BoostedTreeParamsTests.createRandom();
String predictionFieldName = randomBoolean() ? null : randomAlphaOfLength(10);
Integer numTopClasses = randomBoolean() ? null : randomIntBetween(0, 1000);
Double trainingPercent = randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true);
return new Classification(dependentVariableName, boostedTreeParams, predictionFieldName, numTopClasses, trainingPercent);
}
@Override
protected Writeable.Reader<Classification> instanceReader() {
return Classification::new;
}
public void testConstructor_GivenTrainingPercentIsNull() {
Classification classification = new Classification("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 3, null);
assertThat(classification.getTrainingPercent(), equalTo(100.0));
}
public void testConstructor_GivenTrainingPercentIsBoundary() {
Classification classification = new Classification("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 3, 1.0);
assertThat(classification.getTrainingPercent(), equalTo(1.0));
classification = new Classification("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 3, 100.0);
assertThat(classification.getTrainingPercent(), equalTo(100.0));
}
public void testConstructor_GivenTrainingPercentIsLessThanOne() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Classification("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 3, 0.999));
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
}
public void testConstructor_GivenTrainingPercentIsGreaterThan100() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Classification("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 3, 100.0001));
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
}
}

View File

@ -28,15 +28,11 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
}
public static Regression createRandom() {
Double lambda = randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, true);
Double gamma = randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, true);
Double eta = randomBoolean() ? null : randomDoubleBetween(0.001, 1.0, true);
Integer maximumNumberTrees = randomBoolean() ? null : randomIntBetween(1, 2000);
Double featureBagFraction = randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false);
String dependentVariableName = randomAlphaOfLength(10);
BoostedTreeParams boostedTreeParams = BoostedTreeParamsTests.createRandom();
String predictionFieldName = randomBoolean() ? null : randomAlphaOfLength(10);
Double trainingPercent = randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true);
return new Regression(randomAlphaOfLength(10), lambda, gamma, eta, maximumNumberTrees, featureBagFraction,
predictionFieldName, trainingPercent);
return new Regression(dependentVariableName, boostedTreeParams, predictionFieldName, trainingPercent);
}
@Override
@ -44,84 +40,28 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
return Regression::new;
}
public void testRegression_GivenNegativeLambda() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", -0.00001, 0.0, 0.5, 500, 0.3, "result", 100.0));
assertThat(e.getMessage(), equalTo("[lambda] must be a non-negative double"));
}
public void testRegression_GivenNegativeGamma() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", 0.0, -0.00001, 0.5, 500, 0.3, "result", 100.0));
assertThat(e.getMessage(), equalTo("[gamma] must be a non-negative double"));
}
public void testRegression_GivenEtaIsZero() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", 0.0, 0.0, 0.0, 500, 0.3, "result", 100.0));
assertThat(e.getMessage(), equalTo("[eta] must be a double in [0.001, 1]"));
}
public void testRegression_GivenEtaIsGreaterThanOne() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", 0.0, 0.0, 1.00001, 500, 0.3, "result", 100.0));
assertThat(e.getMessage(), equalTo("[eta] must be a double in [0.001, 1]"));
}
public void testRegression_GivenMaximumNumberTreesIsZero() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", 0.0, 0.0, 0.5, 0, 0.3, "result", 100.0));
assertThat(e.getMessage(), equalTo("[maximum_number_trees] must be an integer in [1, 2000]"));
}
public void testRegression_GivenMaximumNumberTreesIsGreaterThan2k() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", 0.0, 0.0, 0.5, 2001, 0.3, "result", 100.0));
assertThat(e.getMessage(), equalTo("[maximum_number_trees] must be an integer in [1, 2000]"));
}
public void testRegression_GivenFeatureBagFractionIsLessThanZero() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", 0.0, 0.0, 0.5, 500, -0.00001, "result", 100.0));
assertThat(e.getMessage(), equalTo("[feature_bag_fraction] must be a double in (0, 1]"));
}
public void testRegression_GivenFeatureBagFractionIsGreaterThanOne() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", 0.0, 0.0, 0.5, 500, 1.00001, "result", 100.0));
assertThat(e.getMessage(), equalTo("[feature_bag_fraction] must be a double in (0, 1]"));
}
public void testRegression_GivenTrainingPercentIsNull() {
Regression regression = new Regression("foo", 0.0, 0.0, 0.5, 500, 1.0, "result", null);
public void testConstructor_GivenTrainingPercentIsNull() {
Regression regression = new Regression("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", null);
assertThat(regression.getTrainingPercent(), equalTo(100.0));
}
public void testRegression_GivenTrainingPercentIsBoundary() {
Regression regression = new Regression("foo", 0.0, 0.0, 0.5, 500, 1.0, "result", 1.0);
public void testConstructor_GivenTrainingPercentIsBoundary() {
Regression regression = new Regression("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 1.0);
assertThat(regression.getTrainingPercent(), equalTo(1.0));
regression = new Regression("foo", 0.0, 0.0, 0.5, 500, 1.0, "result", 100.0);
regression = new Regression("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 100.0);
assertThat(regression.getTrainingPercent(), equalTo(100.0));
}
public void testRegression_GivenTrainingPercentIsLessThanOne() {
public void testConstructor_GivenTrainingPercentIsLessThanOne() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", 0.0, 0.0, 0.5, 500, 1.0, "result", 0.999));
() -> new Regression("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 0.999));
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
}
public void testRegression_GivenTrainingPercentIsGreaterThan100() {
public void testConstructor_GivenTrainingPercentIsGreaterThan100() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", 0.0, 0.0, 0.5, 500, 1.0, "result", 100.0001));
() -> new Regression("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 100.0001));
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
}

View File

@ -73,6 +73,19 @@ integTest.runner {
'ml/data_frame_analytics_crud/Test put regression given feature_bag_fraction is greater than one',
'ml/data_frame_analytics_crud/Test put regression given training_percent is less than one',
'ml/data_frame_analytics_crud/Test put regression given training_percent is greater than hundred',
'ml/data_frame_analytics_crud/Test put classification given dependent_variable is not defined',
'ml/data_frame_analytics_crud/Test put classification given negative lambda',
'ml/data_frame_analytics_crud/Test put classification given negative gamma',
'ml/data_frame_analytics_crud/Test put classification given eta less than 1e-3',
'ml/data_frame_analytics_crud/Test put classification given eta greater than one',
'ml/data_frame_analytics_crud/Test put classification given maximum_number_trees is zero',
'ml/data_frame_analytics_crud/Test put classification given maximum_number_trees is greater than 2k',
'ml/data_frame_analytics_crud/Test put classification given feature_bag_fraction is negative',
'ml/data_frame_analytics_crud/Test put classification given feature_bag_fraction is greater than one',
'ml/data_frame_analytics_crud/Test put classification given num_top_classes is less than zero',
'ml/data_frame_analytics_crud/Test put classification given num_top_classes is greater than 1k',
'ml/data_frame_analytics_crud/Test put classification given training_percent is less than one',
'ml/data_frame_analytics_crud/Test put classification given training_percent is greater than hundred',
'ml/evaluate_data_frame/Test given missing index',
'ml/evaluate_data_frame/Test given index does not exist',
'ml/evaluate_data_frame/Test given missing evaluation',

View File

@ -0,0 +1,314 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.integration;
import com.google.common.collect.Ordering;
import org.elasticsearch.action.bulk.BulkRequestBuilder;
import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.action.get.GetResponse;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParamsTests;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
import org.junit.After;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.hamcrest.Matchers.hasKey;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.in;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.lessThanOrEqualTo;
public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
private static final String NUMERICAL_FEATURE_FIELD = "feature";
private static final String DEPENDENT_VARIABLE_FIELD = "variable";
private static final List<Double> NUMERICAL_FEATURE_VALUES = Collections.unmodifiableList(Arrays.asList(1.0, 2.0, 3.0));
private static final List<String> DEPENDENT_VARIABLE_VALUES = Collections.unmodifiableList(Arrays.asList("dog", "cat", "cow"));
private String jobId;
private String sourceIndex;
private String destIndex;
@After
public void cleanup() throws Exception {
cleanUp();
}
public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception {
initialize("classification_single_numeric_feature_and_mixed_data_set");
{ // Index 350 rows, 300 of them being training rows.
client().admin().indices().prepareCreate(sourceIndex)
.addMapping("_doc", NUMERICAL_FEATURE_FIELD, "type=double", DEPENDENT_VARIABLE_FIELD, "type=keyword")
.get();
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
for (int i = 0; i < 300; i++) {
Double field = NUMERICAL_FEATURE_VALUES.get(i % 3);
String value = DEPENDENT_VARIABLE_VALUES.get(i % 3);
IndexRequest indexRequest = new IndexRequest(sourceIndex)
.source(NUMERICAL_FEATURE_FIELD, field, DEPENDENT_VARIABLE_FIELD, value);
bulkRequestBuilder.add(indexRequest);
}
for (int i = 300; i < 350; i++) {
Double field = NUMERICAL_FEATURE_VALUES.get(i % 3);
IndexRequest indexRequest = new IndexRequest(sourceIndex)
.source(NUMERICAL_FEATURE_FIELD, field);
bulkRequestBuilder.add(indexRequest);
}
BulkResponse bulkResponse = bulkRequestBuilder.get();
if (bulkResponse.hasFailures()) {
fail("Failed to index data: " + bulkResponse.buildFailureMessage());
}
}
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(DEPENDENT_VARIABLE_FIELD));
registerAnalytics(config);
putAnalytics(config);
assertState(jobId, DataFrameAnalyticsState.STOPPED);
assertProgress(jobId, 0, 0, 0, 0);
startAnalytics(jobId);
waitUntilAnalyticsIsStopped(jobId);
SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get();
for (SearchHit hit : sourceData.getHits()) {
Map<String, Object> destDoc = getDestDoc(config, hit);
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(destDoc);
assertThat(resultsObject.containsKey("variable_prediction"), is(true));
assertThat((String) resultsObject.get("variable_prediction"), is(in(DEPENDENT_VARIABLE_VALUES)));
assertThat(resultsObject.containsKey("is_training"), is(true));
assertThat(resultsObject.get("is_training"), is(destDoc.containsKey(DEPENDENT_VARIABLE_FIELD)));
assertThat(resultsObject.containsKey("top_classes"), is(false));
}
assertProgress(jobId, 100, 100, 100, 100);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertThatAuditMessagesMatch(jobId,
"Created analytics with analysis type [classification]",
"Estimated memory usage for this analytics to be",
"Started analytics",
"Creating destination index [" + destIndex + "]",
"Finished reindexing to destination index [" + destIndex + "]",
"Finished analysis");
}
public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Exception {
initialize("classification_only_training_data_and_training_percent_is_100");
indexTrainingData(sourceIndex, 300);
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(DEPENDENT_VARIABLE_FIELD));
registerAnalytics(config);
putAnalytics(config);
assertState(jobId, DataFrameAnalyticsState.STOPPED);
assertProgress(jobId, 0, 0, 0, 0);
startAnalytics(jobId);
waitUntilAnalyticsIsStopped(jobId);
SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get();
for (SearchHit hit : sourceData.getHits()) {
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit));
assertThat(resultsObject.containsKey("variable_prediction"), is(true));
assertThat((String) resultsObject.get("variable_prediction"), is(in(DEPENDENT_VARIABLE_VALUES)));
assertThat(resultsObject.containsKey("is_training"), is(true));
assertThat(resultsObject.get("is_training"), is(true));
assertThat(resultsObject.containsKey("top_classes"), is(false));
}
assertProgress(jobId, 100, 100, 100, 100);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertThatAuditMessagesMatch(jobId,
"Created analytics with analysis type [classification]",
"Estimated memory usage for this analytics to be",
"Started analytics",
"Creating destination index [" + destIndex + "]",
"Finished reindexing to destination index [" + destIndex + "]",
"Finished analysis");
}
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception {
initialize("classification_only_training_data_and_training_percent_is_50");
indexTrainingData(sourceIndex, 300);
DataFrameAnalyticsConfig config =
buildAnalytics(
jobId,
sourceIndex,
destIndex,
null,
new Classification(DEPENDENT_VARIABLE_FIELD, BoostedTreeParamsTests.createRandom(), null, null, 50.0));
registerAnalytics(config);
putAnalytics(config);
assertState(jobId, DataFrameAnalyticsState.STOPPED);
assertProgress(jobId, 0, 0, 0, 0);
startAnalytics(jobId);
waitUntilAnalyticsIsStopped(jobId);
int trainingRowsCount = 0;
int nonTrainingRowsCount = 0;
SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get();
for (SearchHit hit : sourceData.getHits()) {
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit));
assertThat(resultsObject.containsKey("variable_prediction"), is(true));
assertThat((String) resultsObject.get("variable_prediction"), is(in(DEPENDENT_VARIABLE_VALUES)));
assertThat(resultsObject.containsKey("is_training"), is(true));
// Let's just assert there's both training and non-training results
if ((boolean) resultsObject.get("is_training")) {
trainingRowsCount++;
} else {
nonTrainingRowsCount++;
}
assertThat(resultsObject.containsKey("top_classes"), is(false));
}
assertThat(trainingRowsCount, greaterThan(0));
assertThat(nonTrainingRowsCount, greaterThan(0));
assertProgress(jobId, 100, 100, 100, 100);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertThatAuditMessagesMatch(jobId,
"Created analytics with analysis type [classification]",
"Estimated memory usage for this analytics to be",
"Started analytics",
"Creating destination index [" + destIndex + "]",
"Finished reindexing to destination index [" + destIndex + "]",
"Finished analysis");
}
@AwaitsFix(bugUrl = "https://github.com/elastic/ml-cpp/issues/712")
public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows_TopClassesRequested() throws Exception {
initialize("classification_top_classes_requested");
indexTrainingData(sourceIndex, 300);
int numTopClasses = 2;
DataFrameAnalyticsConfig config =
buildAnalytics(
jobId,
sourceIndex,
destIndex,
null,
new Classification(DEPENDENT_VARIABLE_FIELD, BoostedTreeParamsTests.createRandom(), null, numTopClasses, null));
registerAnalytics(config);
putAnalytics(config);
assertState(jobId, DataFrameAnalyticsState.STOPPED);
assertProgress(jobId, 0, 0, 0, 0);
startAnalytics(jobId);
waitUntilAnalyticsIsStopped(jobId);
SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get();
for (SearchHit hit : sourceData.getHits()) {
Map<String, Object> destDoc = getDestDoc(config, hit);
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(destDoc);
assertThat(resultsObject.containsKey("variable_prediction"), is(true));
assertThat((String) resultsObject.get("variable_prediction"), is(in(DEPENDENT_VARIABLE_VALUES)));
assertTopClasses(resultsObject, numTopClasses);
}
assertProgress(jobId, 100, 100, 100, 100);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertThatAuditMessagesMatch(jobId,
"Created analytics with analysis type [classification]",
"Estimated memory usage for this analytics to be",
"Started analytics",
"Creating destination index [" + destIndex + "]",
"Finished reindexing to destination index [" + destIndex + "]",
"Finished analysis");
}
private void initialize(String jobId) {
this.jobId = jobId;
this.sourceIndex = jobId + "_source_index";
this.destIndex = sourceIndex + "_results";
}
private static void indexTrainingData(String sourceIndex, int numRows) {
client().admin().indices().prepareCreate(sourceIndex)
.addMapping("_doc", NUMERICAL_FEATURE_FIELD, "type=double", DEPENDENT_VARIABLE_FIELD, "type=keyword")
.get();
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
for (int i = 0; i < numRows; i++) {
Double field = NUMERICAL_FEATURE_VALUES.get(i % 3);
String value = DEPENDENT_VARIABLE_VALUES.get(i % 3);
IndexRequest indexRequest = new IndexRequest(sourceIndex)
.source(NUMERICAL_FEATURE_FIELD, field, DEPENDENT_VARIABLE_FIELD, value);
bulkRequestBuilder.add(indexRequest);
}
BulkResponse bulkResponse = bulkRequestBuilder.get();
if (bulkResponse.hasFailures()) {
fail("Failed to index data: " + bulkResponse.buildFailureMessage());
}
}
private static Map<String, Object> getDestDoc(DataFrameAnalyticsConfig config, SearchHit hit) {
GetResponse destDocGetResponse = client().prepareGet().setIndex(config.getDest().getIndex()).setId(hit.getId()).get();
assertThat(destDocGetResponse.isExists(), is(true));
Map<String, Object> sourceDoc = hit.getSourceAsMap();
Map<String, Object> destDoc = destDocGetResponse.getSource();
for (String field : sourceDoc.keySet()) {
assertThat(destDoc.containsKey(field), is(true));
assertThat(destDoc.get(field), equalTo(sourceDoc.get(field)));
}
return destDoc;
}
private static Map<String, Object> getMlResultsObjectFromDestDoc(Map<String, Object> destDoc) {
assertThat(destDoc.containsKey("ml"), is(true));
@SuppressWarnings("unchecked")
Map<String, Object> resultsObject = (Map<String, Object>) destDoc.get("ml");
return resultsObject;
}
private static void assertTopClasses(Map<String, Object> resultsObject, int numTopClasses) {
assertThat(resultsObject.containsKey("top_classes"), is(true));
List<Map<String, Object>> topClasses = (List<Map<String, Object>>) resultsObject.get("top_classes");
assertThat(topClasses, hasSize(numTopClasses));
List<String> classNames = new ArrayList<>(topClasses.size());
List<Double> classProbabilities = new ArrayList<>(topClasses.size());
for (Map<String, Object> topClass : topClasses) {
assertThat(topClass, allOf(hasKey("class_name"), hasKey("class_probability")));
classNames.add((String) topClass.get("class_name"));
classProbabilities.add((Double) topClass.get("class_probability"));
}
// Assert that all the class names come from the set of dependent variable values.
classNames.forEach(className -> assertThat(className, is(in(DEPENDENT_VARIABLE_VALUES))));
// Assert that the first class listed in top classes is the same as the predicted class.
assertThat(classNames.get(0), equalTo(resultsObject.get("variable_prediction")));
// Assert that all the class probabilities lie within [0, 1] interval.
classProbabilities.forEach(p -> assertThat(p, allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0))));
// Assert that the top classes are listed in the order of decreasing probabilities.
assertThat(Ordering.natural().reverse().isOrdered(classProbabilities), is(true));
}
}

View File

@ -27,8 +27,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
import org.elasticsearch.xpack.core.ml.notifications.AuditorField;
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
@ -136,13 +135,13 @@ abstract class MlNativeDataFrameAnalyticsIntegTestCase extends MlNativeIntegTest
return response.getResponse().results();
}
protected static DataFrameAnalyticsConfig buildOutlierDetectionAnalytics(String id, String[] sourceIndex, String destIndex,
@Nullable String resultsField) {
protected static DataFrameAnalyticsConfig buildAnalytics(String id, String sourceIndex, String destIndex,
@Nullable String resultsField, DataFrameAnalysis analysis) {
DataFrameAnalyticsConfig.Builder configBuilder = new DataFrameAnalyticsConfig.Builder();
configBuilder.setId(id);
configBuilder.setSource(new DataFrameAnalyticsSource(sourceIndex, null));
configBuilder.setSource(new DataFrameAnalyticsSource(new String[] { sourceIndex }, null));
configBuilder.setDest(new DataFrameAnalyticsDest(destIndex, resultsField));
configBuilder.setAnalysis(new OutlierDetection());
configBuilder.setAnalysis(analysis);
return configBuilder.build();
}
@ -175,16 +174,6 @@ abstract class MlNativeDataFrameAnalyticsIntegTestCase extends MlNativeIntegTest
.get();
}
protected static DataFrameAnalyticsConfig buildRegressionAnalytics(String id, String[] sourceIndex, String destIndex,
@Nullable String resultsField, Regression regression) {
DataFrameAnalyticsConfig.Builder configBuilder = new DataFrameAnalyticsConfig.Builder();
configBuilder.setId(id);
configBuilder.setSource(new DataFrameAnalyticsSource(sourceIndex, null));
configBuilder.setDest(new DataFrameAnalyticsDest(destIndex, resultsField));
configBuilder.setAnalysis(regression);
return configBuilder.build();
}
/**
* Asserts whether the audit messages fetched from index match provided prefixes.
* More specifically, in order to pass:

View File

@ -14,6 +14,7 @@ import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection;
import org.junit.After;
import java.util.Map;
@ -68,7 +69,7 @@ public class OutlierDetectionWithMissingFieldsIT extends MlNativeDataFrameAnalyt
}
String id = "test_outlier_detection_with_missing_fields";
DataFrameAnalyticsConfig config = buildOutlierDetectionAnalytics(id, new String[] {sourceIndex}, sourceIndex + "-results", null);
DataFrameAnalyticsConfig config = buildAnalytics(id, sourceIndex, sourceIndex + "-results", null, new OutlierDetection());
registerAnalytics(config);
putAnalytics(config);

View File

@ -15,11 +15,13 @@ import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParamsTests;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
import org.junit.After;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
@ -30,41 +32,52 @@ import static org.hamcrest.Matchers.is;
public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
private static final String NUMERICAL_FEATURE_FIELD = "feature";
private static final String DEPENDENT_VARIABLE_FIELD = "variable";
private static final List<Double> NUMERICAL_FEATURE_VALUES = Collections.unmodifiableList(Arrays.asList(1.0, 2.0, 3.0));
private static final List<Double> DEPENDENT_VARIABLE_VALUES = Collections.unmodifiableList(Arrays.asList(10.0, 20.0, 30.0));
private String jobId;
private String sourceIndex;
private String destIndex;
@After
public void cleanup() {
public void cleanup() throws Exception {
cleanUp();
}
public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception {
String jobId = "regression_single_numeric_feature_and_mixed_data_set";
String sourceIndex = jobId + "_source_index";
initialize("regression_single_numeric_feature_and_mixed_data_set");
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk();
bulkRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
{ // Index 350 rows, 300 of them being training rows.
client().admin().indices().prepareCreate(sourceIndex)
.addMapping("_doc", NUMERICAL_FEATURE_FIELD, "type=double", DEPENDENT_VARIABLE_FIELD, "type=double")
.get();
List<Double> featureValues = Arrays.asList(1.0, 2.0, 3.0);
List<Double> dependentVariableValues = Arrays.asList(10.0, 20.0, 30.0);
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
for (int i = 0; i < 300; i++) {
Double field = NUMERICAL_FEATURE_VALUES.get(i % 3);
Double value = DEPENDENT_VARIABLE_VALUES.get(i % 3);
for (int i = 0; i < 350; i++) {
Double field = featureValues.get(i % 3);
Double value = dependentVariableValues.get(i % 3);
IndexRequest indexRequest = new IndexRequest(sourceIndex);
if (i < 300) {
indexRequest.source("feature", field, "variable", value);
} else {
indexRequest.source("feature", field);
IndexRequest indexRequest = new IndexRequest(sourceIndex)
.source(NUMERICAL_FEATURE_FIELD, field, DEPENDENT_VARIABLE_FIELD, value);
bulkRequestBuilder.add(indexRequest);
}
for (int i = 300; i < 350; i++) {
Double field = NUMERICAL_FEATURE_VALUES.get(i % 3);
IndexRequest indexRequest = new IndexRequest(sourceIndex)
.source(NUMERICAL_FEATURE_FIELD, field);
bulkRequestBuilder.add(indexRequest);
}
BulkResponse bulkResponse = bulkRequestBuilder.get();
if (bulkResponse.hasFailures()) {
fail("Failed to index data: " + bulkResponse.buildFailureMessage());
}
}
String destIndex = sourceIndex + "_results";
DataFrameAnalyticsConfig config = buildRegressionAnalytics(jobId, new String[] {sourceIndex}, destIndex, null,
new Regression("variable"));
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Regression(DEPENDENT_VARIABLE_FIELD));
registerAnalytics(config);
putAnalytics(config);
@ -76,71 +89,54 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get();
for (SearchHit hit : sourceData.getHits()) {
GetResponse destDocGetResponse = client().prepareGet().setIndex(config.getDest().getIndex()).setId(hit.getId()).get();
assertThat(destDocGetResponse.isExists(), is(true));
Map<String, Object> sourceDoc = hit.getSourceAsMap();
Map<String, Object> destDoc = destDocGetResponse.getSource();
for (String field : sourceDoc.keySet()) {
assertThat(destDoc.containsKey(field), is(true));
assertThat(destDoc.get(field), equalTo(sourceDoc.get(field)));
}
assertThat(destDoc.containsKey("ml"), is(true));
@SuppressWarnings("unchecked")
Map<String, Object> resultsObject = (Map<String, Object>) destDoc.get("ml");
assertThat(resultsObject.containsKey("variable_prediction"), is(true));
Map<String, Object> destDoc = getDestDoc(config, hit);
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(destDoc);
// TODO reenable this assertion when the backend is stable
// it seems for this case values can be as far off as 2.0
// double featureValue = (double) destDoc.get("feature");
// double featureValue = (double) destDoc.get(NUMERICAL_FEATURE_FIELD);
// double predictionValue = (double) resultsObject.get("variable_prediction");
// assertThat(predictionValue, closeTo(10 * featureValue, 2.0));
boolean expectedIsTraining = destDoc.containsKey("variable");
assertThat(resultsObject.containsKey("variable_prediction"), is(true));
assertThat(resultsObject.containsKey("is_training"), is(true));
assertThat(resultsObject.get("is_training"), is(expectedIsTraining));
assertThat(resultsObject.get("is_training"), is(destDoc.containsKey(DEPENDENT_VARIABLE_FIELD)));
}
assertProgress(jobId, 100, 100, 100, 100);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(jobId);
assertThatAuditMessagesMatch(jobId,
"Created analytics with analysis type [regression]",
"Estimated memory usage for this analytics to be",
"Started analytics",
"Creating destination index [regression_single_numeric_feature_and_mixed_data_set_source_index_results]",
"Finished reindexing to destination index [regression_single_numeric_feature_and_mixed_data_set_source_index_results]",
"Creating destination index [" + destIndex + "]",
"Finished reindexing to destination index [" + destIndex + "]",
"Finished analysis");
assertModelStatePersisted(jobId);
}
public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Exception {
String jobId = "regression_only_training_data_and_training_percent_is_hundred";
String sourceIndex = jobId + "_source_index";
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk();
bulkRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
List<Double> featureValues = Arrays.asList(1.0, 2.0, 3.0);
List<Double> dependentVariableValues = Arrays.asList(10.0, 20.0, 30.0);
initialize("regression_only_training_data_and_training_percent_is_100");
{ // Index 350 rows, all of them being training rows.
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
for (int i = 0; i < 350; i++) {
Double field = featureValues.get(i % 3);
Double value = dependentVariableValues.get(i % 3);
Double field = NUMERICAL_FEATURE_VALUES.get(i % 3);
Double value = DEPENDENT_VARIABLE_VALUES.get(i % 3);
IndexRequest indexRequest = new IndexRequest(sourceIndex);
indexRequest.source("feature", field, "variable", value);
IndexRequest indexRequest = new IndexRequest(sourceIndex)
.source(NUMERICAL_FEATURE_FIELD, field, DEPENDENT_VARIABLE_FIELD, value);
bulkRequestBuilder.add(indexRequest);
}
BulkResponse bulkResponse = bulkRequestBuilder.get();
if (bulkResponse.hasFailures()) {
fail("Failed to index data: " + bulkResponse.buildFailureMessage());
}
}
String destIndex = sourceIndex + "_results";
DataFrameAnalyticsConfig config = buildRegressionAnalytics(jobId, new String[] {sourceIndex}, destIndex, null,
new Regression("variable"));
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Regression(DEPENDENT_VARIABLE_FIELD));
registerAnalytics(config);
putAnalytics(config);
@ -152,18 +148,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get();
for (SearchHit hit : sourceData.getHits()) {
GetResponse destDocGetResponse = client().prepareGet().setIndex(config.getDest().getIndex()).setId(hit.getId()).get();
assertThat(destDocGetResponse.isExists(), is(true));
Map<String, Object> sourceDoc = hit.getSourceAsMap();
Map<String, Object> destDoc = destDocGetResponse.getSource();
for (String field : sourceDoc.keySet()) {
assertThat(destDoc.containsKey(field), is(true));
assertThat(destDoc.get(field), equalTo(sourceDoc.get(field)));
}
assertThat(destDoc.containsKey("ml"), is(true));
@SuppressWarnings("unchecked")
Map<String, Object> resultsObject = (Map<String, Object>) destDoc.get("ml");
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit));
assertThat(resultsObject.containsKey("variable_prediction"), is(true));
assertThat(resultsObject.containsKey("is_training"), is(true));
@ -172,42 +157,43 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
assertProgress(jobId, 100, 100, 100, 100);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(jobId);
assertThatAuditMessagesMatch(jobId,
"Created analytics with analysis type [regression]",
"Estimated memory usage for this analytics to be",
"Started analytics",
"Creating destination index [regression_only_training_data_and_training_percent_is_hundred_source_index_results]",
"Finished reindexing to destination index [regression_only_training_data_and_training_percent_is_hundred_source_index_results]",
"Creating destination index [" + destIndex + "]",
"Finished reindexing to destination index [" + destIndex + "]",
"Finished analysis");
assertModelStatePersisted(jobId);
}
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception {
String jobId = "regression_only_training_data_and_training_percent_is_fifty";
String sourceIndex = jobId + "_source_index";
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk();
bulkRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
List<Double> featureValues = Arrays.asList(1.0, 2.0, 3.0);
List<Double> dependentVariableValues = Arrays.asList(10.0, 20.0, 30.0);
initialize("regression_only_training_data_and_training_percent_is_50");
{ // Index 350 rows, all of them being training rows.
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
for (int i = 0; i < 350; i++) {
Double field = featureValues.get(i % 3);
Double value = dependentVariableValues.get(i % 3);
Double field = NUMERICAL_FEATURE_VALUES.get(i % 3);
Double value = DEPENDENT_VARIABLE_VALUES.get(i % 3);
IndexRequest indexRequest = new IndexRequest(sourceIndex);
indexRequest.source("feature", field, "variable", value);
IndexRequest indexRequest = new IndexRequest(sourceIndex)
.source(NUMERICAL_FEATURE_FIELD, field, DEPENDENT_VARIABLE_FIELD, value);
bulkRequestBuilder.add(indexRequest);
}
BulkResponse bulkResponse = bulkRequestBuilder.get();
if (bulkResponse.hasFailures()) {
fail("Failed to index data: " + bulkResponse.buildFailureMessage());
}
}
String destIndex = sourceIndex + "_results";
DataFrameAnalyticsConfig config = buildRegressionAnalytics(jobId, new String[] {sourceIndex}, destIndex, null,
new Regression("variable", null, null, null, null, null, null, 50.0));
DataFrameAnalyticsConfig config =
buildAnalytics(
jobId,
sourceIndex,
destIndex,
null,
new Regression(DEPENDENT_VARIABLE_FIELD, BoostedTreeParamsTests.createRandom(), null, 50.0));
registerAnalytics(config);
putAnalytics(config);
@ -221,21 +207,9 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
int nonTrainingRowsCount = 0;
SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get();
for (SearchHit hit : sourceData.getHits()) {
GetResponse destDocGetResponse = client().prepareGet().setIndex(config.getDest().getIndex()).setId(hit.getId()).get();
assertThat(destDocGetResponse.isExists(), is(true));
Map<String, Object> sourceDoc = hit.getSourceAsMap();
Map<String, Object> destDoc = destDocGetResponse.getSource();
for (String field : sourceDoc.keySet()) {
assertThat(destDoc.containsKey(field), is(true));
assertThat(destDoc.get(field), equalTo(sourceDoc.get(field)));
}
assertThat(destDoc.containsKey("ml"), is(true));
@SuppressWarnings("unchecked")
Map<String, Object> resultsObject = (Map<String, Object>) destDoc.get("ml");
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit));
assertThat(resultsObject.containsKey("variable_prediction"), is(true));
assertThat(resultsObject.containsKey("is_training"), is(true));
// Let's just assert there's both training and non-training results
if ((boolean) resultsObject.get("is_training")) {
@ -249,32 +223,27 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
assertProgress(jobId, 100, 100, 100, 100);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(jobId);
assertThatAuditMessagesMatch(jobId,
"Created analytics with analysis type [regression]",
"Estimated memory usage for this analytics to be",
"Started analytics",
"Creating destination index [regression_only_training_data_and_training_percent_is_fifty_source_index_results]",
"Finished reindexing to destination index [regression_only_training_data_and_training_percent_is_fifty_source_index_results]",
"Creating destination index [" + destIndex + "]",
"Finished reindexing to destination index [" + destIndex + "]",
"Finished analysis");
assertModelStatePersisted(jobId);
}
public void testStopAndRestart() throws Exception {
String jobId = "regression_stop_and_restart";
String sourceIndex = jobId + "_source_index";
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk();
bulkRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
List<Double> featureValues = Arrays.asList(1.0, 2.0, 3.0);
List<Double> dependentVariableValues = Arrays.asList(10.0, 20.0, 30.0);
initialize("regression_stop_and_restart");
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
for (int i = 0; i < 350; i++) {
Double field = featureValues.get(i % 3);
Double value = dependentVariableValues.get(i % 3);
Double field = NUMERICAL_FEATURE_VALUES.get(i % 3);
Double value = DEPENDENT_VARIABLE_VALUES.get(i % 3);
IndexRequest indexRequest = new IndexRequest(sourceIndex);
indexRequest.source("feature", field, "variable", value);
IndexRequest indexRequest = new IndexRequest(sourceIndex)
.source("feature", field, "variable", value);
bulkRequestBuilder.add(indexRequest);
}
BulkResponse bulkResponse = bulkRequestBuilder.get();
@ -282,9 +251,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
fail("Failed to index data: " + bulkResponse.buildFailureMessage());
}
String destIndex = sourceIndex + "_results";
DataFrameAnalyticsConfig config = buildRegressionAnalytics(jobId, new String[] {sourceIndex}, destIndex, null,
new Regression("variable"));
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Regression(DEPENDENT_VARIABLE_FIELD));
registerAnalytics(config);
putAnalytics(config);
@ -317,18 +284,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get();
for (SearchHit hit : sourceData.getHits()) {
GetResponse destDocGetResponse = client().prepareGet().setIndex(config.getDest().getIndex()).setId(hit.getId()).get();
assertThat(destDocGetResponse.isExists(), is(true));
Map<String, Object> sourceDoc = hit.getSourceAsMap();
Map<String, Object> destDoc = destDocGetResponse.getSource();
for (String field : sourceDoc.keySet()) {
assertThat(destDoc.containsKey(field), is(true));
assertThat(destDoc.get(field), equalTo(sourceDoc.get(field)));
}
assertThat(destDoc.containsKey("ml"), is(true));
@SuppressWarnings("unchecked")
Map<String, Object> resultsObject = (Map<String, Object>) destDoc.get("ml");
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit));
assertThat(resultsObject.containsKey("variable_prediction"), is(true));
assertThat(resultsObject.containsKey("is_training"), is(true));
@ -340,7 +296,32 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
assertModelStatePersisted(jobId);
}
private void assertModelStatePersisted(String jobId) {
private void initialize(String jobId) {
this.jobId = jobId;
this.sourceIndex = jobId + "_source_index";
this.destIndex = sourceIndex + "_results";
}
private static Map<String, Object> getDestDoc(DataFrameAnalyticsConfig config, SearchHit hit) {
GetResponse destDocGetResponse = client().prepareGet().setIndex(config.getDest().getIndex()).setId(hit.getId()).get();
assertThat(destDocGetResponse.isExists(), is(true));
Map<String, Object> sourceDoc = hit.getSourceAsMap();
Map<String, Object> destDoc = destDocGetResponse.getSource();
for (String field : sourceDoc.keySet()) {
assertThat(destDoc.containsKey(field), is(true));
assertThat(destDoc.get(field), equalTo(sourceDoc.get(field)));
}
return destDoc;
}
private static Map<String, Object> getMlResultsObjectFromDestDoc(Map<String, Object> destDoc) {
assertThat(destDoc.containsKey("ml"), is(true));
@SuppressWarnings("unchecked")
Map<String, Object> resultsObject = (Map<String, Object>) destDoc.get("ml");
return resultsObject;
}
private static void assertModelStatePersisted(String jobId) {
String docId = jobId + "_regression_state#1";
SearchResponse searchResponse = client().prepareSearch(AnomalyDetectorsIndex.jobStateIndexPattern())
.setQuery(QueryBuilders.idsQuery().addIds(docId))

View File

@ -72,7 +72,7 @@ public class RunDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsIntegTest
}
String id = "test_outlier_detection_with_few_docs";
DataFrameAnalyticsConfig config = buildOutlierDetectionAnalytics(id, new String[] {sourceIndex}, sourceIndex + "-results", null);
DataFrameAnalyticsConfig config = buildAnalytics(id, sourceIndex, sourceIndex + "-results", null, new OutlierDetection());
registerAnalytics(config);
putAnalytics(config);
@ -147,8 +147,7 @@ public class RunDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsIntegTest
}
String id = "test_outlier_detection_with_enough_docs_to_scroll";
DataFrameAnalyticsConfig config = buildOutlierDetectionAnalytics(
id, new String[] {sourceIndex}, sourceIndex + "-results", "custom_ml");
DataFrameAnalyticsConfig config = buildAnalytics(id, sourceIndex, sourceIndex + "-results", "custom_ml", new OutlierDetection());
registerAnalytics(config);
putAnalytics(config);
@ -217,7 +216,7 @@ public class RunDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsIntegTest
}
String id = "test_outlier_detection_with_more_fields_than_docvalue_limit";
DataFrameAnalyticsConfig config = buildOutlierDetectionAnalytics(id, new String[] {sourceIndex}, sourceIndex + "-results", null);
DataFrameAnalyticsConfig config = buildAnalytics(id, sourceIndex, sourceIndex + "-results", null, new OutlierDetection());
registerAnalytics(config);
putAnalytics(config);
@ -280,8 +279,7 @@ public class RunDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsIntegTest
}
String id = "test_stop_outlier_detection_with_enough_docs_to_scroll";
DataFrameAnalyticsConfig config = buildOutlierDetectionAnalytics(
id, new String[] {sourceIndex}, sourceIndex + "-results", "custom_ml");
DataFrameAnalyticsConfig config = buildAnalytics(id, sourceIndex, sourceIndex + "-results", "custom_ml", new OutlierDetection());
registerAnalytics(config);
putAnalytics(config);
@ -345,7 +343,12 @@ public class RunDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsIntegTest
}
String id = "test_outlier_detection_with_multiple_source_indices";
DataFrameAnalyticsConfig config = buildOutlierDetectionAnalytics(id, sourceIndex, destIndex, null);
DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder()
.setId(id)
.setSource(new DataFrameAnalyticsSource(sourceIndex, null))
.setDest(new DataFrameAnalyticsDest(destIndex, null))
.setAnalysis(new OutlierDetection())
.build();
registerAnalytics(config);
putAnalytics(config);
@ -402,7 +405,7 @@ public class RunDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsIntegTest
}
String id = "test_outlier_detection_with_pre_existing_dest_index";
DataFrameAnalyticsConfig config = buildOutlierDetectionAnalytics(id, new String[] {sourceIndex}, destIndex, null);
DataFrameAnalyticsConfig config = buildAnalytics(id, sourceIndex, destIndex, null, new OutlierDetection());
registerAnalytics(config);
putAnalytics(config);
@ -500,8 +503,7 @@ public class RunDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsIntegTest
}
String id = "test_outlier_detection_stop_and_restart";
DataFrameAnalyticsConfig config = buildOutlierDetectionAnalytics(
id, new String[] {sourceIndex}, sourceIndex + "-results", "custom_ml");
DataFrameAnalyticsConfig config = buildAnalytics(id, sourceIndex, sourceIndex + "-results", "custom_ml", new OutlierDetection());
registerAnalytics(config);
putAnalytics(config);

View File

@ -5,6 +5,7 @@
*/
package org.elasticsearch.xpack.ml.dataframe.process.customprocessing;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
@ -21,7 +22,14 @@ public class CustomProcessorFactory {
public CustomProcessor create(DataFrameAnalysis analysis) {
if (analysis instanceof Regression) {
return new RegressionCustomProcessor(fieldNames, (Regression) analysis);
Regression regression = (Regression) analysis;
return new DatasetSplittingCustomProcessor(
fieldNames, regression.getDependentVariable(), regression.getTrainingPercent());
}
if (analysis instanceof Classification) {
Classification classification = (Classification) analysis;
return new DatasetSplittingCustomProcessor(
fieldNames, classification.getDependentVariable(), classification.getTrainingPercent());
}
return row -> {};
}

View File

@ -6,7 +6,6 @@
package org.elasticsearch.xpack.ml.dataframe.process.customprocessing;
import org.elasticsearch.common.Randomness;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.util.List;
@ -18,7 +17,7 @@ import java.util.Random;
* This relies on the fact that when the dependent variable field
* is empty, then the row is not used for training but only to make predictions.
*/
class RegressionCustomProcessor implements CustomProcessor {
class DatasetSplittingCustomProcessor implements CustomProcessor {
private static final String EMPTY = "";
@ -27,10 +26,9 @@ class RegressionCustomProcessor implements CustomProcessor {
private final Random random = Randomness.get();
private boolean isFirstRow = true;
RegressionCustomProcessor(List<String> fieldNames, Regression regression) {
this.dependentVariableIndex = findDependentVariableIndex(fieldNames, regression.getDependentVariable());
this.trainingPercent = regression.getTrainingPercent();
DatasetSplittingCustomProcessor(List<String> fieldNames, String dependentVariable, double trainingPercent) {
this.dependentVariableIndex = findDependentVariableIndex(fieldNames, dependentVariable);
this.trainingPercent = trainingPercent;
}
private static int findDependentVariableIndex(List<String> fieldNames, String dependentVariable) {

View File

@ -6,7 +6,6 @@
package org.elasticsearch.xpack.ml.dataframe.process.customprocessing;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.junit.Before;
import java.util.ArrayList;
@ -20,7 +19,7 @@ import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.lessThan;
public class RegressionCustomProcessorTests extends ESTestCase {
public class DatasetSplittingCustomProcessorTests extends ESTestCase {
private List<String> fields;
private int dependentVariableIndex;
@ -38,7 +37,7 @@ public class RegressionCustomProcessorTests extends ESTestCase {
}
public void testProcess_GivenRowsWithoutDependentVariableValue() {
CustomProcessor customProcessor = new RegressionCustomProcessor(fields, regression(dependentVariable, 50.0));
CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, 50.0);
for (int i = 0; i < 100; i++) {
String[] row = new String[fields.size()];
@ -56,7 +55,7 @@ public class RegressionCustomProcessorTests extends ESTestCase {
}
public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIsHundred() {
CustomProcessor customProcessor = new RegressionCustomProcessor(fields, regression(dependentVariable, 100.0));
CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, 100.0);
for (int i = 0; i < 100; i++) {
String[] row = new String[fields.size()];
@ -76,7 +75,7 @@ public class RegressionCustomProcessorTests extends ESTestCase {
public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIsRandom() {
double trainingPercent = randomDoubleBetween(1.0, 100.0, true);
double trainingFraction = trainingPercent / 100;
CustomProcessor customProcessor = new RegressionCustomProcessor(fields, regression(dependentVariable, trainingPercent));
CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, trainingPercent);
int runCount = 20;
int rowsCount = 1000;
@ -122,7 +121,7 @@ public class RegressionCustomProcessorTests extends ESTestCase {
}
public void testProcess_ShouldHaveAtLeastOneTrainingRow() {
CustomProcessor customProcessor = new RegressionCustomProcessor(fields, regression(dependentVariable, 1.0));
CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, 1.0);
// We have some non-training rows and then a training row to check
// we maintain the first training row and not just the first row
@ -142,8 +141,4 @@ public class RegressionCustomProcessorTests extends ESTestCase {
assertThat(Arrays.equals(processedRow, row), is(true));
}
}
private static Regression regression(String dependentVariable, double trainingPercent) {
return new Regression(dependentVariable, null, null, null, null, null, null, trainingPercent);
}
}

View File

@ -1231,6 +1231,346 @@ setup:
- is_true: create_time
- is_true: version
---
"Test put classification given dependent_variable is not defined":
- do:
catch: /parse_exception/
ml.put_data_frame_analytics:
id: "classification-without-dependent-variable"
body: >
{
"source": {
"index": "index-source"
},
"dest": {
"index": "index-dest"
},
"analysis": {
"classification": {}
}
}
---
"Test put classification given negative lambda":
- do:
catch: /\[lambda\] must be a non-negative double/
ml.put_data_frame_analytics:
id: "classification-negative-lambda"
body: >
{
"source": {
"index": "index-source"
},
"dest": {
"index": "index-dest"
},
"analysis": {
"classification": {
"dependent_variable": "foo",
"lambda": -1.0
}
}
}
---
"Test put classification given negative gamma":
- do:
catch: /\[gamma\] must be a non-negative double/
ml.put_data_frame_analytics:
id: "classification-negative-gamma"
body: >
{
"source": {
"index": "index-source"
},
"dest": {
"index": "index-dest"
},
"analysis": {
"classification": {
"dependent_variable": "foo",
"gamma": -1.0
}
}
}
---
"Test put classification given eta less than 1e-3":
- do:
catch: /\[eta\] must be a double in \[0.001, 1\]/
ml.put_data_frame_analytics:
id: "classification-eta-greater-less-than-valid"
body: >
{
"source": {
"index": "index-source"
},
"dest": {
"index": "index-dest"
},
"analysis": {
"classification": {
"dependent_variable": "foo",
"eta": 0.0009
}
}
}
---
"Test put classification given eta greater than one":
- do:
catch: /\[eta\] must be a double in \[0.001, 1\]/
ml.put_data_frame_analytics:
id: "classification-eta-greater-than-one"
body: >
{
"source": {
"index": "index-source"
},
"dest": {
"index": "index-dest"
},
"analysis": {
"classification": {
"dependent_variable": "foo",
"eta": 1.00001
}
}
}
---
"Test put classification given maximum_number_trees is zero":
- do:
catch: /\[maximum_number_trees\] must be an integer in \[1, 2000\]/
ml.put_data_frame_analytics:
id: "classification-maximum-number-trees-is-zero"
body: >
{
"source": {
"index": "index-source"
},
"dest": {
"index": "index-dest"
},
"analysis": {
"classification": {
"dependent_variable": "foo",
"maximum_number_trees": 0
}
}
}
---
"Test put classification given maximum_number_trees is greater than 2k":
- do:
catch: /\[maximum_number_trees\] must be an integer in \[1, 2000\]/
ml.put_data_frame_analytics:
id: "classification-maximum-number-trees-greater-than-2k"
body: >
{
"source": {
"index": "index-source"
},
"dest": {
"index": "index-dest"
},
"analysis": {
"classification": {
"dependent_variable": "foo",
"maximum_number_trees": 2001
}
}
}
---
"Test put classification given feature_bag_fraction is negative":
- do:
catch: /\[feature_bag_fraction\] must be a double in \(0, 1\]/
ml.put_data_frame_analytics:
id: "classification-feature-bag-fraction-is-negative"
body: >
{
"source": {
"index": "index-source"
},
"dest": {
"index": "index-dest"
},
"analysis": {
"classification": {
"dependent_variable": "foo",
"feature_bag_fraction": -0.0001
}
}
}
---
"Test put classification given feature_bag_fraction is greater than one":
- do:
catch: /\[feature_bag_fraction\] must be a double in \(0, 1\]/
ml.put_data_frame_analytics:
id: "classification-feature-bag-fraction-is-greater-than-one"
body: >
{
"source": {
"index": "index-source"
},
"dest": {
"index": "index-dest"
},
"analysis": {
"classification": {
"dependent_variable": "foo",
"feature_bag_fraction": 1.0001
}
}
}
---
"Test put classification given num_top_classes is less than zero":
- do:
catch: /\[num_top_classes\] must be an integer in \[0, 1000\]/
ml.put_data_frame_analytics:
id: "classification-training-percent-is-less-than-one"
body: >
{
"source": {
"index": "index-source"
},
"dest": {
"index": "index-dest"
},
"analysis": {
"classification": {
"dependent_variable": "foo",
"num_top_classes": -1
}
}
}
---
"Test put classification given num_top_classes is greater than 1k":
- do:
catch: /\[num_top_classes\] must be an integer in \[0, 1000\]/
ml.put_data_frame_analytics:
id: "classification-training-percent-is-greater-than-hundred"
body: >
{
"source": {
"index": "index-source"
},
"dest": {
"index": "index-dest"
},
"analysis": {
"classification": {
"dependent_variable": "foo",
"num_top_classes": 1001
}
}
}
---
"Test put classification given training_percent is less than one":
- do:
catch: /\[training_percent\] must be a double in \[1, 100\]/
ml.put_data_frame_analytics:
id: "classification-training-percent-is-less-than-one"
body: >
{
"source": {
"index": "index-source"
},
"dest": {
"index": "index-dest"
},
"analysis": {
"classification": {
"dependent_variable": "foo",
"training_percent": 0.999
}
}
}
---
"Test put classification given training_percent is greater than hundred":
- do:
catch: /\[training_percent\] must be a double in \[1, 100\]/
ml.put_data_frame_analytics:
id: "classification-training-percent-is-greater-than-hundred"
body: >
{
"source": {
"index": "index-source"
},
"dest": {
"index": "index-dest"
},
"analysis": {
"classification": {
"dependent_variable": "foo",
"training_percent": 100.1
}
}
}
---
"Test put classification given valid":
- do:
ml.put_data_frame_analytics:
id: "valid-classification"
body: >
{
"source": {
"index": "index-source"
},
"dest": {
"index": "index-dest"
},
"analysis": {
"classification": {
"dependent_variable": "foo",
"lambda": 3.14,
"gamma": 0.42,
"eta": 0.5,
"maximum_number_trees": 400,
"feature_bag_fraction": 0.3,
"training_percent": 60.3
}
}
}
- match: { id: "valid-classification" }
- match: { source.index: ["index-source"] }
- match: { dest.index: "index-dest" }
- match: { analysis: {
"classification":{
"dependent_variable": "foo",
"lambda": 3.14,
"gamma": 0.42,
"eta": 0.5,
"maximum_number_trees": 400,
"feature_bag_fraction": 0.3,
"training_percent": 60.3,
"num_top_classes": 0
}
}}
- is_true: create_time
- is_true: version
---
"Test put with description":