[7.x][ML] Add loss_function to regression (#56118) (#56187)

Adds parameters `loss_function` and `loss_function_parameter`
to regression.

Backport of #56118
This commit is contained in:
Dimitris Athanasiou 2020-05-05 14:59:51 +03:00 committed by GitHub
parent c38388c506
commit 75dadb7a6d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 313 additions and 50 deletions

View File

@ -22,12 +22,16 @@ import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.Locale;
import java.util.Objects;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
public class Regression implements DataFrameAnalysis {
public static Regression fromXContent(XContentParser parser) {
@ -50,6 +54,8 @@ public class Regression implements DataFrameAnalysis {
static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");
static final ParseField LOSS_FUNCTION = new ParseField("loss_function");
static final ParseField LOSS_FUNCTION_PARAMETER = new ParseField("loss_function_parameter");
private static final ConstructingObjectParser<Regression, Void> PARSER =
new ConstructingObjectParser<>(
@ -65,7 +71,10 @@ public class Regression implements DataFrameAnalysis {
(Integer) a[6],
(String) a[7],
(Double) a[8],
(Long) a[9]));
(Long) a[9],
(LossFunction) a[10],
(Double) a[11]
));
static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE);
@ -78,6 +87,13 @@ public class Regression implements DataFrameAnalysis {
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME);
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT);
PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), RANDOMIZE_SEED);
PARSER.declareField(optionalConstructorArg(), p -> {
if (p.currentToken() == XContentParser.Token.VALUE_STRING) {
return LossFunction.fromString(p.text());
}
throw new IllegalArgumentException("Unsupported token [" + p.currentToken() + "]");
}, LOSS_FUNCTION, ObjectParser.ValueType.STRING);
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), LOSS_FUNCTION_PARAMETER);
}
private final String dependentVariable;
@ -90,11 +106,14 @@ public class Regression implements DataFrameAnalysis {
private final String predictionFieldName;
private final Double trainingPercent;
private final Long randomizeSeed;
private final LossFunction lossFunction;
private final Double lossFunctionParameter;
private Regression(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
@Nullable Integer maxTrees, @Nullable Double featureBagFraction,
@Nullable Integer numTopFeatureImportanceValues, @Nullable String predictionFieldName,
@Nullable Double trainingPercent, @Nullable Long randomizeSeed) {
@Nullable Double trainingPercent, @Nullable Long randomizeSeed, @Nullable LossFunction lossFunction,
@Nullable Double lossFunctionParameter) {
this.dependentVariable = Objects.requireNonNull(dependentVariable);
this.lambda = lambda;
this.gamma = gamma;
@ -105,6 +124,8 @@ public class Regression implements DataFrameAnalysis {
this.predictionFieldName = predictionFieldName;
this.trainingPercent = trainingPercent;
this.randomizeSeed = randomizeSeed;
this.lossFunction = lossFunction;
this.lossFunctionParameter = lossFunctionParameter;
}
@Override
@ -152,6 +173,14 @@ public class Regression implements DataFrameAnalysis {
return randomizeSeed;
}
public LossFunction getLossFunction() {
return lossFunction;
}
public Double getLossFunctionParameter() {
return lossFunctionParameter;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
@ -183,6 +212,12 @@ public class Regression implements DataFrameAnalysis {
if (randomizeSeed != null) {
builder.field(RANDOMIZE_SEED.getPreferredName(), randomizeSeed);
}
if (lossFunction != null) {
builder.field(LOSS_FUNCTION.getPreferredName(), lossFunction);
}
if (lossFunctionParameter != null) {
builder.field(LOSS_FUNCTION_PARAMETER.getPreferredName(), lossFunctionParameter);
}
builder.endObject();
return builder;
}
@ -190,7 +225,7 @@ public class Regression implements DataFrameAnalysis {
@Override
public int hashCode() {
return Objects.hash(dependentVariable, lambda, gamma, eta, maxTrees, featureBagFraction, numTopFeatureImportanceValues,
predictionFieldName, trainingPercent, randomizeSeed);
predictionFieldName, trainingPercent, randomizeSeed, lossFunction, lossFunctionParameter);
}
@Override
@ -207,7 +242,9 @@ public class Regression implements DataFrameAnalysis {
&& Objects.equals(numTopFeatureImportanceValues, that.numTopFeatureImportanceValues)
&& Objects.equals(predictionFieldName, that.predictionFieldName)
&& Objects.equals(trainingPercent, that.trainingPercent)
&& Objects.equals(randomizeSeed, that.randomizeSeed);
&& Objects.equals(randomizeSeed, that.randomizeSeed)
&& Objects.equals(lossFunction, that.lossFunction)
&& Objects.equals(lossFunctionParameter, that.lossFunctionParameter);
}
@Override
@ -226,6 +263,8 @@ public class Regression implements DataFrameAnalysis {
private String predictionFieldName;
private Double trainingPercent;
private Long randomizeSeed;
private LossFunction lossFunction;
private Double lossFunctionParameter;
private Builder(String dependentVariable) {
this.dependentVariable = Objects.requireNonNull(dependentVariable);
@ -276,9 +315,32 @@ public class Regression implements DataFrameAnalysis {
return this;
}
public Builder setLossFunction(LossFunction lossFunction) {
this.lossFunction = lossFunction;
return this;
}
public Builder setLossFunctionParameter(Double lossFunctionParameter) {
this.lossFunctionParameter = lossFunctionParameter;
return this;
}
public Regression build() {
return new Regression(dependentVariable, lambda, gamma, eta, maxTrees, featureBagFraction,
numTopFeatureImportanceValues, predictionFieldName, trainingPercent, randomizeSeed);
numTopFeatureImportanceValues, predictionFieldName, trainingPercent, randomizeSeed, lossFunction, lossFunctionParameter);
}
}
public enum LossFunction {
MSE, MSLE, HUBER;
private static LossFunction fromString(String value) {
return LossFunction.valueOf(value.toUpperCase(Locale.ROOT));
}
@Override
public String toString() {
return name().toLowerCase(Locale.ROOT);
}
}
}

View File

@ -1356,6 +1356,8 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
.setMaxTrees(10)
.setFeatureBagFraction(0.5)
.setNumTopFeatureImportanceValues(3)
.setLossFunction(org.elasticsearch.client.ml.dataframe.Regression.LossFunction.MSLE)
.setLossFunctionParameter(1.0)
.build())
.setDescription("this is a regression")
.build();

View File

@ -151,6 +151,7 @@ import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsState;
import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsStats;
import org.elasticsearch.client.ml.dataframe.OutlierDetection;
import org.elasticsearch.client.ml.dataframe.QueryConfig;
import org.elasticsearch.client.ml.dataframe.Regression;
import org.elasticsearch.client.ml.dataframe.evaluation.Evaluation;
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric;
@ -3007,6 +3008,8 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
.setPredictionFieldName("my_prediction_field_name") // <8>
.setTrainingPercent(50.0) // <9>
.setRandomizeSeed(1234L) // <10>
.setLossFunction(Regression.LossFunction.MSE) // <11>
.setLossFunctionParameter(1.0) // <12>
.build();
// end::put-data-frame-analytics-regression

View File

@ -35,6 +35,8 @@ public class RegressionTests extends AbstractXContentTestCase<Regression> {
.setNumTopFeatureImportanceValues(randomBoolean() ? null : randomIntBetween(0, Integer.MAX_VALUE))
.setPredictionFieldName(randomBoolean() ? null : randomAlphaOfLength(10))
.setTrainingPercent(randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true))
.setLossFunction(randomBoolean() ? null : randomFrom(Regression.LossFunction.values()))
.setLossFunctionParameter(randomBoolean() ? null : randomDoubleBetween(1.0, Double.MAX_VALUE, true))
.build();
}

