parent
65c473bd4b
commit
ec9b77deaa
|
@ -0,0 +1,245 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe;
|
||||
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
|
||||
public class Classification implements DataFrameAnalysis {
|
||||
|
||||
public static Classification fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
public static Builder builder(String dependentVariable) {
|
||||
return new Builder(dependentVariable);
|
||||
}
|
||||
|
||||
public static final ParseField NAME = new ParseField("classification");
|
||||
|
||||
static final ParseField DEPENDENT_VARIABLE = new ParseField("dependent_variable");
|
||||
static final ParseField LAMBDA = new ParseField("lambda");
|
||||
static final ParseField GAMMA = new ParseField("gamma");
|
||||
static final ParseField ETA = new ParseField("eta");
|
||||
static final ParseField MAXIMUM_NUMBER_TREES = new ParseField("maximum_number_trees");
|
||||
static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction");
|
||||
static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
|
||||
static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
|
||||
|
||||
private static final ConstructingObjectParser<Classification, Void> PARSER =
|
||||
new ConstructingObjectParser<>(
|
||||
NAME.getPreferredName(),
|
||||
true,
|
||||
a -> new Classification(
|
||||
(String) a[0],
|
||||
(Double) a[1],
|
||||
(Double) a[2],
|
||||
(Double) a[3],
|
||||
(Integer) a[4],
|
||||
(Double) a[5],
|
||||
(String) a[6],
|
||||
(Double) a[7]));
|
||||
|
||||
static {
|
||||
PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE);
|
||||
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), LAMBDA);
|
||||
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), GAMMA);
|
||||
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), ETA);
|
||||
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), MAXIMUM_NUMBER_TREES);
|
||||
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), FEATURE_BAG_FRACTION);
|
||||
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME);
|
||||
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT);
|
||||
}
|
||||
|
||||
private final String dependentVariable;
|
||||
private final Double lambda;
|
||||
private final Double gamma;
|
||||
private final Double eta;
|
||||
private final Integer maximumNumberTrees;
|
||||
private final Double featureBagFraction;
|
||||
private final String predictionFieldName;
|
||||
private final Double trainingPercent;
|
||||
|
||||
private Classification(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
|
||||
@Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, @Nullable String predictionFieldName,
|
||||
@Nullable Double trainingPercent) {
|
||||
this.dependentVariable = Objects.requireNonNull(dependentVariable);
|
||||
this.lambda = lambda;
|
||||
this.gamma = gamma;
|
||||
this.eta = eta;
|
||||
this.maximumNumberTrees = maximumNumberTrees;
|
||||
this.featureBagFraction = featureBagFraction;
|
||||
this.predictionFieldName = predictionFieldName;
|
||||
this.trainingPercent = trainingPercent;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
public String getDependentVariable() {
|
||||
return dependentVariable;
|
||||
}
|
||||
|
||||
public Double getLambda() {
|
||||
return lambda;
|
||||
}
|
||||
|
||||
public Double getGamma() {
|
||||
return gamma;
|
||||
}
|
||||
|
||||
public Double getEta() {
|
||||
return eta;
|
||||
}
|
||||
|
||||
public Integer getMaximumNumberTrees() {
|
||||
return maximumNumberTrees;
|
||||
}
|
||||
|
||||
public Double getFeatureBagFraction() {
|
||||
return featureBagFraction;
|
||||
}
|
||||
|
||||
public String getPredictionFieldName() {
|
||||
return predictionFieldName;
|
||||
}
|
||||
|
||||
public Double getTrainingPercent() {
|
||||
return trainingPercent;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable);
|
||||
if (lambda != null) {
|
||||
builder.field(LAMBDA.getPreferredName(), lambda);
|
||||
}
|
||||
if (gamma != null) {
|
||||
builder.field(GAMMA.getPreferredName(), gamma);
|
||||
}
|
||||
if (eta != null) {
|
||||
builder.field(ETA.getPreferredName(), eta);
|
||||
}
|
||||
if (maximumNumberTrees != null) {
|
||||
builder.field(MAXIMUM_NUMBER_TREES.getPreferredName(), maximumNumberTrees);
|
||||
}
|
||||
if (featureBagFraction != null) {
|
||||
builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction);
|
||||
}
|
||||
if (predictionFieldName != null) {
|
||||
builder.field(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
|
||||
}
|
||||
if (trainingPercent != null) {
|
||||
builder.field(TRAINING_PERCENT.getPreferredName(), trainingPercent);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
|
||||
trainingPercent);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
Classification that = (Classification) o;
|
||||
return Objects.equals(dependentVariable, that.dependentVariable)
|
||||
&& Objects.equals(lambda, that.lambda)
|
||||
&& Objects.equals(gamma, that.gamma)
|
||||
&& Objects.equals(eta, that.eta)
|
||||
&& Objects.equals(maximumNumberTrees, that.maximumNumberTrees)
|
||||
&& Objects.equals(featureBagFraction, that.featureBagFraction)
|
||||
&& Objects.equals(predictionFieldName, that.predictionFieldName)
|
||||
&& Objects.equals(trainingPercent, that.trainingPercent);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return Strings.toString(this);
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
private String dependentVariable;
|
||||
private Double lambda;
|
||||
private Double gamma;
|
||||
private Double eta;
|
||||
private Integer maximumNumberTrees;
|
||||
private Double featureBagFraction;
|
||||
private String predictionFieldName;
|
||||
private Double trainingPercent;
|
||||
|
||||
private Builder(String dependentVariable) {
|
||||
this.dependentVariable = Objects.requireNonNull(dependentVariable);
|
||||
}
|
||||
|
||||
public Builder setLambda(Double lambda) {
|
||||
this.lambda = lambda;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setGamma(Double gamma) {
|
||||
this.gamma = gamma;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setEta(Double eta) {
|
||||
this.eta = eta;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setMaximumNumberTrees(Integer maximumNumberTrees) {
|
||||
this.maximumNumberTrees = maximumNumberTrees;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setFeatureBagFraction(Double featureBagFraction) {
|
||||
this.featureBagFraction = featureBagFraction;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setPredictionFieldName(String predictionFieldName) {
|
||||
this.predictionFieldName = predictionFieldName;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setTrainingPercent(Double trainingPercent) {
|
||||
this.trainingPercent = trainingPercent;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Classification build() {
|
||||
return new Classification(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
|
||||
trainingPercent);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -36,6 +36,10 @@ public class MlDataFrameAnalysisNamedXContentProvider implements NamedXContentPr
|
|||
new NamedXContentRegistry.Entry(
|
||||
DataFrameAnalysis.class,
|
||||
Regression.NAME,
|
||||
(p, c) -> Regression.fromXContent(p)));
|
||||
(p, c) -> Regression.fromXContent(p)),
|
||||
new NamedXContentRegistry.Entry(
|
||||
DataFrameAnalysis.class,
|
||||
Classification.NAME,
|
||||
(p, c) -> Classification.fromXContent(p)));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -49,7 +49,10 @@ public class Regression implements DataFrameAnalysis {
|
|||
static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
|
||||
static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
|
||||
|
||||
private static final ConstructingObjectParser<Regression, Void> PARSER = new ConstructingObjectParser<>(NAME.getPreferredName(), true,
|
||||
private static final ConstructingObjectParser<Regression, Void> PARSER =
|
||||
new ConstructingObjectParser<>(
|
||||
NAME.getPreferredName(),
|
||||
true,
|
||||
a -> new Regression(
|
||||
(String) a[0],
|
||||
(Double) a[1],
|
||||
|
|
|
@ -1315,6 +1315,41 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|||
assertThat(createdConfig.getDescription(), equalTo("this is a regression"));
|
||||
}
|
||||
|
||||
public void testPutDataFrameAnalyticsConfig_GivenClassification() throws Exception {
|
||||
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
|
||||
String configId = "test-put-df-analytics-classification";
|
||||
DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig.builder()
|
||||
.setId(configId)
|
||||
.setSource(DataFrameAnalyticsSource.builder()
|
||||
.setIndex("put-test-source-index")
|
||||
.build())
|
||||
.setDest(DataFrameAnalyticsDest.builder()
|
||||
.setIndex("put-test-dest-index")
|
||||
.build())
|
||||
.setAnalysis(org.elasticsearch.client.ml.dataframe.Classification
|
||||
.builder("my_dependent_variable")
|
||||
.setTrainingPercent(80.0)
|
||||
.build())
|
||||
.setDescription("this is a classification")
|
||||
.build();
|
||||
|
||||
createIndex("put-test-source-index", defaultMappingForTest());
|
||||
|
||||
PutDataFrameAnalyticsResponse putDataFrameAnalyticsResponse = execute(
|
||||
new PutDataFrameAnalyticsRequest(config),
|
||||
machineLearningClient::putDataFrameAnalytics, machineLearningClient::putDataFrameAnalyticsAsync);
|
||||
DataFrameAnalyticsConfig createdConfig = putDataFrameAnalyticsResponse.getConfig();
|
||||
assertThat(createdConfig.getId(), equalTo(config.getId()));
|
||||
assertThat(createdConfig.getSource().getIndex(), equalTo(config.getSource().getIndex()));
|
||||
assertThat(createdConfig.getSource().getQueryConfig(), equalTo(new QueryConfig(new MatchAllQueryBuilder()))); // default value
|
||||
assertThat(createdConfig.getDest().getIndex(), equalTo(config.getDest().getIndex()));
|
||||
assertThat(createdConfig.getDest().getResultsField(), equalTo("ml")); // default value
|
||||
assertThat(createdConfig.getAnalysis(), equalTo(config.getAnalysis()));
|
||||
assertThat(createdConfig.getAnalyzedFields(), equalTo(config.getAnalyzedFields()));
|
||||
assertThat(createdConfig.getModelMemoryLimit(), equalTo(ByteSizeValue.parseBytesSizeValue("1gb", ""))); // default value
|
||||
assertThat(createdConfig.getDescription(), equalTo("this is a classification"));
|
||||
}
|
||||
|
||||
public void testGetDataFrameAnalyticsConfig_SingleConfig() throws Exception {
|
||||
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
|
||||
String configId = "get-test-config";
|
||||
|
|
|
@ -684,7 +684,7 @@ public class RestHighLevelClientTests extends ESTestCase {
|
|||
|
||||
public void testProvidedNamedXContents() {
|
||||
List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
|
||||
assertEquals(44, namedXContents.size());
|
||||
assertEquals(45, namedXContents.size());
|
||||
Map<Class<?>, Integer> categories = new HashMap<>();
|
||||
List<String> names = new ArrayList<>();
|
||||
for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
|
||||
|
@ -718,9 +718,10 @@ public class RestHighLevelClientTests extends ESTestCase {
|
|||
assertTrue(names.contains(ShrinkAction.NAME));
|
||||
assertTrue(names.contains(FreezeAction.NAME));
|
||||
assertTrue(names.contains(SetPriorityAction.NAME));
|
||||
assertEquals(Integer.valueOf(2), categories.get(DataFrameAnalysis.class));
|
||||
assertEquals(Integer.valueOf(3), categories.get(DataFrameAnalysis.class));
|
||||
assertTrue(names.contains(OutlierDetection.NAME.getPreferredName()));
|
||||
assertTrue(names.contains(org.elasticsearch.client.ml.dataframe.Regression.NAME.getPreferredName()));
|
||||
assertTrue(names.contains(org.elasticsearch.client.ml.dataframe.Classification.NAME.getPreferredName()));
|
||||
assertEquals(Integer.valueOf(1), categories.get(SyncConfig.class));
|
||||
assertTrue(names.contains(TimeSyncConfig.NAME));
|
||||
assertEquals(Integer.valueOf(2), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class));
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe;
|
||||
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public class ClassificationTests extends AbstractXContentTestCase<Classification> {
|
||||
|
||||
public static Classification randomClassification() {
|
||||
return Classification.builder(randomAlphaOfLength(10))
|
||||
.setLambda(randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, true))
|
||||
.setGamma(randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, true))
|
||||
.setEta(randomBoolean() ? null : randomDoubleBetween(0.001, 1.0, true))
|
||||
.setMaximumNumberTrees(randomBoolean() ? null : randomIntBetween(1, 2000))
|
||||
.setFeatureBagFraction(randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false))
|
||||
.setPredictionFieldName(randomBoolean() ? null : randomAlphaOfLength(10))
|
||||
.setTrainingPercent(randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true))
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Classification createTestInstance() {
|
||||
return randomClassification();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Classification doParseInstance(XContentParser parser) throws IOException {
|
||||
return Classification.fromXContent(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
}
|
|
@ -133,6 +133,7 @@ import org.elasticsearch.xpack.core.ml.action.ValidateDetectorAction;
|
|||
import org.elasticsearch.xpack.core.ml.action.ValidateJobConfigAction;
|
||||
import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||
|
@ -466,6 +467,7 @@ public class XPackClientPlugin extends Plugin implements ActionPlugin, NetworkPl
|
|||
// ML - Data frame analytics
|
||||
new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, OutlierDetection.NAME.getPreferredName(), OutlierDetection::new),
|
||||
new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, Regression.NAME.getPreferredName(), Regression::new),
|
||||
new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, Classification.NAME.getPreferredName(), Classification::new),
|
||||
// ML - Data frame evaluation
|
||||
new NamedWriteableRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME.getPreferredName(),
|
||||
BinarySoftClassification::new),
|
||||
|
|
|
@ -0,0 +1,156 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.analyses;
|
||||
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.AbstractObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContentFragment;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||
|
||||
/**
|
||||
* Parameters used by both {@link Classification} and {@link Regression} analyses.
|
||||
*/
|
||||
public class BoostedTreeParams implements ToXContentFragment, Writeable {
|
||||
|
||||
static final String NAME = "boosted_tree_params";
|
||||
|
||||
public static final ParseField LAMBDA = new ParseField("lambda");
|
||||
public static final ParseField GAMMA = new ParseField("gamma");
|
||||
public static final ParseField ETA = new ParseField("eta");
|
||||
public static final ParseField MAXIMUM_NUMBER_TREES = new ParseField("maximum_number_trees");
|
||||
public static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction");
|
||||
|
||||
static void declareFields(AbstractObjectParser<?, Void> parser) {
|
||||
parser.declareDouble(optionalConstructorArg(), LAMBDA);
|
||||
parser.declareDouble(optionalConstructorArg(), GAMMA);
|
||||
parser.declareDouble(optionalConstructorArg(), ETA);
|
||||
parser.declareInt(optionalConstructorArg(), MAXIMUM_NUMBER_TREES);
|
||||
parser.declareDouble(optionalConstructorArg(), FEATURE_BAG_FRACTION);
|
||||
}
|
||||
|
||||
private final Double lambda;
|
||||
private final Double gamma;
|
||||
private final Double eta;
|
||||
private final Integer maximumNumberTrees;
|
||||
private final Double featureBagFraction;
|
||||
|
||||
BoostedTreeParams(@Nullable Double lambda,
|
||||
@Nullable Double gamma,
|
||||
@Nullable Double eta,
|
||||
@Nullable Integer maximumNumberTrees,
|
||||
@Nullable Double featureBagFraction) {
|
||||
if (lambda != null && lambda < 0) {
|
||||
throw ExceptionsHelper.badRequestException("[{}] must be a non-negative double", LAMBDA.getPreferredName());
|
||||
}
|
||||
if (gamma != null && gamma < 0) {
|
||||
throw ExceptionsHelper.badRequestException("[{}] must be a non-negative double", GAMMA.getPreferredName());
|
||||
}
|
||||
if (eta != null && (eta < 0.001 || eta > 1)) {
|
||||
throw ExceptionsHelper.badRequestException("[{}] must be a double in [0.001, 1]", ETA.getPreferredName());
|
||||
}
|
||||
if (maximumNumberTrees != null && (maximumNumberTrees <= 0 || maximumNumberTrees > 2000)) {
|
||||
throw ExceptionsHelper.badRequestException("[{}] must be an integer in [1, 2000]", MAXIMUM_NUMBER_TREES.getPreferredName());
|
||||
}
|
||||
if (featureBagFraction != null && (featureBagFraction <= 0 || featureBagFraction > 1.0)) {
|
||||
throw ExceptionsHelper.badRequestException("[{}] must be a double in (0, 1]", FEATURE_BAG_FRACTION.getPreferredName());
|
||||
}
|
||||
this.lambda = lambda;
|
||||
this.gamma = gamma;
|
||||
this.eta = eta;
|
||||
this.maximumNumberTrees = maximumNumberTrees;
|
||||
this.featureBagFraction = featureBagFraction;
|
||||
}
|
||||
|
||||
BoostedTreeParams() {
|
||||
this(null, null, null, null, null);
|
||||
}
|
||||
|
||||
BoostedTreeParams(StreamInput in) throws IOException {
|
||||
lambda = in.readOptionalDouble();
|
||||
gamma = in.readOptionalDouble();
|
||||
eta = in.readOptionalDouble();
|
||||
maximumNumberTrees = in.readOptionalVInt();
|
||||
featureBagFraction = in.readOptionalDouble();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeOptionalDouble(lambda);
|
||||
out.writeOptionalDouble(gamma);
|
||||
out.writeOptionalDouble(eta);
|
||||
out.writeOptionalVInt(maximumNumberTrees);
|
||||
out.writeOptionalDouble(featureBagFraction);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
if (lambda != null) {
|
||||
builder.field(LAMBDA.getPreferredName(), lambda);
|
||||
}
|
||||
if (gamma != null) {
|
||||
builder.field(GAMMA.getPreferredName(), gamma);
|
||||
}
|
||||
if (eta != null) {
|
||||
builder.field(ETA.getPreferredName(), eta);
|
||||
}
|
||||
if (maximumNumberTrees != null) {
|
||||
builder.field(MAXIMUM_NUMBER_TREES.getPreferredName(), maximumNumberTrees);
|
||||
}
|
||||
if (featureBagFraction != null) {
|
||||
builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction);
|
||||
}
|
||||
return builder;
|
||||
}
|
||||
|
||||
Map<String, Object> getParams() {
|
||||
Map<String, Object> params = new HashMap<>();
|
||||
if (lambda != null) {
|
||||
params.put(LAMBDA.getPreferredName(), lambda);
|
||||
}
|
||||
if (gamma != null) {
|
||||
params.put(GAMMA.getPreferredName(), gamma);
|
||||
}
|
||||
if (eta != null) {
|
||||
params.put(ETA.getPreferredName(), eta);
|
||||
}
|
||||
if (maximumNumberTrees != null) {
|
||||
params.put(MAXIMUM_NUMBER_TREES.getPreferredName(), maximumNumberTrees);
|
||||
}
|
||||
if (featureBagFraction != null) {
|
||||
params.put(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction);
|
||||
}
|
||||
return params;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
BoostedTreeParams that = (BoostedTreeParams) o;
|
||||
return Objects.equals(lambda, that.lambda)
|
||||
&& Objects.equals(gamma, that.gamma)
|
||||
&& Objects.equals(eta, that.eta)
|
||||
&& Objects.equals(maximumNumberTrees, that.maximumNumberTrees)
|
||||
&& Objects.equals(featureBagFraction, that.featureBagFraction);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(lambda, gamma, eta, maximumNumberTrees, featureBagFraction);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,186 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.analyses;
|
||||
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||
|
||||
public class Classification implements DataFrameAnalysis {
|
||||
|
||||
public static final ParseField NAME = new ParseField("classification");
|
||||
|
||||
public static final ParseField DEPENDENT_VARIABLE = new ParseField("dependent_variable");
|
||||
public static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
|
||||
public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
|
||||
public static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
|
||||
|
||||
private static final ConstructingObjectParser<Classification, Void> LENIENT_PARSER = createParser(true);
|
||||
private static final ConstructingObjectParser<Classification, Void> STRICT_PARSER = createParser(false);
|
||||
|
||||
private static ConstructingObjectParser<Classification, Void> createParser(boolean lenient) {
|
||||
ConstructingObjectParser<Classification, Void> parser = new ConstructingObjectParser<>(
|
||||
NAME.getPreferredName(),
|
||||
lenient,
|
||||
a -> new Classification(
|
||||
(String) a[0],
|
||||
new BoostedTreeParams((Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5]),
|
||||
(String) a[6],
|
||||
(Integer) a[7],
|
||||
(Double) a[8]));
|
||||
parser.declareString(constructorArg(), DEPENDENT_VARIABLE);
|
||||
BoostedTreeParams.declareFields(parser);
|
||||
parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME);
|
||||
parser.declareInt(optionalConstructorArg(), NUM_TOP_CLASSES);
|
||||
parser.declareDouble(optionalConstructorArg(), TRAINING_PERCENT);
|
||||
return parser;
|
||||
}
|
||||
|
||||
public static Classification fromXContent(XContentParser parser, boolean ignoreUnknownFields) {
|
||||
return ignoreUnknownFields ? LENIENT_PARSER.apply(parser, null) : STRICT_PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
private final String dependentVariable;
|
||||
private final BoostedTreeParams boostedTreeParams;
|
||||
private final String predictionFieldName;
|
||||
private final int numTopClasses;
|
||||
private final double trainingPercent;
|
||||
|
||||
public Classification(String dependentVariable,
|
||||
BoostedTreeParams boostedTreeParams,
|
||||
@Nullable String predictionFieldName,
|
||||
@Nullable Integer numTopClasses,
|
||||
@Nullable Double trainingPercent) {
|
||||
if (numTopClasses != null && (numTopClasses < 0 || numTopClasses > 1000)) {
|
||||
throw ExceptionsHelper.badRequestException("[{}] must be an integer in [0, 1000]", NUM_TOP_CLASSES.getPreferredName());
|
||||
}
|
||||
if (trainingPercent != null && (trainingPercent < 1.0 || trainingPercent > 100.0)) {
|
||||
throw ExceptionsHelper.badRequestException("[{}] must be a double in [1, 100]", TRAINING_PERCENT.getPreferredName());
|
||||
}
|
||||
this.dependentVariable = ExceptionsHelper.requireNonNull(dependentVariable, DEPENDENT_VARIABLE);
|
||||
this.boostedTreeParams = ExceptionsHelper.requireNonNull(boostedTreeParams, BoostedTreeParams.NAME);
|
||||
this.predictionFieldName = predictionFieldName;
|
||||
this.numTopClasses = numTopClasses == null ? 0 : numTopClasses;
|
||||
this.trainingPercent = trainingPercent == null ? 100.0 : trainingPercent;
|
||||
}
|
||||
|
||||
public Classification(String dependentVariable) {
|
||||
this(dependentVariable, new BoostedTreeParams(), null, null, null);
|
||||
}
|
||||
|
||||
public Classification(StreamInput in) throws IOException {
|
||||
dependentVariable = in.readString();
|
||||
boostedTreeParams = new BoostedTreeParams(in);
|
||||
predictionFieldName = in.readOptionalString();
|
||||
numTopClasses = in.readOptionalVInt();
|
||||
trainingPercent = in.readDouble();
|
||||
}
|
||||
|
||||
public String getDependentVariable() {
|
||||
return dependentVariable;
|
||||
}
|
||||
|
||||
public double getTrainingPercent() {
|
||||
return trainingPercent;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeString(dependentVariable);
|
||||
boostedTreeParams.writeTo(out);
|
||||
out.writeOptionalString(predictionFieldName);
|
||||
out.writeOptionalVInt(numTopClasses);
|
||||
out.writeDouble(trainingPercent);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable);
|
||||
boostedTreeParams.toXContent(builder, params);
|
||||
builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses);
|
||||
if (predictionFieldName != null) {
|
||||
builder.field(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
|
||||
}
|
||||
builder.field(TRAINING_PERCENT.getPreferredName(), trainingPercent);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, Object> getParams() {
|
||||
Map<String, Object> params = new HashMap<>();
|
||||
params.put(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable);
|
||||
params.putAll(boostedTreeParams.getParams());
|
||||
params.put(NUM_TOP_CLASSES.getPreferredName(), numTopClasses);
|
||||
if (predictionFieldName != null) {
|
||||
params.put(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
|
||||
}
|
||||
return params;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean supportsCategoricalFields() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<RequiredField> getRequiredFields() {
|
||||
return Collections.singletonList(new RequiredField(dependentVariable, Types.categorical()));
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean supportsMissingValues() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean persistsState() {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getStateDocId(String jobId) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
Classification that = (Classification) o;
|
||||
return Objects.equals(dependentVariable, that.dependentVariable)
|
||||
&& Objects.equals(boostedTreeParams, that.boostedTreeParams)
|
||||
&& Objects.equals(predictionFieldName, that.predictionFieldName)
|
||||
&& Objects.equals(numTopClasses, that.numTopClasses)
|
||||
&& trainingPercent == that.trainingPercent;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, numTopClasses, trainingPercent);
|
||||
}
|
||||
}
|
|
@ -9,35 +9,34 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
|||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.plugins.spi.NamedXContentProvider;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
public class MlDataFrameAnalysisNamedXContentProvider implements NamedXContentProvider {
|
||||
|
||||
@Override
|
||||
public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
|
||||
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
|
||||
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(DataFrameAnalysis.class, OutlierDetection.NAME, (p, c) -> {
|
||||
return Arrays.asList(
|
||||
new NamedXContentRegistry.Entry(DataFrameAnalysis.class, OutlierDetection.NAME, (p, c) -> {
|
||||
boolean ignoreUnknownFields = (boolean) c;
|
||||
return OutlierDetection.fromXContent(p, ignoreUnknownFields);
|
||||
}));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(DataFrameAnalysis.class, Regression.NAME, (p, c) -> {
|
||||
}),
|
||||
new NamedXContentRegistry.Entry(DataFrameAnalysis.class, Regression.NAME, (p, c) -> {
|
||||
boolean ignoreUnknownFields = (boolean) c;
|
||||
return Regression.fromXContent(p, ignoreUnknownFields);
|
||||
}));
|
||||
|
||||
return namedXContent;
|
||||
}),
|
||||
new NamedXContentRegistry.Entry(DataFrameAnalysis.class, Classification.NAME, (p, c) -> {
|
||||
boolean ignoreUnknownFields = (boolean) c;
|
||||
return Classification.fromXContent(p, ignoreUnknownFields);
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
|
||||
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
|
||||
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, OutlierDetection.NAME.getPreferredName(),
|
||||
OutlierDetection::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, Regression.NAME.getPreferredName(),
|
||||
Regression::new));
|
||||
|
||||
return namedWriteables;
|
||||
return Arrays.asList(
|
||||
new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, OutlierDetection.NAME.getPreferredName(), OutlierDetection::new),
|
||||
new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, Regression.NAME.getPreferredName(), Regression::new),
|
||||
new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, Classification.NAME.getPreferredName(), Classification::new)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,16 +21,14 @@ import java.util.List;
|
|||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||
|
||||
public class Regression implements DataFrameAnalysis {
|
||||
|
||||
public static final ParseField NAME = new ParseField("regression");
|
||||
|
||||
public static final ParseField DEPENDENT_VARIABLE = new ParseField("dependent_variable");
|
||||
public static final ParseField LAMBDA = new ParseField("lambda");
|
||||
public static final ParseField GAMMA = new ParseField("gamma");
|
||||
public static final ParseField ETA = new ParseField("eta");
|
||||
public static final ParseField MAXIMUM_NUMBER_TREES = new ParseField("maximum_number_trees");
|
||||
public static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction");
|
||||
public static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
|
||||
public static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
|
||||
|
||||
|
@ -38,17 +36,18 @@ public class Regression implements DataFrameAnalysis {
|
|||
private static final ConstructingObjectParser<Regression, Void> STRICT_PARSER = createParser(false);
|
||||
|
||||
private static ConstructingObjectParser<Regression, Void> createParser(boolean lenient) {
|
||||
ConstructingObjectParser<Regression, Void> parser = new ConstructingObjectParser<>(NAME.getPreferredName(), lenient,
|
||||
a -> new Regression((String) a[0], (Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5], (String) a[6],
|
||||
ConstructingObjectParser<Regression, Void> parser = new ConstructingObjectParser<>(
|
||||
NAME.getPreferredName(),
|
||||
lenient,
|
||||
a -> new Regression(
|
||||
(String) a[0],
|
||||
new BoostedTreeParams((Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5]),
|
||||
(String) a[6],
|
||||
(Double) a[7]));
|
||||
parser.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE);
|
||||
parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), LAMBDA);
|
||||
parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), GAMMA);
|
||||
parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), ETA);
|
||||
parser.declareInt(ConstructingObjectParser.optionalConstructorArg(), MAXIMUM_NUMBER_TREES);
|
||||
parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), FEATURE_BAG_FRACTION);
|
||||
parser.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME);
|
||||
parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT);
|
||||
parser.declareString(constructorArg(), DEPENDENT_VARIABLE);
|
||||
BoostedTreeParams.declareFields(parser);
|
||||
parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME);
|
||||
parser.declareDouble(optionalConstructorArg(), TRAINING_PERCENT);
|
||||
return parser;
|
||||
}
|
||||
|
||||
|
@ -57,63 +56,30 @@ public class Regression implements DataFrameAnalysis {
|
|||
}
|
||||
|
||||
private final String dependentVariable;
|
||||
private final Double lambda;
|
||||
private final Double gamma;
|
||||
private final Double eta;
|
||||
private final Integer maximumNumberTrees;
|
||||
private final Double featureBagFraction;
|
||||
private final BoostedTreeParams boostedTreeParams;
|
||||
private final String predictionFieldName;
|
||||
private final double trainingPercent;
|
||||
|
||||
public Regression(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
|
||||
@Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, @Nullable String predictionFieldName,
|
||||
public Regression(String dependentVariable,
|
||||
BoostedTreeParams boostedTreeParams,
|
||||
@Nullable String predictionFieldName,
|
||||
@Nullable Double trainingPercent) {
|
||||
this.dependentVariable = Objects.requireNonNull(dependentVariable);
|
||||
|
||||
if (lambda != null && lambda < 0) {
|
||||
throw ExceptionsHelper.badRequestException("[{}] must be a non-negative double", LAMBDA.getPreferredName());
|
||||
}
|
||||
this.lambda = lambda;
|
||||
|
||||
if (gamma != null && gamma < 0) {
|
||||
throw ExceptionsHelper.badRequestException("[{}] must be a non-negative double", GAMMA.getPreferredName());
|
||||
}
|
||||
this.gamma = gamma;
|
||||
|
||||
if (eta != null && (eta < 0.001 || eta > 1)) {
|
||||
throw ExceptionsHelper.badRequestException("[{}] must be a double in [0.001, 1]", ETA.getPreferredName());
|
||||
}
|
||||
this.eta = eta;
|
||||
|
||||
if (maximumNumberTrees != null && (maximumNumberTrees <= 0 || maximumNumberTrees > 2000)) {
|
||||
throw ExceptionsHelper.badRequestException("[{}] must be an integer in [1, 2000]", MAXIMUM_NUMBER_TREES.getPreferredName());
|
||||
}
|
||||
this.maximumNumberTrees = maximumNumberTrees;
|
||||
|
||||
if (featureBagFraction != null && (featureBagFraction <= 0 || featureBagFraction > 1.0)) {
|
||||
throw ExceptionsHelper.badRequestException("[{}] must be a double in (0, 1]", FEATURE_BAG_FRACTION.getPreferredName());
|
||||
}
|
||||
this.featureBagFraction = featureBagFraction;
|
||||
|
||||
this.predictionFieldName = predictionFieldName;
|
||||
|
||||
if (trainingPercent != null && (trainingPercent < 1.0 || trainingPercent > 100.0)) {
|
||||
throw ExceptionsHelper.badRequestException("[{}] must be a double in [1, 100]", TRAINING_PERCENT.getPreferredName());
|
||||
}
|
||||
this.dependentVariable = ExceptionsHelper.requireNonNull(dependentVariable, DEPENDENT_VARIABLE);
|
||||
this.boostedTreeParams = ExceptionsHelper.requireNonNull(boostedTreeParams, BoostedTreeParams.NAME);
|
||||
this.predictionFieldName = predictionFieldName;
|
||||
this.trainingPercent = trainingPercent == null ? 100.0 : trainingPercent;
|
||||
}
|
||||
|
||||
public Regression(String dependentVariable) {
|
||||
this(dependentVariable, null, null, null, null, null, null, null);
|
||||
this(dependentVariable, new BoostedTreeParams(), null, null);
|
||||
}
|
||||
|
||||
public Regression(StreamInput in) throws IOException {
|
||||
dependentVariable = in.readString();
|
||||
lambda = in.readOptionalDouble();
|
||||
gamma = in.readOptionalDouble();
|
||||
eta = in.readOptionalDouble();
|
||||
maximumNumberTrees = in.readOptionalVInt();
|
||||
featureBagFraction = in.readOptionalDouble();
|
||||
boostedTreeParams = new BoostedTreeParams(in);
|
||||
predictionFieldName = in.readOptionalString();
|
||||
trainingPercent = in.readDouble();
|
||||
}
|
||||
|
@ -134,11 +100,7 @@ public class Regression implements DataFrameAnalysis {
|
|||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeString(dependentVariable);
|
||||
out.writeOptionalDouble(lambda);
|
||||
out.writeOptionalDouble(gamma);
|
||||
out.writeOptionalDouble(eta);
|
||||
out.writeOptionalVInt(maximumNumberTrees);
|
||||
out.writeOptionalDouble(featureBagFraction);
|
||||
boostedTreeParams.writeTo(out);
|
||||
out.writeOptionalString(predictionFieldName);
|
||||
out.writeDouble(trainingPercent);
|
||||
}
|
||||
|
@ -147,21 +109,7 @@ public class Regression implements DataFrameAnalysis {
|
|||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable);
|
||||
if (lambda != null) {
|
||||
builder.field(LAMBDA.getPreferredName(), lambda);
|
||||
}
|
||||
if (gamma != null) {
|
||||
builder.field(GAMMA.getPreferredName(), gamma);
|
||||
}
|
||||
if (eta != null) {
|
||||
builder.field(ETA.getPreferredName(), eta);
|
||||
}
|
||||
if (maximumNumberTrees != null) {
|
||||
builder.field(MAXIMUM_NUMBER_TREES.getPreferredName(), maximumNumberTrees);
|
||||
}
|
||||
if (featureBagFraction != null) {
|
||||
builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction);
|
||||
}
|
||||
boostedTreeParams.toXContent(builder, params);
|
||||
if (predictionFieldName != null) {
|
||||
builder.field(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
|
||||
}
|
||||
|
@ -174,21 +122,7 @@ public class Regression implements DataFrameAnalysis {
|
|||
public Map<String, Object> getParams() {
|
||||
Map<String, Object> params = new HashMap<>();
|
||||
params.put(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable);
|
||||
if (lambda != null) {
|
||||
params.put(LAMBDA.getPreferredName(), lambda);
|
||||
}
|
||||
if (gamma != null) {
|
||||
params.put(GAMMA.getPreferredName(), gamma);
|
||||
}
|
||||
if (eta != null) {
|
||||
params.put(ETA.getPreferredName(), eta);
|
||||
}
|
||||
if (maximumNumberTrees != null) {
|
||||
params.put(MAXIMUM_NUMBER_TREES.getPreferredName(), maximumNumberTrees);
|
||||
}
|
||||
if (featureBagFraction != null) {
|
||||
params.put(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction);
|
||||
}
|
||||
params.putAll(boostedTreeParams.getParams());
|
||||
if (predictionFieldName != null) {
|
||||
params.put(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
|
||||
}
|
||||
|
@ -220,24 +154,19 @@ public class Regression implements DataFrameAnalysis {
|
|||
return jobId + "_regression_state#1";
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
|
||||
trainingPercent);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
Regression that = (Regression) o;
|
||||
return Objects.equals(dependentVariable, that.dependentVariable)
|
||||
&& Objects.equals(lambda, that.lambda)
|
||||
&& Objects.equals(gamma, that.gamma)
|
||||
&& Objects.equals(eta, that.eta)
|
||||
&& Objects.equals(maximumNumberTrees, that.maximumNumberTrees)
|
||||
&& Objects.equals(featureBagFraction, that.featureBagFraction)
|
||||
&& Objects.equals(boostedTreeParams, that.boostedTreeParams)
|
||||
&& Objects.equals(predictionFieldName, that.predictionFieldName)
|
||||
&& trainingPercent == that.trainingPercent;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, trainingPercent);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,9 +7,7 @@ package org.elasticsearch.xpack.core.ml.dataframe.analyses;
|
|||
|
||||
import org.elasticsearch.index.mapper.NumberFieldMapper;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.HashSet;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
@ -21,17 +19,17 @@ public final class Types {
|
|||
|
||||
private Types() {}
|
||||
|
||||
private static final Set<String> CATEGORICAL_TYPES = Collections.unmodifiableSet(new HashSet<>(Arrays.asList("text", "keyword", "ip")));
|
||||
private static final Set<String> CATEGORICAL_TYPES =
|
||||
Collections.unmodifiableSet(
|
||||
Stream.of("text", "keyword", "ip")
|
||||
.collect(Collectors.toSet()));
|
||||
|
||||
private static final Set<String> NUMERICAL_TYPES;
|
||||
|
||||
static {
|
||||
Set<String> numericalTypes = Stream.of(NumberFieldMapper.NumberType.values())
|
||||
.map(NumberFieldMapper.NumberType::typeName)
|
||||
.collect(Collectors.toSet());
|
||||
numericalTypes.add("scaled_float");
|
||||
NUMERICAL_TYPES = Collections.unmodifiableSet(numericalTypes);
|
||||
}
|
||||
private static final Set<String> NUMERICAL_TYPES =
|
||||
Collections.unmodifiableSet(
|
||||
Stream.concat(
|
||||
Stream.of(NumberFieldMapper.NumberType.values()).map(NumberFieldMapper.NumberType::typeName),
|
||||
Stream.of("scaled_float"))
|
||||
.collect(Collectors.toSet()));
|
||||
|
||||
public static Set<String> categorical() {
|
||||
return CATEGORICAL_TYPES;
|
||||
|
|
|
@ -30,6 +30,8 @@ import org.elasticsearch.xpack.core.ml.datafeed.DelayedDataCheckConfig;
|
|||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||
import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig;
|
||||
|
@ -449,19 +451,19 @@ public class ElasticsearchMappings {
|
|||
.startObject(Regression.DEPENDENT_VARIABLE.getPreferredName())
|
||||
.field(TYPE, KEYWORD)
|
||||
.endObject()
|
||||
.startObject(Regression.LAMBDA.getPreferredName())
|
||||
.startObject(BoostedTreeParams.LAMBDA.getPreferredName())
|
||||
.field(TYPE, DOUBLE)
|
||||
.endObject()
|
||||
.startObject(Regression.GAMMA.getPreferredName())
|
||||
.startObject(BoostedTreeParams.GAMMA.getPreferredName())
|
||||
.field(TYPE, DOUBLE)
|
||||
.endObject()
|
||||
.startObject(Regression.ETA.getPreferredName())
|
||||
.startObject(BoostedTreeParams.ETA.getPreferredName())
|
||||
.field(TYPE, DOUBLE)
|
||||
.endObject()
|
||||
.startObject(Regression.MAXIMUM_NUMBER_TREES.getPreferredName())
|
||||
.startObject(BoostedTreeParams.MAXIMUM_NUMBER_TREES.getPreferredName())
|
||||
.field(TYPE, INTEGER)
|
||||
.endObject()
|
||||
.startObject(Regression.FEATURE_BAG_FRACTION.getPreferredName())
|
||||
.startObject(BoostedTreeParams.FEATURE_BAG_FRACTION.getPreferredName())
|
||||
.field(TYPE, DOUBLE)
|
||||
.endObject()
|
||||
.startObject(Regression.PREDICTION_FIELD_NAME.getPreferredName())
|
||||
|
@ -472,6 +474,37 @@ public class ElasticsearchMappings {
|
|||
.endObject()
|
||||
.endObject()
|
||||
.endObject()
|
||||
.startObject(Classification.NAME.getPreferredName())
|
||||
.startObject(PROPERTIES)
|
||||
.startObject(Classification.DEPENDENT_VARIABLE.getPreferredName())
|
||||
.field(TYPE, KEYWORD)
|
||||
.endObject()
|
||||
.startObject(BoostedTreeParams.LAMBDA.getPreferredName())
|
||||
.field(TYPE, DOUBLE)
|
||||
.endObject()
|
||||
.startObject(BoostedTreeParams.GAMMA.getPreferredName())
|
||||
.field(TYPE, DOUBLE)
|
||||
.endObject()
|
||||
.startObject(BoostedTreeParams.ETA.getPreferredName())
|
||||
.field(TYPE, DOUBLE)
|
||||
.endObject()
|
||||
.startObject(BoostedTreeParams.MAXIMUM_NUMBER_TREES.getPreferredName())
|
||||
.field(TYPE, INTEGER)
|
||||
.endObject()
|
||||
.startObject(BoostedTreeParams.FEATURE_BAG_FRACTION.getPreferredName())
|
||||
.field(TYPE, DOUBLE)
|
||||
.endObject()
|
||||
.startObject(Classification.PREDICTION_FIELD_NAME.getPreferredName())
|
||||
.field(TYPE, KEYWORD)
|
||||
.endObject()
|
||||
.startObject(Classification.NUM_TOP_CLASSES.getPreferredName())
|
||||
.field(TYPE, INTEGER)
|
||||
.endObject()
|
||||
.startObject(Classification.TRAINING_PERCENT.getPreferredName())
|
||||
.field(TYPE, DOUBLE)
|
||||
.endObject()
|
||||
.endObject()
|
||||
.endObject()
|
||||
.endObject()
|
||||
.endObject()
|
||||
// re-used: CREATE_TIME
|
||||
|
|
|
@ -13,6 +13,8 @@ import org.elasticsearch.xpack.core.ml.datafeed.DelayedDataCheckConfig;
|
|||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||
import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig;
|
||||
|
@ -303,13 +305,18 @@ public final class ReservedFieldNames {
|
|||
OutlierDetection.FEATURE_INFLUENCE_THRESHOLD.getPreferredName(),
|
||||
Regression.NAME.getPreferredName(),
|
||||
Regression.DEPENDENT_VARIABLE.getPreferredName(),
|
||||
Regression.LAMBDA.getPreferredName(),
|
||||
Regression.GAMMA.getPreferredName(),
|
||||
Regression.ETA.getPreferredName(),
|
||||
Regression.MAXIMUM_NUMBER_TREES.getPreferredName(),
|
||||
Regression.FEATURE_BAG_FRACTION.getPreferredName(),
|
||||
Regression.PREDICTION_FIELD_NAME.getPreferredName(),
|
||||
Regression.TRAINING_PERCENT.getPreferredName(),
|
||||
Classification.NAME.getPreferredName(),
|
||||
Classification.DEPENDENT_VARIABLE.getPreferredName(),
|
||||
Classification.PREDICTION_FIELD_NAME.getPreferredName(),
|
||||
Classification.NUM_TOP_CLASSES.getPreferredName(),
|
||||
Classification.TRAINING_PERCENT.getPreferredName(),
|
||||
BoostedTreeParams.LAMBDA.getPreferredName(),
|
||||
BoostedTreeParams.GAMMA.getPreferredName(),
|
||||
BoostedTreeParams.ETA.getPreferredName(),
|
||||
BoostedTreeParams.MAXIMUM_NUMBER_TREES.getPreferredName(),
|
||||
BoostedTreeParams.FEATURE_BAG_FRACTION.getPreferredName(),
|
||||
|
||||
ElasticsearchMappings.CONFIG_TYPE,
|
||||
|
||||
|
|
|
@ -0,0 +1,105 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.analyses;
|
||||
|
||||
import org.elasticsearch.ElasticsearchStatusException;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public class BoostedTreeParamsTests extends AbstractSerializingTestCase<BoostedTreeParams> {
|
||||
|
||||
@Override
|
||||
protected BoostedTreeParams doParseInstance(XContentParser parser) throws IOException {
|
||||
ConstructingObjectParser<BoostedTreeParams, Void> objParser =
|
||||
new ConstructingObjectParser<>(
|
||||
BoostedTreeParams.NAME,
|
||||
true,
|
||||
a -> new BoostedTreeParams((Double) a[0], (Double) a[1], (Double) a[2], (Integer) a[3], (Double) a[4]));
|
||||
BoostedTreeParams.declareFields(objParser);
|
||||
return objParser.apply(parser, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected BoostedTreeParams createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
public static BoostedTreeParams createRandom() {
|
||||
Double lambda = randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, true);
|
||||
Double gamma = randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, true);
|
||||
Double eta = randomBoolean() ? null : randomDoubleBetween(0.001, 1.0, true);
|
||||
Integer maximumNumberTrees = randomBoolean() ? null : randomIntBetween(1, 2000);
|
||||
Double featureBagFraction = randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false);
|
||||
return new BoostedTreeParams(lambda, gamma, eta, maximumNumberTrees, featureBagFraction);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<BoostedTreeParams> instanceReader() {
|
||||
return BoostedTreeParams::new;
|
||||
}
|
||||
|
||||
public void testConstructor_GivenNegativeLambda() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new BoostedTreeParams(-0.00001, 0.0, 0.5, 500, 0.3));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[lambda] must be a non-negative double"));
|
||||
}
|
||||
|
||||
public void testConstructor_GivenNegativeGamma() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new BoostedTreeParams(0.0, -0.00001, 0.5, 500, 0.3));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[gamma] must be a non-negative double"));
|
||||
}
|
||||
|
||||
public void testConstructor_GivenEtaIsZero() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new BoostedTreeParams(0.0, 0.0, 0.0, 500, 0.3));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[eta] must be a double in [0.001, 1]"));
|
||||
}
|
||||
|
||||
public void testConstructor_GivenEtaIsGreaterThanOne() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new BoostedTreeParams(0.0, 0.0, 1.00001, 500, 0.3));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[eta] must be a double in [0.001, 1]"));
|
||||
}
|
||||
|
||||
public void testConstructor_GivenMaximumNumberTreesIsZero() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new BoostedTreeParams(0.0, 0.0, 0.5, 0, 0.3));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[maximum_number_trees] must be an integer in [1, 2000]"));
|
||||
}
|
||||
|
||||
public void testConstructor_GivenMaximumNumberTreesIsGreaterThan2k() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new BoostedTreeParams(0.0, 0.0, 0.5, 2001, 0.3));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[maximum_number_trees] must be an integer in [1, 2000]"));
|
||||
}
|
||||
|
||||
public void testConstructor_GivenFeatureBagFractionIsLessThanZero() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new BoostedTreeParams(0.0, 0.0, 0.5, 500, -0.00001));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[feature_bag_fraction] must be a double in (0, 1]"));
|
||||
}
|
||||
|
||||
public void testConstructor_GivenFeatureBagFractionIsGreaterThanOne() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.00001));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[feature_bag_fraction] must be a double in (0, 1]"));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,68 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.analyses;
|
||||
|
||||
import org.elasticsearch.ElasticsearchStatusException;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public class ClassificationTests extends AbstractSerializingTestCase<Classification> {
|
||||
|
||||
@Override
|
||||
protected Classification doParseInstance(XContentParser parser) throws IOException {
|
||||
return Classification.fromXContent(parser, false);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Classification createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
public static Classification createRandom() {
|
||||
String dependentVariableName = randomAlphaOfLength(10);
|
||||
BoostedTreeParams boostedTreeParams = BoostedTreeParamsTests.createRandom();
|
||||
String predictionFieldName = randomBoolean() ? null : randomAlphaOfLength(10);
|
||||
Integer numTopClasses = randomBoolean() ? null : randomIntBetween(0, 1000);
|
||||
Double trainingPercent = randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true);
|
||||
return new Classification(dependentVariableName, boostedTreeParams, predictionFieldName, numTopClasses, trainingPercent);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<Classification> instanceReader() {
|
||||
return Classification::new;
|
||||
}
|
||||
|
||||
public void testConstructor_GivenTrainingPercentIsNull() {
|
||||
Classification classification = new Classification("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 3, null);
|
||||
assertThat(classification.getTrainingPercent(), equalTo(100.0));
|
||||
}
|
||||
|
||||
public void testConstructor_GivenTrainingPercentIsBoundary() {
|
||||
Classification classification = new Classification("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 3, 1.0);
|
||||
assertThat(classification.getTrainingPercent(), equalTo(1.0));
|
||||
classification = new Classification("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 3, 100.0);
|
||||
assertThat(classification.getTrainingPercent(), equalTo(100.0));
|
||||
}
|
||||
|
||||
public void testConstructor_GivenTrainingPercentIsLessThanOne() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Classification("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 3, 0.999));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
|
||||
}
|
||||
|
||||
public void testConstructor_GivenTrainingPercentIsGreaterThan100() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Classification("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 3, 100.0001));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
|
||||
}
|
||||
}
|
|
@ -28,15 +28,11 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
|
|||
}
|
||||
|
||||
public static Regression createRandom() {
|
||||
Double lambda = randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, true);
|
||||
Double gamma = randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, true);
|
||||
Double eta = randomBoolean() ? null : randomDoubleBetween(0.001, 1.0, true);
|
||||
Integer maximumNumberTrees = randomBoolean() ? null : randomIntBetween(1, 2000);
|
||||
Double featureBagFraction = randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false);
|
||||
String dependentVariableName = randomAlphaOfLength(10);
|
||||
BoostedTreeParams boostedTreeParams = BoostedTreeParamsTests.createRandom();
|
||||
String predictionFieldName = randomBoolean() ? null : randomAlphaOfLength(10);
|
||||
Double trainingPercent = randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true);
|
||||
return new Regression(randomAlphaOfLength(10), lambda, gamma, eta, maximumNumberTrees, featureBagFraction,
|
||||
predictionFieldName, trainingPercent);
|
||||
return new Regression(dependentVariableName, boostedTreeParams, predictionFieldName, trainingPercent);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -44,84 +40,28 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
|
|||
return Regression::new;
|
||||
}
|
||||
|
||||
public void testRegression_GivenNegativeLambda() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Regression("foo", -0.00001, 0.0, 0.5, 500, 0.3, "result", 100.0));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[lambda] must be a non-negative double"));
|
||||
}
|
||||
|
||||
public void testRegression_GivenNegativeGamma() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Regression("foo", 0.0, -0.00001, 0.5, 500, 0.3, "result", 100.0));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[gamma] must be a non-negative double"));
|
||||
}
|
||||
|
||||
public void testRegression_GivenEtaIsZero() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Regression("foo", 0.0, 0.0, 0.0, 500, 0.3, "result", 100.0));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[eta] must be a double in [0.001, 1]"));
|
||||
}
|
||||
|
||||
public void testRegression_GivenEtaIsGreaterThanOne() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Regression("foo", 0.0, 0.0, 1.00001, 500, 0.3, "result", 100.0));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[eta] must be a double in [0.001, 1]"));
|
||||
}
|
||||
|
||||
public void testRegression_GivenMaximumNumberTreesIsZero() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Regression("foo", 0.0, 0.0, 0.5, 0, 0.3, "result", 100.0));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[maximum_number_trees] must be an integer in [1, 2000]"));
|
||||
}
|
||||
|
||||
public void testRegression_GivenMaximumNumberTreesIsGreaterThan2k() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Regression("foo", 0.0, 0.0, 0.5, 2001, 0.3, "result", 100.0));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[maximum_number_trees] must be an integer in [1, 2000]"));
|
||||
}
|
||||
|
||||
public void testRegression_GivenFeatureBagFractionIsLessThanZero() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Regression("foo", 0.0, 0.0, 0.5, 500, -0.00001, "result", 100.0));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[feature_bag_fraction] must be a double in (0, 1]"));
|
||||
}
|
||||
|
||||
public void testRegression_GivenFeatureBagFractionIsGreaterThanOne() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Regression("foo", 0.0, 0.0, 0.5, 500, 1.00001, "result", 100.0));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[feature_bag_fraction] must be a double in (0, 1]"));
|
||||
}
|
||||
|
||||
public void testRegression_GivenTrainingPercentIsNull() {
|
||||
Regression regression = new Regression("foo", 0.0, 0.0, 0.5, 500, 1.0, "result", null);
|
||||
public void testConstructor_GivenTrainingPercentIsNull() {
|
||||
Regression regression = new Regression("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", null);
|
||||
assertThat(regression.getTrainingPercent(), equalTo(100.0));
|
||||
}
|
||||
|
||||
public void testRegression_GivenTrainingPercentIsBoundary() {
|
||||
Regression regression = new Regression("foo", 0.0, 0.0, 0.5, 500, 1.0, "result", 1.0);
|
||||
public void testConstructor_GivenTrainingPercentIsBoundary() {
|
||||
Regression regression = new Regression("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 1.0);
|
||||
assertThat(regression.getTrainingPercent(), equalTo(1.0));
|
||||
regression = new Regression("foo", 0.0, 0.0, 0.5, 500, 1.0, "result", 100.0);
|
||||
regression = new Regression("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 100.0);
|
||||
assertThat(regression.getTrainingPercent(), equalTo(100.0));
|
||||
}
|
||||
|
||||
public void testRegression_GivenTrainingPercentIsLessThanOne() {
|
||||
public void testConstructor_GivenTrainingPercentIsLessThanOne() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Regression("foo", 0.0, 0.0, 0.5, 500, 1.0, "result", 0.999));
|
||||
() -> new Regression("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 0.999));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
|
||||
}
|
||||
|
||||
public void testRegression_GivenTrainingPercentIsGreaterThan100() {
|
||||
public void testConstructor_GivenTrainingPercentIsGreaterThan100() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Regression("foo", 0.0, 0.0, 0.5, 500, 1.0, "result", 100.0001));
|
||||
() -> new Regression("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 100.0001));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
|
||||
}
|
||||
|
|
|
@ -73,6 +73,19 @@ integTest.runner {
|
|||
'ml/data_frame_analytics_crud/Test put regression given feature_bag_fraction is greater than one',
|
||||
'ml/data_frame_analytics_crud/Test put regression given training_percent is less than one',
|
||||
'ml/data_frame_analytics_crud/Test put regression given training_percent is greater than hundred',
|
||||
'ml/data_frame_analytics_crud/Test put classification given dependent_variable is not defined',
|
||||
'ml/data_frame_analytics_crud/Test put classification given negative lambda',
|
||||
'ml/data_frame_analytics_crud/Test put classification given negative gamma',
|
||||
'ml/data_frame_analytics_crud/Test put classification given eta less than 1e-3',
|
||||
'ml/data_frame_analytics_crud/Test put classification given eta greater than one',
|
||||
'ml/data_frame_analytics_crud/Test put classification given maximum_number_trees is zero',
|
||||
'ml/data_frame_analytics_crud/Test put classification given maximum_number_trees is greater than 2k',
|
||||
'ml/data_frame_analytics_crud/Test put classification given feature_bag_fraction is negative',
|
||||
'ml/data_frame_analytics_crud/Test put classification given feature_bag_fraction is greater than one',
|
||||
'ml/data_frame_analytics_crud/Test put classification given num_top_classes is less than zero',
|
||||
'ml/data_frame_analytics_crud/Test put classification given num_top_classes is greater than 1k',
|
||||
'ml/data_frame_analytics_crud/Test put classification given training_percent is less than one',
|
||||
'ml/data_frame_analytics_crud/Test put classification given training_percent is greater than hundred',
|
||||
'ml/evaluate_data_frame/Test given missing index',
|
||||
'ml/evaluate_data_frame/Test given index does not exist',
|
||||
'ml/evaluate_data_frame/Test given missing evaluation',
|
||||
|
|
|
@ -0,0 +1,314 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.ml.integration;
|
||||
|
||||
import com.google.common.collect.Ordering;
|
||||
import org.elasticsearch.action.bulk.BulkRequestBuilder;
|
||||
import org.elasticsearch.action.bulk.BulkResponse;
|
||||
import org.elasticsearch.action.get.GetResponse;
|
||||
import org.elasticsearch.action.index.IndexRequest;
|
||||
import org.elasticsearch.action.search.SearchResponse;
|
||||
import org.elasticsearch.action.support.WriteRequest;
|
||||
import org.elasticsearch.search.SearchHit;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParamsTests;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
|
||||
import org.junit.After;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.hamcrest.Matchers.allOf;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.greaterThan;
|
||||
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
|
||||
import static org.hamcrest.Matchers.hasKey;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
import static org.hamcrest.Matchers.in;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.hamcrest.Matchers.lessThanOrEqualTo;
|
||||
|
||||
public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||
|
||||
private static final String NUMERICAL_FEATURE_FIELD = "feature";
|
||||
private static final String DEPENDENT_VARIABLE_FIELD = "variable";
|
||||
private static final List<Double> NUMERICAL_FEATURE_VALUES = Collections.unmodifiableList(Arrays.asList(1.0, 2.0, 3.0));
|
||||
private static final List<String> DEPENDENT_VARIABLE_VALUES = Collections.unmodifiableList(Arrays.asList("dog", "cat", "cow"));
|
||||
|
||||
private String jobId;
|
||||
private String sourceIndex;
|
||||
private String destIndex;
|
||||
|
||||
@After
|
||||
public void cleanup() throws Exception {
|
||||
cleanUp();
|
||||
}
|
||||
|
||||
public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception {
|
||||
initialize("classification_single_numeric_feature_and_mixed_data_set");
|
||||
|
||||
{ // Index 350 rows, 300 of them being training rows.
|
||||
client().admin().indices().prepareCreate(sourceIndex)
|
||||
.addMapping("_doc", NUMERICAL_FEATURE_FIELD, "type=double", DEPENDENT_VARIABLE_FIELD, "type=keyword")
|
||||
.get();
|
||||
|
||||
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
|
||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
|
||||
for (int i = 0; i < 300; i++) {
|
||||
Double field = NUMERICAL_FEATURE_VALUES.get(i % 3);
|
||||
String value = DEPENDENT_VARIABLE_VALUES.get(i % 3);
|
||||
|
||||
IndexRequest indexRequest = new IndexRequest(sourceIndex)
|
||||
.source(NUMERICAL_FEATURE_FIELD, field, DEPENDENT_VARIABLE_FIELD, value);
|
||||
bulkRequestBuilder.add(indexRequest);
|
||||
}
|
||||
for (int i = 300; i < 350; i++) {
|
||||
Double field = NUMERICAL_FEATURE_VALUES.get(i % 3);
|
||||
|
||||
IndexRequest indexRequest = new IndexRequest(sourceIndex)
|
||||
.source(NUMERICAL_FEATURE_FIELD, field);
|
||||
bulkRequestBuilder.add(indexRequest);
|
||||
}
|
||||
BulkResponse bulkResponse = bulkRequestBuilder.get();
|
||||
if (bulkResponse.hasFailures()) {
|
||||
fail("Failed to index data: " + bulkResponse.buildFailureMessage());
|
||||
}
|
||||
}
|
||||
|
||||
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(DEPENDENT_VARIABLE_FIELD));
|
||||
registerAnalytics(config);
|
||||
putAnalytics(config);
|
||||
|
||||
assertState(jobId, DataFrameAnalyticsState.STOPPED);
|
||||
assertProgress(jobId, 0, 0, 0, 0);
|
||||
|
||||
startAnalytics(jobId);
|
||||
waitUntilAnalyticsIsStopped(jobId);
|
||||
|
||||
SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get();
|
||||
for (SearchHit hit : sourceData.getHits()) {
|
||||
Map<String, Object> destDoc = getDestDoc(config, hit);
|
||||
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(destDoc);
|
||||
|
||||
assertThat(resultsObject.containsKey("variable_prediction"), is(true));
|
||||
assertThat((String) resultsObject.get("variable_prediction"), is(in(DEPENDENT_VARIABLE_VALUES)));
|
||||
assertThat(resultsObject.containsKey("is_training"), is(true));
|
||||
assertThat(resultsObject.get("is_training"), is(destDoc.containsKey(DEPENDENT_VARIABLE_FIELD)));
|
||||
assertThat(resultsObject.containsKey("top_classes"), is(false));
|
||||
}
|
||||
|
||||
assertProgress(jobId, 100, 100, 100, 100);
|
||||
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||
assertThatAuditMessagesMatch(jobId,
|
||||
"Created analytics with analysis type [classification]",
|
||||
"Estimated memory usage for this analytics to be",
|
||||
"Started analytics",
|
||||
"Creating destination index [" + destIndex + "]",
|
||||
"Finished reindexing to destination index [" + destIndex + "]",
|
||||
"Finished analysis");
|
||||
}
|
||||
|
||||
public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Exception {
|
||||
initialize("classification_only_training_data_and_training_percent_is_100");
|
||||
indexTrainingData(sourceIndex, 300);
|
||||
|
||||
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(DEPENDENT_VARIABLE_FIELD));
|
||||
registerAnalytics(config);
|
||||
putAnalytics(config);
|
||||
|
||||
assertState(jobId, DataFrameAnalyticsState.STOPPED);
|
||||
assertProgress(jobId, 0, 0, 0, 0);
|
||||
|
||||
startAnalytics(jobId);
|
||||
waitUntilAnalyticsIsStopped(jobId);
|
||||
|
||||
SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get();
|
||||
for (SearchHit hit : sourceData.getHits()) {
|
||||
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit));
|
||||
|
||||
assertThat(resultsObject.containsKey("variable_prediction"), is(true));
|
||||
assertThat((String) resultsObject.get("variable_prediction"), is(in(DEPENDENT_VARIABLE_VALUES)));
|
||||
assertThat(resultsObject.containsKey("is_training"), is(true));
|
||||
assertThat(resultsObject.get("is_training"), is(true));
|
||||
assertThat(resultsObject.containsKey("top_classes"), is(false));
|
||||
}
|
||||
|
||||
assertProgress(jobId, 100, 100, 100, 100);
|
||||
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||
assertThatAuditMessagesMatch(jobId,
|
||||
"Created analytics with analysis type [classification]",
|
||||
"Estimated memory usage for this analytics to be",
|
||||
"Started analytics",
|
||||
"Creating destination index [" + destIndex + "]",
|
||||
"Finished reindexing to destination index [" + destIndex + "]",
|
||||
"Finished analysis");
|
||||
}
|
||||
|
||||
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception {
|
||||
initialize("classification_only_training_data_and_training_percent_is_50");
|
||||
indexTrainingData(sourceIndex, 300);
|
||||
|
||||
DataFrameAnalyticsConfig config =
|
||||
buildAnalytics(
|
||||
jobId,
|
||||
sourceIndex,
|
||||
destIndex,
|
||||
null,
|
||||
new Classification(DEPENDENT_VARIABLE_FIELD, BoostedTreeParamsTests.createRandom(), null, null, 50.0));
|
||||
registerAnalytics(config);
|
||||
putAnalytics(config);
|
||||
|
||||
assertState(jobId, DataFrameAnalyticsState.STOPPED);
|
||||
assertProgress(jobId, 0, 0, 0, 0);
|
||||
|
||||
startAnalytics(jobId);
|
||||
waitUntilAnalyticsIsStopped(jobId);
|
||||
|
||||
int trainingRowsCount = 0;
|
||||
int nonTrainingRowsCount = 0;
|
||||
SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get();
|
||||
for (SearchHit hit : sourceData.getHits()) {
|
||||
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit));
|
||||
assertThat(resultsObject.containsKey("variable_prediction"), is(true));
|
||||
assertThat((String) resultsObject.get("variable_prediction"), is(in(DEPENDENT_VARIABLE_VALUES)));
|
||||
|
||||
assertThat(resultsObject.containsKey("is_training"), is(true));
|
||||
// Let's just assert there's both training and non-training results
|
||||
if ((boolean) resultsObject.get("is_training")) {
|
||||
trainingRowsCount++;
|
||||
} else {
|
||||
nonTrainingRowsCount++;
|
||||
}
|
||||
assertThat(resultsObject.containsKey("top_classes"), is(false));
|
||||
}
|
||||
assertThat(trainingRowsCount, greaterThan(0));
|
||||
assertThat(nonTrainingRowsCount, greaterThan(0));
|
||||
|
||||
assertProgress(jobId, 100, 100, 100, 100);
|
||||
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||
assertThatAuditMessagesMatch(jobId,
|
||||
"Created analytics with analysis type [classification]",
|
||||
"Estimated memory usage for this analytics to be",
|
||||
"Started analytics",
|
||||
"Creating destination index [" + destIndex + "]",
|
||||
"Finished reindexing to destination index [" + destIndex + "]",
|
||||
"Finished analysis");
|
||||
}
|
||||
|
||||
@AwaitsFix(bugUrl = "https://github.com/elastic/ml-cpp/issues/712")
|
||||
public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows_TopClassesRequested() throws Exception {
|
||||
initialize("classification_top_classes_requested");
|
||||
indexTrainingData(sourceIndex, 300);
|
||||
|
||||
int numTopClasses = 2;
|
||||
DataFrameAnalyticsConfig config =
|
||||
buildAnalytics(
|
||||
jobId,
|
||||
sourceIndex,
|
||||
destIndex,
|
||||
null,
|
||||
new Classification(DEPENDENT_VARIABLE_FIELD, BoostedTreeParamsTests.createRandom(), null, numTopClasses, null));
|
||||
registerAnalytics(config);
|
||||
putAnalytics(config);
|
||||
|
||||
assertState(jobId, DataFrameAnalyticsState.STOPPED);
|
||||
assertProgress(jobId, 0, 0, 0, 0);
|
||||
|
||||
startAnalytics(jobId);
|
||||
waitUntilAnalyticsIsStopped(jobId);
|
||||
|
||||
SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get();
|
||||
for (SearchHit hit : sourceData.getHits()) {
|
||||
Map<String, Object> destDoc = getDestDoc(config, hit);
|
||||
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(destDoc);
|
||||
|
||||
assertThat(resultsObject.containsKey("variable_prediction"), is(true));
|
||||
assertThat((String) resultsObject.get("variable_prediction"), is(in(DEPENDENT_VARIABLE_VALUES)));
|
||||
assertTopClasses(resultsObject, numTopClasses);
|
||||
}
|
||||
|
||||
assertProgress(jobId, 100, 100, 100, 100);
|
||||
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||
assertThatAuditMessagesMatch(jobId,
|
||||
"Created analytics with analysis type [classification]",
|
||||
"Estimated memory usage for this analytics to be",
|
||||
"Started analytics",
|
||||
"Creating destination index [" + destIndex + "]",
|
||||
"Finished reindexing to destination index [" + destIndex + "]",
|
||||
"Finished analysis");
|
||||
}
|
||||
|
||||
private void initialize(String jobId) {
|
||||
this.jobId = jobId;
|
||||
this.sourceIndex = jobId + "_source_index";
|
||||
this.destIndex = sourceIndex + "_results";
|
||||
}
|
||||
|
||||
private static void indexTrainingData(String sourceIndex, int numRows) {
|
||||
client().admin().indices().prepareCreate(sourceIndex)
|
||||
.addMapping("_doc", NUMERICAL_FEATURE_FIELD, "type=double", DEPENDENT_VARIABLE_FIELD, "type=keyword")
|
||||
.get();
|
||||
|
||||
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
|
||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
|
||||
for (int i = 0; i < numRows; i++) {
|
||||
Double field = NUMERICAL_FEATURE_VALUES.get(i % 3);
|
||||
String value = DEPENDENT_VARIABLE_VALUES.get(i % 3);
|
||||
|
||||
IndexRequest indexRequest = new IndexRequest(sourceIndex)
|
||||
.source(NUMERICAL_FEATURE_FIELD, field, DEPENDENT_VARIABLE_FIELD, value);
|
||||
bulkRequestBuilder.add(indexRequest);
|
||||
}
|
||||
BulkResponse bulkResponse = bulkRequestBuilder.get();
|
||||
if (bulkResponse.hasFailures()) {
|
||||
fail("Failed to index data: " + bulkResponse.buildFailureMessage());
|
||||
}
|
||||
}
|
||||
|
||||
private static Map<String, Object> getDestDoc(DataFrameAnalyticsConfig config, SearchHit hit) {
|
||||
GetResponse destDocGetResponse = client().prepareGet().setIndex(config.getDest().getIndex()).setId(hit.getId()).get();
|
||||
assertThat(destDocGetResponse.isExists(), is(true));
|
||||
Map<String, Object> sourceDoc = hit.getSourceAsMap();
|
||||
Map<String, Object> destDoc = destDocGetResponse.getSource();
|
||||
for (String field : sourceDoc.keySet()) {
|
||||
assertThat(destDoc.containsKey(field), is(true));
|
||||
assertThat(destDoc.get(field), equalTo(sourceDoc.get(field)));
|
||||
}
|
||||
return destDoc;
|
||||
}
|
||||
|
||||
private static Map<String, Object> getMlResultsObjectFromDestDoc(Map<String, Object> destDoc) {
|
||||
assertThat(destDoc.containsKey("ml"), is(true));
|
||||
@SuppressWarnings("unchecked")
|
||||
Map<String, Object> resultsObject = (Map<String, Object>) destDoc.get("ml");
|
||||
return resultsObject;
|
||||
}
|
||||
|
||||
private static void assertTopClasses(Map<String, Object> resultsObject, int numTopClasses) {
|
||||
assertThat(resultsObject.containsKey("top_classes"), is(true));
|
||||
List<Map<String, Object>> topClasses = (List<Map<String, Object>>) resultsObject.get("top_classes");
|
||||
assertThat(topClasses, hasSize(numTopClasses));
|
||||
List<String> classNames = new ArrayList<>(topClasses.size());
|
||||
List<Double> classProbabilities = new ArrayList<>(topClasses.size());
|
||||
for (Map<String, Object> topClass : topClasses) {
|
||||
assertThat(topClass, allOf(hasKey("class_name"), hasKey("class_probability")));
|
||||
classNames.add((String) topClass.get("class_name"));
|
||||
classProbabilities.add((Double) topClass.get("class_probability"));
|
||||
}
|
||||
// Assert that all the class names come from the set of dependent variable values.
|
||||
classNames.forEach(className -> assertThat(className, is(in(DEPENDENT_VARIABLE_VALUES))));
|
||||
// Assert that the first class listed in top classes is the same as the predicted class.
|
||||
assertThat(classNames.get(0), equalTo(resultsObject.get("variable_prediction")));
|
||||
// Assert that all the class probabilities lie within [0, 1] interval.
|
||||
classProbabilities.forEach(p -> assertThat(p, allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0))));
|
||||
// Assert that the top classes are listed in the order of decreasing probabilities.
|
||||
assertThat(Ordering.natural().reverse().isOrdered(classProbabilities), is(true));
|
||||
}
|
||||
}
|
|
@ -27,8 +27,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
|||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
|
||||
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
|
||||
import org.elasticsearch.xpack.core.ml.notifications.AuditorField;
|
||||
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
|
||||
|
@ -136,13 +135,13 @@ abstract class MlNativeDataFrameAnalyticsIntegTestCase extends MlNativeIntegTest
|
|||
return response.getResponse().results();
|
||||
}
|
||||
|
||||
protected static DataFrameAnalyticsConfig buildOutlierDetectionAnalytics(String id, String[] sourceIndex, String destIndex,
|
||||
@Nullable String resultsField) {
|
||||
protected static DataFrameAnalyticsConfig buildAnalytics(String id, String sourceIndex, String destIndex,
|
||||
@Nullable String resultsField, DataFrameAnalysis analysis) {
|
||||
DataFrameAnalyticsConfig.Builder configBuilder = new DataFrameAnalyticsConfig.Builder();
|
||||
configBuilder.setId(id);
|
||||
configBuilder.setSource(new DataFrameAnalyticsSource(sourceIndex, null));
|
||||
configBuilder.setSource(new DataFrameAnalyticsSource(new String[] { sourceIndex }, null));
|
||||
configBuilder.setDest(new DataFrameAnalyticsDest(destIndex, resultsField));
|
||||
configBuilder.setAnalysis(new OutlierDetection());
|
||||
configBuilder.setAnalysis(analysis);
|
||||
return configBuilder.build();
|
||||
}
|
||||
|
||||
|
@ -175,16 +174,6 @@ abstract class MlNativeDataFrameAnalyticsIntegTestCase extends MlNativeIntegTest
|
|||
.get();
|
||||
}
|
||||
|
||||
protected static DataFrameAnalyticsConfig buildRegressionAnalytics(String id, String[] sourceIndex, String destIndex,
|
||||
@Nullable String resultsField, Regression regression) {
|
||||
DataFrameAnalyticsConfig.Builder configBuilder = new DataFrameAnalyticsConfig.Builder();
|
||||
configBuilder.setId(id);
|
||||
configBuilder.setSource(new DataFrameAnalyticsSource(sourceIndex, null));
|
||||
configBuilder.setDest(new DataFrameAnalyticsDest(destIndex, resultsField));
|
||||
configBuilder.setAnalysis(regression);
|
||||
return configBuilder.build();
|
||||
}
|
||||
|
||||
/**
|
||||
* Asserts whether the audit messages fetched from index match provided prefixes.
|
||||
* More specifically, in order to pass:
|
||||
|
|
|
@ -14,6 +14,7 @@ import org.elasticsearch.action.support.WriteRequest;
|
|||
import org.elasticsearch.search.SearchHit;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection;
|
||||
import org.junit.After;
|
||||
|
||||
import java.util.Map;
|
||||
|
@ -68,7 +69,7 @@ public class OutlierDetectionWithMissingFieldsIT extends MlNativeDataFrameAnalyt
|
|||
}
|
||||
|
||||
String id = "test_outlier_detection_with_missing_fields";
|
||||
DataFrameAnalyticsConfig config = buildOutlierDetectionAnalytics(id, new String[] {sourceIndex}, sourceIndex + "-results", null);
|
||||
DataFrameAnalyticsConfig config = buildAnalytics(id, sourceIndex, sourceIndex + "-results", null, new OutlierDetection());
|
||||
registerAnalytics(config);
|
||||
putAnalytics(config);
|
||||
|
||||
|
|
|
@ -15,11 +15,13 @@ import org.elasticsearch.index.query.QueryBuilders;
|
|||
import org.elasticsearch.search.SearchHit;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParamsTests;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
|
||||
import org.junit.After;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
|
@ -30,41 +32,52 @@ import static org.hamcrest.Matchers.is;
|
|||
|
||||
public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||
|
||||
private static final String NUMERICAL_FEATURE_FIELD = "feature";
|
||||
private static final String DEPENDENT_VARIABLE_FIELD = "variable";
|
||||
private static final List<Double> NUMERICAL_FEATURE_VALUES = Collections.unmodifiableList(Arrays.asList(1.0, 2.0, 3.0));
|
||||
private static final List<Double> DEPENDENT_VARIABLE_VALUES = Collections.unmodifiableList(Arrays.asList(10.0, 20.0, 30.0));
|
||||
|
||||
private String jobId;
|
||||
private String sourceIndex;
|
||||
private String destIndex;
|
||||
|
||||
@After
|
||||
public void cleanup() {
|
||||
public void cleanup() throws Exception {
|
||||
cleanUp();
|
||||
}
|
||||
|
||||
public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception {
|
||||
String jobId = "regression_single_numeric_feature_and_mixed_data_set";
|
||||
String sourceIndex = jobId + "_source_index";
|
||||
initialize("regression_single_numeric_feature_and_mixed_data_set");
|
||||
|
||||
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk();
|
||||
bulkRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
|
||||
{ // Index 350 rows, 300 of them being training rows.
|
||||
client().admin().indices().prepareCreate(sourceIndex)
|
||||
.addMapping("_doc", NUMERICAL_FEATURE_FIELD, "type=double", DEPENDENT_VARIABLE_FIELD, "type=double")
|
||||
.get();
|
||||
|
||||
List<Double> featureValues = Arrays.asList(1.0, 2.0, 3.0);
|
||||
List<Double> dependentVariableValues = Arrays.asList(10.0, 20.0, 30.0);
|
||||
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
|
||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
|
||||
for (int i = 0; i < 300; i++) {
|
||||
Double field = NUMERICAL_FEATURE_VALUES.get(i % 3);
|
||||
Double value = DEPENDENT_VARIABLE_VALUES.get(i % 3);
|
||||
|
||||
for (int i = 0; i < 350; i++) {
|
||||
Double field = featureValues.get(i % 3);
|
||||
Double value = dependentVariableValues.get(i % 3);
|
||||
|
||||
IndexRequest indexRequest = new IndexRequest(sourceIndex);
|
||||
if (i < 300) {
|
||||
indexRequest.source("feature", field, "variable", value);
|
||||
} else {
|
||||
indexRequest.source("feature", field);
|
||||
IndexRequest indexRequest = new IndexRequest(sourceIndex)
|
||||
.source(NUMERICAL_FEATURE_FIELD, field, DEPENDENT_VARIABLE_FIELD, value);
|
||||
bulkRequestBuilder.add(indexRequest);
|
||||
}
|
||||
for (int i = 300; i < 350; i++) {
|
||||
Double field = NUMERICAL_FEATURE_VALUES.get(i % 3);
|
||||
|
||||
IndexRequest indexRequest = new IndexRequest(sourceIndex)
|
||||
.source(NUMERICAL_FEATURE_FIELD, field);
|
||||
bulkRequestBuilder.add(indexRequest);
|
||||
}
|
||||
BulkResponse bulkResponse = bulkRequestBuilder.get();
|
||||
if (bulkResponse.hasFailures()) {
|
||||
fail("Failed to index data: " + bulkResponse.buildFailureMessage());
|
||||
}
|
||||
}
|
||||
|
||||
String destIndex = sourceIndex + "_results";
|
||||
DataFrameAnalyticsConfig config = buildRegressionAnalytics(jobId, new String[] {sourceIndex}, destIndex, null,
|
||||
new Regression("variable"));
|
||||
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Regression(DEPENDENT_VARIABLE_FIELD));
|
||||
registerAnalytics(config);
|
||||
putAnalytics(config);
|
||||
|
||||
|
@ -76,71 +89,54 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
|
||||
SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get();
|
||||
for (SearchHit hit : sourceData.getHits()) {
|
||||
GetResponse destDocGetResponse = client().prepareGet().setIndex(config.getDest().getIndex()).setId(hit.getId()).get();
|
||||
assertThat(destDocGetResponse.isExists(), is(true));
|
||||
Map<String, Object> sourceDoc = hit.getSourceAsMap();
|
||||
Map<String, Object> destDoc = destDocGetResponse.getSource();
|
||||
for (String field : sourceDoc.keySet()) {
|
||||
assertThat(destDoc.containsKey(field), is(true));
|
||||
assertThat(destDoc.get(field), equalTo(sourceDoc.get(field)));
|
||||
}
|
||||
assertThat(destDoc.containsKey("ml"), is(true));
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
Map<String, Object> resultsObject = (Map<String, Object>) destDoc.get("ml");
|
||||
|
||||
assertThat(resultsObject.containsKey("variable_prediction"), is(true));
|
||||
Map<String, Object> destDoc = getDestDoc(config, hit);
|
||||
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(destDoc);
|
||||
|
||||
// TODO reenable this assertion when the backend is stable
|
||||
// it seems for this case values can be as far off as 2.0
|
||||
|
||||
// double featureValue = (double) destDoc.get("feature");
|
||||
// double featureValue = (double) destDoc.get(NUMERICAL_FEATURE_FIELD);
|
||||
// double predictionValue = (double) resultsObject.get("variable_prediction");
|
||||
// assertThat(predictionValue, closeTo(10 * featureValue, 2.0));
|
||||
|
||||
boolean expectedIsTraining = destDoc.containsKey("variable");
|
||||
assertThat(resultsObject.containsKey("variable_prediction"), is(true));
|
||||
assertThat(resultsObject.containsKey("is_training"), is(true));
|
||||
assertThat(resultsObject.get("is_training"), is(expectedIsTraining));
|
||||
assertThat(resultsObject.get("is_training"), is(destDoc.containsKey(DEPENDENT_VARIABLE_FIELD)));
|
||||
}
|
||||
|
||||
assertProgress(jobId, 100, 100, 100, 100);
|
||||
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||
assertModelStatePersisted(jobId);
|
||||
assertThatAuditMessagesMatch(jobId,
|
||||
"Created analytics with analysis type [regression]",
|
||||
"Estimated memory usage for this analytics to be",
|
||||
"Started analytics",
|
||||
"Creating destination index [regression_single_numeric_feature_and_mixed_data_set_source_index_results]",
|
||||
"Finished reindexing to destination index [regression_single_numeric_feature_and_mixed_data_set_source_index_results]",
|
||||
"Creating destination index [" + destIndex + "]",
|
||||
"Finished reindexing to destination index [" + destIndex + "]",
|
||||
"Finished analysis");
|
||||
assertModelStatePersisted(jobId);
|
||||
}
|
||||
|
||||
public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Exception {
|
||||
String jobId = "regression_only_training_data_and_training_percent_is_hundred";
|
||||
String sourceIndex = jobId + "_source_index";
|
||||
|
||||
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk();
|
||||
bulkRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
|
||||
|
||||
List<Double> featureValues = Arrays.asList(1.0, 2.0, 3.0);
|
||||
List<Double> dependentVariableValues = Arrays.asList(10.0, 20.0, 30.0);
|
||||
initialize("regression_only_training_data_and_training_percent_is_100");
|
||||
|
||||
{ // Index 350 rows, all of them being training rows.
|
||||
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
|
||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
|
||||
for (int i = 0; i < 350; i++) {
|
||||
Double field = featureValues.get(i % 3);
|
||||
Double value = dependentVariableValues.get(i % 3);
|
||||
Double field = NUMERICAL_FEATURE_VALUES.get(i % 3);
|
||||
Double value = DEPENDENT_VARIABLE_VALUES.get(i % 3);
|
||||
|
||||
IndexRequest indexRequest = new IndexRequest(sourceIndex);
|
||||
indexRequest.source("feature", field, "variable", value);
|
||||
IndexRequest indexRequest = new IndexRequest(sourceIndex)
|
||||
.source(NUMERICAL_FEATURE_FIELD, field, DEPENDENT_VARIABLE_FIELD, value);
|
||||
bulkRequestBuilder.add(indexRequest);
|
||||
}
|
||||
BulkResponse bulkResponse = bulkRequestBuilder.get();
|
||||
if (bulkResponse.hasFailures()) {
|
||||
fail("Failed to index data: " + bulkResponse.buildFailureMessage());
|
||||
}
|
||||
}
|
||||
|
||||
String destIndex = sourceIndex + "_results";
|
||||
DataFrameAnalyticsConfig config = buildRegressionAnalytics(jobId, new String[] {sourceIndex}, destIndex, null,
|
||||
new Regression("variable"));
|
||||
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Regression(DEPENDENT_VARIABLE_FIELD));
|
||||
registerAnalytics(config);
|
||||
putAnalytics(config);
|
||||
|
||||
|
@ -152,18 +148,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
|
||||
SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get();
|
||||
for (SearchHit hit : sourceData.getHits()) {
|
||||
GetResponse destDocGetResponse = client().prepareGet().setIndex(config.getDest().getIndex()).setId(hit.getId()).get();
|
||||
assertThat(destDocGetResponse.isExists(), is(true));
|
||||
Map<String, Object> sourceDoc = hit.getSourceAsMap();
|
||||
Map<String, Object> destDoc = destDocGetResponse.getSource();
|
||||
for (String field : sourceDoc.keySet()) {
|
||||
assertThat(destDoc.containsKey(field), is(true));
|
||||
assertThat(destDoc.get(field), equalTo(sourceDoc.get(field)));
|
||||
}
|
||||
assertThat(destDoc.containsKey("ml"), is(true));
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
Map<String, Object> resultsObject = (Map<String, Object>) destDoc.get("ml");
|
||||
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit));
|
||||
|
||||
assertThat(resultsObject.containsKey("variable_prediction"), is(true));
|
||||
assertThat(resultsObject.containsKey("is_training"), is(true));
|
||||
|
@ -172,42 +157,43 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
|
||||
assertProgress(jobId, 100, 100, 100, 100);
|
||||
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||
assertModelStatePersisted(jobId);
|
||||
assertThatAuditMessagesMatch(jobId,
|
||||
"Created analytics with analysis type [regression]",
|
||||
"Estimated memory usage for this analytics to be",
|
||||
"Started analytics",
|
||||
"Creating destination index [regression_only_training_data_and_training_percent_is_hundred_source_index_results]",
|
||||
"Finished reindexing to destination index [regression_only_training_data_and_training_percent_is_hundred_source_index_results]",
|
||||
"Creating destination index [" + destIndex + "]",
|
||||
"Finished reindexing to destination index [" + destIndex + "]",
|
||||
"Finished analysis");
|
||||
assertModelStatePersisted(jobId);
|
||||
}
|
||||
|
||||
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception {
|
||||
String jobId = "regression_only_training_data_and_training_percent_is_fifty";
|
||||
String sourceIndex = jobId + "_source_index";
|
||||
|
||||
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk();
|
||||
bulkRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
|
||||
|
||||
List<Double> featureValues = Arrays.asList(1.0, 2.0, 3.0);
|
||||
List<Double> dependentVariableValues = Arrays.asList(10.0, 20.0, 30.0);
|
||||
initialize("regression_only_training_data_and_training_percent_is_50");
|
||||
|
||||
{ // Index 350 rows, all of them being training rows.
|
||||
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
|
||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
|
||||
for (int i = 0; i < 350; i++) {
|
||||
Double field = featureValues.get(i % 3);
|
||||
Double value = dependentVariableValues.get(i % 3);
|
||||
Double field = NUMERICAL_FEATURE_VALUES.get(i % 3);
|
||||
Double value = DEPENDENT_VARIABLE_VALUES.get(i % 3);
|
||||
|
||||
IndexRequest indexRequest = new IndexRequest(sourceIndex);
|
||||
indexRequest.source("feature", field, "variable", value);
|
||||
IndexRequest indexRequest = new IndexRequest(sourceIndex)
|
||||
.source(NUMERICAL_FEATURE_FIELD, field, DEPENDENT_VARIABLE_FIELD, value);
|
||||
bulkRequestBuilder.add(indexRequest);
|
||||
}
|
||||
BulkResponse bulkResponse = bulkRequestBuilder.get();
|
||||
if (bulkResponse.hasFailures()) {
|
||||
fail("Failed to index data: " + bulkResponse.buildFailureMessage());
|
||||
}
|
||||
}
|
||||
|
||||
String destIndex = sourceIndex + "_results";
|
||||
DataFrameAnalyticsConfig config = buildRegressionAnalytics(jobId, new String[] {sourceIndex}, destIndex, null,
|
||||
new Regression("variable", null, null, null, null, null, null, 50.0));
|
||||
DataFrameAnalyticsConfig config =
|
||||
buildAnalytics(
|
||||
jobId,
|
||||
sourceIndex,
|
||||
destIndex,
|
||||
null,
|
||||
new Regression(DEPENDENT_VARIABLE_FIELD, BoostedTreeParamsTests.createRandom(), null, 50.0));
|
||||
registerAnalytics(config);
|
||||
putAnalytics(config);
|
||||
|
||||
|
@ -221,21 +207,9 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
int nonTrainingRowsCount = 0;
|
||||
SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get();
|
||||
for (SearchHit hit : sourceData.getHits()) {
|
||||
GetResponse destDocGetResponse = client().prepareGet().setIndex(config.getDest().getIndex()).setId(hit.getId()).get();
|
||||
assertThat(destDocGetResponse.isExists(), is(true));
|
||||
Map<String, Object> sourceDoc = hit.getSourceAsMap();
|
||||
Map<String, Object> destDoc = destDocGetResponse.getSource();
|
||||
for (String field : sourceDoc.keySet()) {
|
||||
assertThat(destDoc.containsKey(field), is(true));
|
||||
assertThat(destDoc.get(field), equalTo(sourceDoc.get(field)));
|
||||
}
|
||||
assertThat(destDoc.containsKey("ml"), is(true));
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
Map<String, Object> resultsObject = (Map<String, Object>) destDoc.get("ml");
|
||||
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit));
|
||||
|
||||
assertThat(resultsObject.containsKey("variable_prediction"), is(true));
|
||||
|
||||
assertThat(resultsObject.containsKey("is_training"), is(true));
|
||||
// Let's just assert there's both training and non-training results
|
||||
if ((boolean) resultsObject.get("is_training")) {
|
||||
|
@ -249,32 +223,27 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
|
||||
assertProgress(jobId, 100, 100, 100, 100);
|
||||
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||
assertModelStatePersisted(jobId);
|
||||
assertThatAuditMessagesMatch(jobId,
|
||||
"Created analytics with analysis type [regression]",
|
||||
"Estimated memory usage for this analytics to be",
|
||||
"Started analytics",
|
||||
"Creating destination index [regression_only_training_data_and_training_percent_is_fifty_source_index_results]",
|
||||
"Finished reindexing to destination index [regression_only_training_data_and_training_percent_is_fifty_source_index_results]",
|
||||
"Creating destination index [" + destIndex + "]",
|
||||
"Finished reindexing to destination index [" + destIndex + "]",
|
||||
"Finished analysis");
|
||||
assertModelStatePersisted(jobId);
|
||||
}
|
||||
|
||||
public void testStopAndRestart() throws Exception {
|
||||
String jobId = "regression_stop_and_restart";
|
||||
String sourceIndex = jobId + "_source_index";
|
||||
|
||||
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk();
|
||||
bulkRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
|
||||
|
||||
List<Double> featureValues = Arrays.asList(1.0, 2.0, 3.0);
|
||||
List<Double> dependentVariableValues = Arrays.asList(10.0, 20.0, 30.0);
|
||||
initialize("regression_stop_and_restart");
|
||||
|
||||
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
|
||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
|
||||
for (int i = 0; i < 350; i++) {
|
||||
Double field = featureValues.get(i % 3);
|
||||
Double value = dependentVariableValues.get(i % 3);
|
||||
Double field = NUMERICAL_FEATURE_VALUES.get(i % 3);
|
||||
Double value = DEPENDENT_VARIABLE_VALUES.get(i % 3);
|
||||
|
||||
IndexRequest indexRequest = new IndexRequest(sourceIndex);
|
||||
indexRequest.source("feature", field, "variable", value);
|
||||
IndexRequest indexRequest = new IndexRequest(sourceIndex)
|
||||
.source("feature", field, "variable", value);
|
||||
bulkRequestBuilder.add(indexRequest);
|
||||
}
|
||||
BulkResponse bulkResponse = bulkRequestBuilder.get();
|
||||
|
@ -282,9 +251,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
fail("Failed to index data: " + bulkResponse.buildFailureMessage());
|
||||
}
|
||||
|
||||
String destIndex = sourceIndex + "_results";
|
||||
DataFrameAnalyticsConfig config = buildRegressionAnalytics(jobId, new String[] {sourceIndex}, destIndex, null,
|
||||
new Regression("variable"));
|
||||
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Regression(DEPENDENT_VARIABLE_FIELD));
|
||||
registerAnalytics(config);
|
||||
putAnalytics(config);
|
||||
|
||||
|
@ -317,18 +284,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
|
||||
SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get();
|
||||
for (SearchHit hit : sourceData.getHits()) {
|
||||
GetResponse destDocGetResponse = client().prepareGet().setIndex(config.getDest().getIndex()).setId(hit.getId()).get();
|
||||
assertThat(destDocGetResponse.isExists(), is(true));
|
||||
Map<String, Object> sourceDoc = hit.getSourceAsMap();
|
||||
Map<String, Object> destDoc = destDocGetResponse.getSource();
|
||||
for (String field : sourceDoc.keySet()) {
|
||||
assertThat(destDoc.containsKey(field), is(true));
|
||||
assertThat(destDoc.get(field), equalTo(sourceDoc.get(field)));
|
||||
}
|
||||
assertThat(destDoc.containsKey("ml"), is(true));
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
Map<String, Object> resultsObject = (Map<String, Object>) destDoc.get("ml");
|
||||
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit));
|
||||
|
||||
assertThat(resultsObject.containsKey("variable_prediction"), is(true));
|
||||
assertThat(resultsObject.containsKey("is_training"), is(true));
|
||||
|
@ -340,7 +296,32 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
assertModelStatePersisted(jobId);
|
||||
}
|
||||
|
||||
private void assertModelStatePersisted(String jobId) {
|
||||
private void initialize(String jobId) {
|
||||
this.jobId = jobId;
|
||||
this.sourceIndex = jobId + "_source_index";
|
||||
this.destIndex = sourceIndex + "_results";
|
||||
}
|
||||
|
||||
private static Map<String, Object> getDestDoc(DataFrameAnalyticsConfig config, SearchHit hit) {
|
||||
GetResponse destDocGetResponse = client().prepareGet().setIndex(config.getDest().getIndex()).setId(hit.getId()).get();
|
||||
assertThat(destDocGetResponse.isExists(), is(true));
|
||||
Map<String, Object> sourceDoc = hit.getSourceAsMap();
|
||||
Map<String, Object> destDoc = destDocGetResponse.getSource();
|
||||
for (String field : sourceDoc.keySet()) {
|
||||
assertThat(destDoc.containsKey(field), is(true));
|
||||
assertThat(destDoc.get(field), equalTo(sourceDoc.get(field)));
|
||||
}
|
||||
return destDoc;
|
||||
}
|
||||
|
||||
private static Map<String, Object> getMlResultsObjectFromDestDoc(Map<String, Object> destDoc) {
|
||||
assertThat(destDoc.containsKey("ml"), is(true));
|
||||
@SuppressWarnings("unchecked")
|
||||
Map<String, Object> resultsObject = (Map<String, Object>) destDoc.get("ml");
|
||||
return resultsObject;
|
||||
}
|
||||
|
||||
private static void assertModelStatePersisted(String jobId) {
|
||||
String docId = jobId + "_regression_state#1";
|
||||
SearchResponse searchResponse = client().prepareSearch(AnomalyDetectorsIndex.jobStateIndexPattern())
|
||||
.setQuery(QueryBuilders.idsQuery().addIds(docId))
|
||||
|
|
|
@ -72,7 +72,7 @@ public class RunDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsIntegTest
|
|||
}
|
||||
|
||||
String id = "test_outlier_detection_with_few_docs";
|
||||
DataFrameAnalyticsConfig config = buildOutlierDetectionAnalytics(id, new String[] {sourceIndex}, sourceIndex + "-results", null);
|
||||
DataFrameAnalyticsConfig config = buildAnalytics(id, sourceIndex, sourceIndex + "-results", null, new OutlierDetection());
|
||||
registerAnalytics(config);
|
||||
putAnalytics(config);
|
||||
|
||||
|
@ -147,8 +147,7 @@ public class RunDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsIntegTest
|
|||
}
|
||||
|
||||
String id = "test_outlier_detection_with_enough_docs_to_scroll";
|
||||
DataFrameAnalyticsConfig config = buildOutlierDetectionAnalytics(
|
||||
id, new String[] {sourceIndex}, sourceIndex + "-results", "custom_ml");
|
||||
DataFrameAnalyticsConfig config = buildAnalytics(id, sourceIndex, sourceIndex + "-results", "custom_ml", new OutlierDetection());
|
||||
registerAnalytics(config);
|
||||
putAnalytics(config);
|
||||
|
||||
|
@ -217,7 +216,7 @@ public class RunDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsIntegTest
|
|||
}
|
||||
|
||||
String id = "test_outlier_detection_with_more_fields_than_docvalue_limit";
|
||||
DataFrameAnalyticsConfig config = buildOutlierDetectionAnalytics(id, new String[] {sourceIndex}, sourceIndex + "-results", null);
|
||||
DataFrameAnalyticsConfig config = buildAnalytics(id, sourceIndex, sourceIndex + "-results", null, new OutlierDetection());
|
||||
registerAnalytics(config);
|
||||
putAnalytics(config);
|
||||
|
||||
|
@ -280,8 +279,7 @@ public class RunDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsIntegTest
|
|||
}
|
||||
|
||||
String id = "test_stop_outlier_detection_with_enough_docs_to_scroll";
|
||||
DataFrameAnalyticsConfig config = buildOutlierDetectionAnalytics(
|
||||
id, new String[] {sourceIndex}, sourceIndex + "-results", "custom_ml");
|
||||
DataFrameAnalyticsConfig config = buildAnalytics(id, sourceIndex, sourceIndex + "-results", "custom_ml", new OutlierDetection());
|
||||
registerAnalytics(config);
|
||||
putAnalytics(config);
|
||||
|
||||
|
@ -345,7 +343,12 @@ public class RunDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsIntegTest
|
|||
}
|
||||
|
||||
String id = "test_outlier_detection_with_multiple_source_indices";
|
||||
DataFrameAnalyticsConfig config = buildOutlierDetectionAnalytics(id, sourceIndex, destIndex, null);
|
||||
DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder()
|
||||
.setId(id)
|
||||
.setSource(new DataFrameAnalyticsSource(sourceIndex, null))
|
||||
.setDest(new DataFrameAnalyticsDest(destIndex, null))
|
||||
.setAnalysis(new OutlierDetection())
|
||||
.build();
|
||||
registerAnalytics(config);
|
||||
putAnalytics(config);
|
||||
|
||||
|
@ -402,7 +405,7 @@ public class RunDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsIntegTest
|
|||
}
|
||||
|
||||
String id = "test_outlier_detection_with_pre_existing_dest_index";
|
||||
DataFrameAnalyticsConfig config = buildOutlierDetectionAnalytics(id, new String[] {sourceIndex}, destIndex, null);
|
||||
DataFrameAnalyticsConfig config = buildAnalytics(id, sourceIndex, destIndex, null, new OutlierDetection());
|
||||
registerAnalytics(config);
|
||||
putAnalytics(config);
|
||||
|
||||
|
@ -500,8 +503,7 @@ public class RunDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsIntegTest
|
|||
}
|
||||
|
||||
String id = "test_outlier_detection_stop_and_restart";
|
||||
DataFrameAnalyticsConfig config = buildOutlierDetectionAnalytics(
|
||||
id, new String[] {sourceIndex}, sourceIndex + "-results", "custom_ml");
|
||||
DataFrameAnalyticsConfig config = buildAnalytics(id, sourceIndex, sourceIndex + "-results", "custom_ml", new OutlierDetection());
|
||||
registerAnalytics(config);
|
||||
putAnalytics(config);
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.ml.dataframe.process.customprocessing;
|
||||
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||
|
||||
|
@ -21,7 +22,14 @@ public class CustomProcessorFactory {
|
|||
|
||||
public CustomProcessor create(DataFrameAnalysis analysis) {
|
||||
if (analysis instanceof Regression) {
|
||||
return new RegressionCustomProcessor(fieldNames, (Regression) analysis);
|
||||
Regression regression = (Regression) analysis;
|
||||
return new DatasetSplittingCustomProcessor(
|
||||
fieldNames, regression.getDependentVariable(), regression.getTrainingPercent());
|
||||
}
|
||||
if (analysis instanceof Classification) {
|
||||
Classification classification = (Classification) analysis;
|
||||
return new DatasetSplittingCustomProcessor(
|
||||
fieldNames, classification.getDependentVariable(), classification.getTrainingPercent());
|
||||
}
|
||||
return row -> {};
|
||||
}
|
||||
|
|
|
@ -6,7 +6,6 @@
|
|||
package org.elasticsearch.xpack.ml.dataframe.process.customprocessing;
|
||||
|
||||
import org.elasticsearch.common.Randomness;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.util.List;
|
||||
|
@ -18,7 +17,7 @@ import java.util.Random;
|
|||
* This relies on the fact that when the dependent variable field
|
||||
* is empty, then the row is not used for training but only to make predictions.
|
||||
*/
|
||||
class RegressionCustomProcessor implements CustomProcessor {
|
||||
class DatasetSplittingCustomProcessor implements CustomProcessor {
|
||||
|
||||
private static final String EMPTY = "";
|
||||
|
||||
|
@ -27,10 +26,9 @@ class RegressionCustomProcessor implements CustomProcessor {
|
|||
private final Random random = Randomness.get();
|
||||
private boolean isFirstRow = true;
|
||||
|
||||
RegressionCustomProcessor(List<String> fieldNames, Regression regression) {
|
||||
this.dependentVariableIndex = findDependentVariableIndex(fieldNames, regression.getDependentVariable());
|
||||
this.trainingPercent = regression.getTrainingPercent();
|
||||
|
||||
DatasetSplittingCustomProcessor(List<String> fieldNames, String dependentVariable, double trainingPercent) {
|
||||
this.dependentVariableIndex = findDependentVariableIndex(fieldNames, dependentVariable);
|
||||
this.trainingPercent = trainingPercent;
|
||||
}
|
||||
|
||||
private static int findDependentVariableIndex(List<String> fieldNames, String dependentVariable) {
|
|
@ -6,7 +6,6 @@
|
|||
package org.elasticsearch.xpack.ml.dataframe.process.customprocessing;
|
||||
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.util.ArrayList;
|
||||
|
@ -20,7 +19,7 @@ import static org.hamcrest.Matchers.greaterThan;
|
|||
import static org.hamcrest.Matchers.is;
|
||||
import static org.hamcrest.Matchers.lessThan;
|
||||
|
||||
public class RegressionCustomProcessorTests extends ESTestCase {
|
||||
public class DatasetSplittingCustomProcessorTests extends ESTestCase {
|
||||
|
||||
private List<String> fields;
|
||||
private int dependentVariableIndex;
|
||||
|
@ -38,7 +37,7 @@ public class RegressionCustomProcessorTests extends ESTestCase {
|
|||
}
|
||||
|
||||
public void testProcess_GivenRowsWithoutDependentVariableValue() {
|
||||
CustomProcessor customProcessor = new RegressionCustomProcessor(fields, regression(dependentVariable, 50.0));
|
||||
CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, 50.0);
|
||||
|
||||
for (int i = 0; i < 100; i++) {
|
||||
String[] row = new String[fields.size()];
|
||||
|
@ -56,7 +55,7 @@ public class RegressionCustomProcessorTests extends ESTestCase {
|
|||
}
|
||||
|
||||
public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIsHundred() {
|
||||
CustomProcessor customProcessor = new RegressionCustomProcessor(fields, regression(dependentVariable, 100.0));
|
||||
CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, 100.0);
|
||||
|
||||
for (int i = 0; i < 100; i++) {
|
||||
String[] row = new String[fields.size()];
|
||||
|
@ -76,7 +75,7 @@ public class RegressionCustomProcessorTests extends ESTestCase {
|
|||
public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIsRandom() {
|
||||
double trainingPercent = randomDoubleBetween(1.0, 100.0, true);
|
||||
double trainingFraction = trainingPercent / 100;
|
||||
CustomProcessor customProcessor = new RegressionCustomProcessor(fields, regression(dependentVariable, trainingPercent));
|
||||
CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, trainingPercent);
|
||||
|
||||
int runCount = 20;
|
||||
int rowsCount = 1000;
|
||||
|
@ -122,7 +121,7 @@ public class RegressionCustomProcessorTests extends ESTestCase {
|
|||
}
|
||||
|
||||
public void testProcess_ShouldHaveAtLeastOneTrainingRow() {
|
||||
CustomProcessor customProcessor = new RegressionCustomProcessor(fields, regression(dependentVariable, 1.0));
|
||||
CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, 1.0);
|
||||
|
||||
// We have some non-training rows and then a training row to check
|
||||
// we maintain the first training row and not just the first row
|
||||
|
@ -142,8 +141,4 @@ public class RegressionCustomProcessorTests extends ESTestCase {
|
|||
assertThat(Arrays.equals(processedRow, row), is(true));
|
||||
}
|
||||
}
|
||||
|
||||
private static Regression regression(String dependentVariable, double trainingPercent) {
|
||||
return new Regression(dependentVariable, null, null, null, null, null, null, trainingPercent);
|
||||
}
|
||||
}
|
|
@ -1231,6 +1231,346 @@ setup:
|
|||
- is_true: create_time
|
||||
- is_true: version
|
||||
|
||||
---
|
||||
"Test put classification given dependent_variable is not defined":
|
||||
|
||||
- do:
|
||||
catch: /parse_exception/
|
||||
ml.put_data_frame_analytics:
|
||||
id: "classification-without-dependent-variable"
|
||||
body: >
|
||||
{
|
||||
"source": {
|
||||
"index": "index-source"
|
||||
},
|
||||
"dest": {
|
||||
"index": "index-dest"
|
||||
},
|
||||
"analysis": {
|
||||
"classification": {}
|
||||
}
|
||||
}
|
||||
|
||||
---
|
||||
"Test put classification given negative lambda":
|
||||
|
||||
- do:
|
||||
catch: /\[lambda\] must be a non-negative double/
|
||||
ml.put_data_frame_analytics:
|
||||
id: "classification-negative-lambda"
|
||||
body: >
|
||||
{
|
||||
"source": {
|
||||
"index": "index-source"
|
||||
},
|
||||
"dest": {
|
||||
"index": "index-dest"
|
||||
},
|
||||
"analysis": {
|
||||
"classification": {
|
||||
"dependent_variable": "foo",
|
||||
"lambda": -1.0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
---
|
||||
"Test put classification given negative gamma":
|
||||
|
||||
- do:
|
||||
catch: /\[gamma\] must be a non-negative double/
|
||||
ml.put_data_frame_analytics:
|
||||
id: "classification-negative-gamma"
|
||||
body: >
|
||||
{
|
||||
"source": {
|
||||
"index": "index-source"
|
||||
},
|
||||
"dest": {
|
||||
"index": "index-dest"
|
||||
},
|
||||
"analysis": {
|
||||
"classification": {
|
||||
"dependent_variable": "foo",
|
||||
"gamma": -1.0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
---
|
||||
"Test put classification given eta less than 1e-3":
|
||||
|
||||
- do:
|
||||
catch: /\[eta\] must be a double in \[0.001, 1\]/
|
||||
ml.put_data_frame_analytics:
|
||||
id: "classification-eta-greater-less-than-valid"
|
||||
body: >
|
||||
{
|
||||
"source": {
|
||||
"index": "index-source"
|
||||
},
|
||||
"dest": {
|
||||
"index": "index-dest"
|
||||
},
|
||||
"analysis": {
|
||||
"classification": {
|
||||
"dependent_variable": "foo",
|
||||
"eta": 0.0009
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
---
|
||||
"Test put classification given eta greater than one":
|
||||
|
||||
- do:
|
||||
catch: /\[eta\] must be a double in \[0.001, 1\]/
|
||||
ml.put_data_frame_analytics:
|
||||
id: "classification-eta-greater-than-one"
|
||||
body: >
|
||||
{
|
||||
"source": {
|
||||
"index": "index-source"
|
||||
},
|
||||
"dest": {
|
||||
"index": "index-dest"
|
||||
},
|
||||
"analysis": {
|
||||
"classification": {
|
||||
"dependent_variable": "foo",
|
||||
"eta": 1.00001
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
---
|
||||
"Test put classification given maximum_number_trees is zero":
|
||||
|
||||
- do:
|
||||
catch: /\[maximum_number_trees\] must be an integer in \[1, 2000\]/
|
||||
ml.put_data_frame_analytics:
|
||||
id: "classification-maximum-number-trees-is-zero"
|
||||
body: >
|
||||
{
|
||||
"source": {
|
||||
"index": "index-source"
|
||||
},
|
||||
"dest": {
|
||||
"index": "index-dest"
|
||||
},
|
||||
"analysis": {
|
||||
"classification": {
|
||||
"dependent_variable": "foo",
|
||||
"maximum_number_trees": 0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
---
|
||||
"Test put classification given maximum_number_trees is greater than 2k":
|
||||
|
||||
- do:
|
||||
catch: /\[maximum_number_trees\] must be an integer in \[1, 2000\]/
|
||||
ml.put_data_frame_analytics:
|
||||
id: "classification-maximum-number-trees-greater-than-2k"
|
||||
body: >
|
||||
{
|
||||
"source": {
|
||||
"index": "index-source"
|
||||
},
|
||||
"dest": {
|
||||
"index": "index-dest"
|
||||
},
|
||||
"analysis": {
|
||||
"classification": {
|
||||
"dependent_variable": "foo",
|
||||
"maximum_number_trees": 2001
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
---
|
||||
"Test put classification given feature_bag_fraction is negative":
|
||||
|
||||
- do:
|
||||
catch: /\[feature_bag_fraction\] must be a double in \(0, 1\]/
|
||||
ml.put_data_frame_analytics:
|
||||
id: "classification-feature-bag-fraction-is-negative"
|
||||
body: >
|
||||
{
|
||||
"source": {
|
||||
"index": "index-source"
|
||||
},
|
||||
"dest": {
|
||||
"index": "index-dest"
|
||||
},
|
||||
"analysis": {
|
||||
"classification": {
|
||||
"dependent_variable": "foo",
|
||||
"feature_bag_fraction": -0.0001
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
---
|
||||
"Test put classification given feature_bag_fraction is greater than one":
|
||||
|
||||
- do:
|
||||
catch: /\[feature_bag_fraction\] must be a double in \(0, 1\]/
|
||||
ml.put_data_frame_analytics:
|
||||
id: "classification-feature-bag-fraction-is-greater-than-one"
|
||||
body: >
|
||||
{
|
||||
"source": {
|
||||
"index": "index-source"
|
||||
},
|
||||
"dest": {
|
||||
"index": "index-dest"
|
||||
},
|
||||
"analysis": {
|
||||
"classification": {
|
||||
"dependent_variable": "foo",
|
||||
"feature_bag_fraction": 1.0001
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
---
|
||||
"Test put classification given num_top_classes is less than zero":
|
||||
|
||||
- do:
|
||||
catch: /\[num_top_classes\] must be an integer in \[0, 1000\]/
|
||||
ml.put_data_frame_analytics:
|
||||
id: "classification-training-percent-is-less-than-one"
|
||||
body: >
|
||||
{
|
||||
"source": {
|
||||
"index": "index-source"
|
||||
},
|
||||
"dest": {
|
||||
"index": "index-dest"
|
||||
},
|
||||
"analysis": {
|
||||
"classification": {
|
||||
"dependent_variable": "foo",
|
||||
"num_top_classes": -1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
---
|
||||
"Test put classification given num_top_classes is greater than 1k":
|
||||
|
||||
- do:
|
||||
catch: /\[num_top_classes\] must be an integer in \[0, 1000\]/
|
||||
ml.put_data_frame_analytics:
|
||||
id: "classification-training-percent-is-greater-than-hundred"
|
||||
body: >
|
||||
{
|
||||
"source": {
|
||||
"index": "index-source"
|
||||
},
|
||||
"dest": {
|
||||
"index": "index-dest"
|
||||
},
|
||||
"analysis": {
|
||||
"classification": {
|
||||
"dependent_variable": "foo",
|
||||
"num_top_classes": 1001
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
---
|
||||
"Test put classification given training_percent is less than one":
|
||||
|
||||
- do:
|
||||
catch: /\[training_percent\] must be a double in \[1, 100\]/
|
||||
ml.put_data_frame_analytics:
|
||||
id: "classification-training-percent-is-less-than-one"
|
||||
body: >
|
||||
{
|
||||
"source": {
|
||||
"index": "index-source"
|
||||
},
|
||||
"dest": {
|
||||
"index": "index-dest"
|
||||
},
|
||||
"analysis": {
|
||||
"classification": {
|
||||
"dependent_variable": "foo",
|
||||
"training_percent": 0.999
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
---
|
||||
"Test put classification given training_percent is greater than hundred":
|
||||
|
||||
- do:
|
||||
catch: /\[training_percent\] must be a double in \[1, 100\]/
|
||||
ml.put_data_frame_analytics:
|
||||
id: "classification-training-percent-is-greater-than-hundred"
|
||||
body: >
|
||||
{
|
||||
"source": {
|
||||
"index": "index-source"
|
||||
},
|
||||
"dest": {
|
||||
"index": "index-dest"
|
||||
},
|
||||
"analysis": {
|
||||
"classification": {
|
||||
"dependent_variable": "foo",
|
||||
"training_percent": 100.1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
---
|
||||
"Test put classification given valid":
|
||||
|
||||
- do:
|
||||
ml.put_data_frame_analytics:
|
||||
id: "valid-classification"
|
||||
body: >
|
||||
{
|
||||
"source": {
|
||||
"index": "index-source"
|
||||
},
|
||||
"dest": {
|
||||
"index": "index-dest"
|
||||
},
|
||||
"analysis": {
|
||||
"classification": {
|
||||
"dependent_variable": "foo",
|
||||
"lambda": 3.14,
|
||||
"gamma": 0.42,
|
||||
"eta": 0.5,
|
||||
"maximum_number_trees": 400,
|
||||
"feature_bag_fraction": 0.3,
|
||||
"training_percent": 60.3
|
||||
}
|
||||
}
|
||||
}
|
||||
- match: { id: "valid-classification" }
|
||||
- match: { source.index: ["index-source"] }
|
||||
- match: { dest.index: "index-dest" }
|
||||
- match: { analysis: {
|
||||
"classification":{
|
||||
"dependent_variable": "foo",
|
||||
"lambda": 3.14,
|
||||
"gamma": 0.42,
|
||||
"eta": 0.5,
|
||||
"maximum_number_trees": 400,
|
||||
"feature_bag_fraction": 0.3,
|
||||
"training_percent": 60.3,
|
||||
"num_top_classes": 0
|
||||
}
|
||||
}}
|
||||
- is_true: create_time
|
||||
- is_true: version
|
||||
|
||||
---
|
||||
"Test put with description":
|
||||
|
||||
|
|
Loading…
Reference in New Issue