View File

@ -143,6 +143,8 @@ include-tagged::{doc-tests-file}[{api}-regression]
<8> The name of the prediction field in the results object.
<9> The percentage of training-eligible rows to be used in training. Defaults to 100%.
<10> The seed to be used by the random generator that picks which rows are used in training.
<11> The loss function used for regression. Defaults to `mse`.
<12> An optional parameter to the loss function.
==== Analyzed fields

View File

@ -225,6 +225,15 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=gamma]
(Optional, double)
include::{docdir}/ml/ml-shared.asciidoc[tag=lambda]
`loss_function`::::
(Optional, string)
The loss function used during regression. Available options are `mse` (mean squared error),
`msle` (mean squared logarithmic error), `huber` (Pseudo-Huber loss). Defaults to `mse`.
`loss_function_parameter`::::
(Optional, double)
A strictly positive number that is used as a parameter to the `loss_function`.
`max_trees`::::
(Optional, integer)
include::{docdir}/ml/ml-shared.asciidoc[tag=max-trees]

View File

@ -12,6 +12,7 @@ import org.elasticsearch.common.Randomness;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.mapper.NumberFieldMapper;
@ -21,6 +22,7 @@ import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
@ -36,6 +38,8 @@ public class Regression implements DataFrameAnalysis {
public static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
public static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
public static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");
public static final ParseField LOSS_FUNCTION = new ParseField("loss_function");
public static final ParseField LOSS_FUNCTION_PARAMETER = new ParseField("loss_function_parameter");
private static final String STATE_DOC_ID_SUFFIX = "_regression_state#1";
@ -51,12 +55,21 @@ public class Regression implements DataFrameAnalysis {
new BoostedTreeParams((Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5], (Integer) a[6]),
(String) a[7],
(Double) a[8],
(Long) a[9]));
(Long) a[9],
(LossFunction) a[10],
(Double) a[11]));
parser.declareString(constructorArg(), DEPENDENT_VARIABLE);
BoostedTreeParams.declareFields(parser);
parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME);
parser.declareDouble(optionalConstructorArg(), TRAINING_PERCENT);
parser.declareLong(optionalConstructorArg(), RANDOMIZE_SEED);
parser.declareField(optionalConstructorArg(), p -> {
if (p.currentToken() == XContentParser.Token.VALUE_STRING) {
return LossFunction.fromString(p.text());
}
throw new IllegalArgumentException("Unsupported token [" + p.currentToken() + "]");
}, LOSS_FUNCTION, ObjectParser.ValueType.STRING);
parser.declareDouble(optionalConstructorArg(), LOSS_FUNCTION_PARAMETER);
return parser;
}
@ -69,12 +82,16 @@ public class Regression implements DataFrameAnalysis {
private final String predictionFieldName;
private final double trainingPercent;
private final long randomizeSeed;
private final LossFunction lossFunction;
private final Double lossFunctionParameter;
public Regression(String dependentVariable,
BoostedTreeParams boostedTreeParams,
@Nullable String predictionFieldName,
@Nullable Double trainingPercent,
@Nullable Long randomizeSeed) {
@Nullable Long randomizeSeed,
@Nullable LossFunction lossFunction,
@Nullable Double lossFunctionParameter) {
if (trainingPercent != null && (trainingPercent < 1.0 || trainingPercent > 100.0)) {
throw ExceptionsHelper.badRequestException("[{}] must be a double in [1, 100]", TRAINING_PERCENT.getPreferredName());
}
@ -83,10 +100,16 @@ public class Regression implements DataFrameAnalysis {
this.predictionFieldName = predictionFieldName == null ? dependentVariable + "_prediction" : predictionFieldName;
this.trainingPercent = trainingPercent == null ? 100.0 : trainingPercent;
this.randomizeSeed = randomizeSeed == null ? Randomness.get().nextLong() : randomizeSeed;
// Prior to introducing the loss function setting only MSE was implemented
this.lossFunction = lossFunction == null ? LossFunction.MSE : lossFunction;
if (lossFunctionParameter != null && lossFunctionParameter <= 0.0) {
throw ExceptionsHelper.badRequestException("[{}] must be a positive double", LOSS_FUNCTION_PARAMETER.getPreferredName());
}
this.lossFunctionParameter = lossFunctionParameter;
}
public Regression(String dependentVariable) {
this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null);
this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null, null);
}
public Regression(StreamInput in) throws IOException {
@ -99,6 +122,14 @@ public class Regression implements DataFrameAnalysis {
} else {
randomizeSeed = Randomness.get().nextLong();
}
if (in.getVersion().onOrAfter(Version.V_7_8_0)) {
lossFunction = in.readEnum(LossFunction.class);
lossFunctionParameter = in.readOptionalDouble();
} else {
// Prior to introducing the loss function setting only MSE was implemented
lossFunction = LossFunction.MSE;
lossFunctionParameter = null;
}
}
public String getDependentVariable() {
@ -121,6 +152,14 @@ public class Regression implements DataFrameAnalysis {
return randomizeSeed;
}
public LossFunction getLossFunction() {
return lossFunction;
}
public Double getLossFunctionParameter() {
return lossFunctionParameter;
}
@Override
public String getWriteableName() {
return NAME.getPreferredName();
@ -135,6 +174,10 @@ public class Regression implements DataFrameAnalysis {
if (out.getVersion().onOrAfter(Version.V_7_6_0)) {
out.writeOptionalLong(randomizeSeed);
}
if (out.getVersion().onOrAfter(Version.V_7_8_0)) {
out.writeEnum(lossFunction);
out.writeOptionalDouble(lossFunctionParameter);
}
}
@Override
@ -151,6 +194,10 @@ public class Regression implements DataFrameAnalysis {
if (version.onOrAfter(Version.V_7_6_0)) {
builder.field(RANDOMIZE_SEED.getPreferredName(), randomizeSeed);
}
builder.field(LOSS_FUNCTION.getPreferredName(), lossFunction);
if (lossFunctionParameter != null) {
builder.field(LOSS_FUNCTION_PARAMETER.getPreferredName(), lossFunctionParameter);
}
builder.endObject();
return builder;
}
@ -164,6 +211,10 @@ public class Regression implements DataFrameAnalysis {
params.put(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
}
params.put(TRAINING_PERCENT.getPreferredName(), trainingPercent);
params.put(LOSS_FUNCTION.getPreferredName(), lossFunction.toString());
if (lossFunctionParameter != null) {
params.put(LOSS_FUNCTION_PARAMETER.getPreferredName(), lossFunctionParameter);
}
return params;
}
@ -232,11 +283,27 @@ public class Regression implements DataFrameAnalysis {
&& Objects.equals(boostedTreeParams, that.boostedTreeParams)
&& Objects.equals(predictionFieldName, that.predictionFieldName)
&& trainingPercent == that.trainingPercent
&& randomizeSeed == that.randomizeSeed;
&& randomizeSeed == that.randomizeSeed
&& lossFunction == that.lossFunction
&& Objects.equals(lossFunctionParameter, that.lossFunctionParameter);
}
@Override
public int hashCode() {
return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, trainingPercent, randomizeSeed);
return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, trainingPercent, randomizeSeed, lossFunction,
lossFunctionParameter);
}
public enum LossFunction {
MSE, MSLE, HUBER;
private static LossFunction fromString(String value) {
return LossFunction.valueOf(value.toUpperCase(Locale.ROOT));
}
@Override
public String toString() {
return name().toLowerCase(Locale.ROOT);
}
}
}

View File

@ -315,6 +315,8 @@ public final class ReservedFieldNames {
OutlierDetection.FEATURE_INFLUENCE_THRESHOLD.getPreferredName(),
Regression.NAME.getPreferredName(),
Regression.DEPENDENT_VARIABLE.getPreferredName(),
Regression.LOSS_FUNCTION.getPreferredName(),
Regression.LOSS_FUNCTION_PARAMETER.getPreferredName(),
Regression.PREDICTION_FIELD_NAME.getPreferredName(),
Regression.TRAINING_PERCENT.getPreferredName(),
Classification.NAME.getPreferredName(),

View File

@ -90,6 +90,12 @@
"lambda" : {
"type" : "double"
},
"loss_function" : {
"type" : "keyword"
},
"loss_function_parameter" : {
"type" : "double"
},
"max_trees" : {
"type" : "integer"
},

View File

@ -138,12 +138,16 @@ public class DataFrameAnalyticsConfigTests extends AbstractBWCSerializationTestC
bwcRegression.getBoostedTreeParams(),
bwcRegression.getPredictionFieldName(),
bwcRegression.getTrainingPercent(),
42L);
42L,
bwcRegression.getLossFunction(),
bwcRegression.getLossFunctionParameter());
testAnalysis = new Regression(testRegression.getDependentVariable(),
testRegression.getBoostedTreeParams(),
testRegression.getPredictionFieldName(),
testRegression.getTrainingPercent(),
42L);
42L,
testRegression.getLossFunction(),
testRegression.getLossFunctionParameter());
} else {
Classification testClassification = (Classification)testInstance.getAnalysis();
Classification bwcClassification = (Classification)bwcSerializedObject.getAnalysis();

View File

@ -16,8 +16,8 @@ import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
import java.io.IOException;
import java.util.Map;
import java.util.Collections;
import java.util.Map;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty;
@ -43,12 +43,18 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
}
public static Regression createRandom() {
return createRandom(BoostedTreeParamsTests.createRandom());
}
private static Regression createRandom(BoostedTreeParams boostedTreeParams) {
String dependentVariableName = randomAlphaOfLength(10);
BoostedTreeParams boostedTreeParams = BoostedTreeParamsTests.createRandom();
String predictionFieldName = randomBoolean() ? null : randomAlphaOfLength(10);
Double trainingPercent = randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true);
Long randomizeSeed = randomBoolean() ? null : randomLong();
return new Regression(dependentVariableName, boostedTreeParams, predictionFieldName, trainingPercent, randomizeSeed);
Regression.LossFunction lossFunction = randomBoolean() ? null : randomFrom(Regression.LossFunction.values());
Double lossFunctionParameter = randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, false);
return new Regression(dependentVariableName, boostedTreeParams, predictionFieldName, trainingPercent, randomizeSeed, lossFunction,
lossFunctionParameter);
}
public static Regression mutateForVersion(Regression instance, Version version) {
@ -56,7 +62,9 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
BoostedTreeParamsTests.mutateForVersion(instance.getBoostedTreeParams(), version),
instance.getPredictionFieldName(),
instance.getTrainingPercent(),
instance.getRandomizeSeed());
instance.getRandomizeSeed(),
version.onOrAfter(Version.V_7_8_0) ? instance.getLossFunction() : null,
version.onOrAfter(Version.V_7_8_0) ? instance.getLossFunctionParameter() : null);
}
@Override
@ -70,12 +78,16 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
bwcSerializedObject.getBoostedTreeParams(),
bwcSerializedObject.getPredictionFieldName(),
bwcSerializedObject.getTrainingPercent(),
42L);
42L,
bwcSerializedObject.getLossFunction(),
bwcSerializedObject.getLossFunctionParameter());
Regression newInstance = new Regression(testInstance.getDependentVariable(),
testInstance.getBoostedTreeParams(),
testInstance.getPredictionFieldName(),
testInstance.getTrainingPercent(),
42L);
42L,
testInstance.getLossFunction(),
testInstance.getLossFunctionParameter());
super.assertOnBWCObject(newBwc, newInstance, version);
}
@ -91,60 +103,93 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
public void testConstructor_GivenTrainingPercentIsLessThanOne() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 0.999, randomLong()));
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 0.999, randomLong(), Regression.LossFunction.MSE, null));
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
}
public void testConstructor_GivenTrainingPercentIsGreaterThan100() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0001, randomLong()));
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0001, randomLong(), Regression.LossFunction.MSE, null));
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
}
public void testConstructor_GivenLossFunctionParameterIsZero() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong(), Regression.LossFunction.MSE, 0.0));
assertThat(e.getMessage(), equalTo("[loss_function_parameter] must be a positive double"));
}
public void testConstructor_GivenLossFunctionParameterIsNegative() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong(), Regression.LossFunction.MSE, -1.0));
assertThat(e.getMessage(), equalTo("[loss_function_parameter] must be a positive double"));
}
public void testGetPredictionFieldName() {
Regression regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 50.0, randomLong());
Regression regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 50.0, randomLong(), Regression.LossFunction.MSE, 1.0);
assertThat(regression.getPredictionFieldName(), equalTo("result"));
regression = new Regression("foo", BOOSTED_TREE_PARAMS, null, 50.0, randomLong());
regression = new Regression("foo", BOOSTED_TREE_PARAMS, null, 50.0, randomLong(), Regression.LossFunction.MSE, null);
assertThat(regression.getPredictionFieldName(), equalTo("foo_prediction"));
}
public void testGetTrainingPercent() {
Regression regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 50.0, randomLong());
Regression regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 50.0, randomLong(), Regression.LossFunction.MSE, 1.0);
assertThat(regression.getTrainingPercent(), equalTo(50.0));
// Boundary condition: training_percent == 1.0
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 1.0, randomLong());
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 1.0, randomLong(), Regression.LossFunction.MSE, null);
assertThat(regression.getTrainingPercent(), equalTo(1.0));
// Boundary condition: training_percent == 100.0
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong());
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong(), Regression.LossFunction.MSE, null);
assertThat(regression.getTrainingPercent(), equalTo(100.0));
// training_percent == null, default applied
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", null, randomLong());
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", null, randomLong(), Regression.LossFunction.MSE, null);
assertThat(regression.getTrainingPercent(), equalTo(100.0));
}
public void testGetParams() {
assertThat(
new Regression("foo").getParams(null),
equalTo(org.elasticsearch.common.collect.Map.of(
"dependent_variable", "foo",
"prediction_field_name", "foo_prediction",
"training_percent", 100.0)));
assertThat(
new Regression("foo",
BoostedTreeParams.builder().build(),
null,
50.0,
null).getParams(null),
equalTo(org.elasticsearch.common.collect.Map.of(
"dependent_variable", "foo",
"prediction_field_name", "foo_prediction",
"training_percent", 50.0)));
public void testGetParams_ShouldIncludeBoostedTreeParams() {
int maxTrees = randomIntBetween(1, 100);
Regression regression = new Regression("foo",
BoostedTreeParams.builder().setMaxTrees(maxTrees).build(),
null,
100.0,
0L,
Regression.LossFunction.MSE,
null);
Map<String, Object> params = regression.getParams(null);
assertThat(params.size(), equalTo(5));
assertThat(params.get("dependent_variable"), equalTo("foo"));
assertThat(params.get("prediction_field_name"), equalTo("foo_prediction"));
assertThat(params.get("max_trees"), equalTo(maxTrees));
assertThat(params.get("training_percent"), equalTo(100.0));
assertThat(params.get("loss_function"), equalTo("mse"));
}
public void testGetParams_GivenRandomWithoutBoostedTreeParams() {
Regression regression = createRandom(BoostedTreeParams.builder().build());
Map<String, Object> params = regression.getParams(null);
int expectedParamsCount = 4 + (regression.getLossFunctionParameter() == null ? 0 : 1);
assertThat(params.size(), equalTo(expectedParamsCount));
assertThat(params.get("dependent_variable"), equalTo(regression.getDependentVariable()));
assertThat(params.get("prediction_field_name"), equalTo(regression.getPredictionFieldName()));
assertThat(params.get("training_percent"), equalTo(regression.getTrainingPercent()));
assertThat(params.get("loss_function"), equalTo(regression.getLossFunction().toString()));
if (regression.getLossFunctionParameter() == null) {
assertThat(params.containsKey("loss_function_parameter"), is(false));
} else {
assertThat(params.get("loss_function_parameter"), equalTo(regression.getLossFunctionParameter()));
}
}
public void testRequiredFieldsIsNonEmpty() {

View File

@ -83,6 +83,8 @@ integTest.runner {
'ml/data_frame_analytics_crud/Test put regression given feature_bag_fraction is greater than one',
'ml/data_frame_analytics_crud/Test put regression given training_percent is less than one',
'ml/data_frame_analytics_crud/Test put regression given training_percent is greater than hundred',
'ml/data_frame_analytics_crud/Test put regression given loss_function_parameter is zero',
'ml/data_frame_analytics_crud/Test put regression given loss_function_parameter is negative',
'ml/data_frame_analytics_crud/Test put classification given dependent_variable is not defined',
'ml/data_frame_analytics_crud/Test put classification given negative lambda',
'ml/data_frame_analytics_crud/Test put classification given negative gamma',

View File

@ -99,6 +99,8 @@ public class ExplainDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsInteg
BoostedTreeParams.builder().build(),
null,
100.0,
null,
null,
null))
.buildForExplain();
@ -115,6 +117,8 @@ public class ExplainDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsInteg
BoostedTreeParams.builder().build(),
null,
50.0,
null,
null,
null))
.buildForExplain();

View File

@ -67,6 +67,8 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
BoostedTreeParams.builder().setNumTopFeatureImportanceValues(1).build(),
null,
null,
null,
null,
null)
);
putAnalytics(config);
@ -180,7 +182,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
sourceIndex,
destIndex,
null,
new Regression(DEPENDENT_VARIABLE_FIELD, BoostedTreeParams.builder().build(), null, 50.0, null));
new Regression(DEPENDENT_VARIABLE_FIELD, BoostedTreeParams.builder().build(), null, 50.0, null, null, null));
putAnalytics(config);
assertIsStopped(jobId);
@ -307,7 +309,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
.build();
DataFrameAnalyticsConfig firstJob = buildAnalytics(firstJobId, sourceIndex, firstJobDestIndex, null,
new Regression(DEPENDENT_VARIABLE_FIELD, boostedTreeParams, null, 50.0, null));
new Regression(DEPENDENT_VARIABLE_FIELD, boostedTreeParams, null, 50.0, null, null, null));
putAnalytics(firstJob);
String secondJobId = "regression_two_jobs_with_same_randomize_seed_2";
@ -315,7 +317,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
long randomizeSeed = ((Regression) firstJob.getAnalysis()).getRandomizeSeed();
DataFrameAnalyticsConfig secondJob = buildAnalytics(secondJobId, sourceIndex, secondJobDestIndex, null,
new Regression(DEPENDENT_VARIABLE_FIELD, boostedTreeParams, null, 50.0, randomizeSeed));
new Regression(DEPENDENT_VARIABLE_FIELD, boostedTreeParams, null, 50.0, randomizeSeed, null, null));
putAnalytics(secondJob);
@ -376,7 +378,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
sourceIndex,
destIndex,
null,
new Regression(DISCRETE_NUMERICAL_FEATURE_FIELD, BoostedTreeParams.builder().build(), null, null, null));
new Regression(DISCRETE_NUMERICAL_FEATURE_FIELD, BoostedTreeParams.builder().build(), null, null, null, null, null));
putAnalytics(config);
assertIsStopped(jobId);

View File

@ -1552,6 +1552,52 @@ setup:
}
}
---
"Test put regression given loss_function_parameter is zero":
- do:
catch: /\[loss_function_parameter\] must be a positive double/
ml.put_data_frame_analytics:
id: "regression-loss-function-param-is-zero"
body: >
{
"source": {
"index": "index-source"
},
"dest": {
"index": "index-dest"
},
"analysis": {
"regression": {
"dependent_variable": "foo",
"loss_function_parameter": 0.0
}
}
}
---
"Test put regression given loss_function_parameter is negative":
- do:
catch: /\[loss_function_parameter\] must be a positive double/
ml.put_data_frame_analytics:
id: "regression-loss-function-param-is-negative"
body: >
{
"source": {
"index": "index-source"
},
"dest": {
"index": "index-dest"
},
"analysis": {
"regression": {
"dependent_variable": "foo",
"loss_function_parameter": -0.01
}
}
}
---
"Test put regression given valid":
@ -1575,7 +1621,9 @@ setup:
"max_trees": 400,
"feature_bag_fraction": 0.3,
"training_percent": 60.3,
"randomize_seed": 42
"randomize_seed": 42,
"loss_function": "msle",
"loss_function_parameter": 2.0
}
}
}
@ -1592,7 +1640,9 @@ setup:
"feature_bag_fraction": 0.3,
"prediction_field_name": "foo_prediction",
"training_percent": 60.3,
"randomize_seed": 42
"randomize_seed": 42,
"loss_function": "msle",
"loss_function_parameter": 2.0
}
}}
- is_true: create_time
@ -2018,7 +2068,8 @@ setup:
"dependent_variable": "foo",
"prediction_field_name": "foo_prediction",
"training_percent": 100.0,
"randomize_seed": 42
"randomize_seed": 42,
"loss_function": "mse"
}
}}
- is_true: create_time