mirror of
https://github.com/honeymoose/OpenSearch.git
synced 2025-03-24 17:09:48 +00:00
Adds parameters `loss_function` and `loss_function_parameter` to regression. Backport of #56118
This commit is contained in:
parent
c38388c506
commit
75dadb7a6d
@ -22,12 +22,16 @@ import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Locale;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||
|
||||
public class Regression implements DataFrameAnalysis {
|
||||
|
||||
public static Regression fromXContent(XContentParser parser) {
|
||||
@ -50,6 +54,8 @@ public class Regression implements DataFrameAnalysis {
|
||||
static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
|
||||
static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
|
||||
static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");
|
||||
static final ParseField LOSS_FUNCTION = new ParseField("loss_function");
|
||||
static final ParseField LOSS_FUNCTION_PARAMETER = new ParseField("loss_function_parameter");
|
||||
|
||||
private static final ConstructingObjectParser<Regression, Void> PARSER =
|
||||
new ConstructingObjectParser<>(
|
||||
@ -65,7 +71,10 @@ public class Regression implements DataFrameAnalysis {
|
||||
(Integer) a[6],
|
||||
(String) a[7],
|
||||
(Double) a[8],
|
||||
(Long) a[9]));
|
||||
(Long) a[9],
|
||||
(LossFunction) a[10],
|
||||
(Double) a[11]
|
||||
));
|
||||
|
||||
static {
|
||||
PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE);
|
||||
@ -78,6 +87,13 @@ public class Regression implements DataFrameAnalysis {
|
||||
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME);
|
||||
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT);
|
||||
PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), RANDOMIZE_SEED);
|
||||
PARSER.declareField(optionalConstructorArg(), p -> {
|
||||
if (p.currentToken() == XContentParser.Token.VALUE_STRING) {
|
||||
return LossFunction.fromString(p.text());
|
||||
}
|
||||
throw new IllegalArgumentException("Unsupported token [" + p.currentToken() + "]");
|
||||
}, LOSS_FUNCTION, ObjectParser.ValueType.STRING);
|
||||
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), LOSS_FUNCTION_PARAMETER);
|
||||
}
|
||||
|
||||
private final String dependentVariable;
|
||||
@ -90,11 +106,14 @@ public class Regression implements DataFrameAnalysis {
|
||||
private final String predictionFieldName;
|
||||
private final Double trainingPercent;
|
||||
private final Long randomizeSeed;
|
||||
private final LossFunction lossFunction;
|
||||
private final Double lossFunctionParameter;
|
||||
|
||||
private Regression(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
|
||||
@Nullable Integer maxTrees, @Nullable Double featureBagFraction,
|
||||
@Nullable Integer numTopFeatureImportanceValues, @Nullable String predictionFieldName,
|
||||
@Nullable Double trainingPercent, @Nullable Long randomizeSeed) {
|
||||
@Nullable Double trainingPercent, @Nullable Long randomizeSeed, @Nullable LossFunction lossFunction,
|
||||
@Nullable Double lossFunctionParameter) {
|
||||
this.dependentVariable = Objects.requireNonNull(dependentVariable);
|
||||
this.lambda = lambda;
|
||||
this.gamma = gamma;
|
||||
@ -105,6 +124,8 @@ public class Regression implements DataFrameAnalysis {
|
||||
this.predictionFieldName = predictionFieldName;
|
||||
this.trainingPercent = trainingPercent;
|
||||
this.randomizeSeed = randomizeSeed;
|
||||
this.lossFunction = lossFunction;
|
||||
this.lossFunctionParameter = lossFunctionParameter;
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -152,6 +173,14 @@ public class Regression implements DataFrameAnalysis {
|
||||
return randomizeSeed;
|
||||
}
|
||||
|
||||
public LossFunction getLossFunction() {
|
||||
return lossFunction;
|
||||
}
|
||||
|
||||
public Double getLossFunctionParameter() {
|
||||
return lossFunctionParameter;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
@ -183,6 +212,12 @@ public class Regression implements DataFrameAnalysis {
|
||||
if (randomizeSeed != null) {
|
||||
builder.field(RANDOMIZE_SEED.getPreferredName(), randomizeSeed);
|
||||
}
|
||||
if (lossFunction != null) {
|
||||
builder.field(LOSS_FUNCTION.getPreferredName(), lossFunction);
|
||||
}
|
||||
if (lossFunctionParameter != null) {
|
||||
builder.field(LOSS_FUNCTION_PARAMETER.getPreferredName(), lossFunctionParameter);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
@ -190,7 +225,7 @@ public class Regression implements DataFrameAnalysis {
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(dependentVariable, lambda, gamma, eta, maxTrees, featureBagFraction, numTopFeatureImportanceValues,
|
||||
predictionFieldName, trainingPercent, randomizeSeed);
|
||||
predictionFieldName, trainingPercent, randomizeSeed, lossFunction, lossFunctionParameter);
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -207,7 +242,9 @@ public class Regression implements DataFrameAnalysis {
|
||||
&& Objects.equals(numTopFeatureImportanceValues, that.numTopFeatureImportanceValues)
|
||||
&& Objects.equals(predictionFieldName, that.predictionFieldName)
|
||||
&& Objects.equals(trainingPercent, that.trainingPercent)
|
||||
&& Objects.equals(randomizeSeed, that.randomizeSeed);
|
||||
&& Objects.equals(randomizeSeed, that.randomizeSeed)
|
||||
&& Objects.equals(lossFunction, that.lossFunction)
|
||||
&& Objects.equals(lossFunctionParameter, that.lossFunctionParameter);
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -226,6 +263,8 @@ public class Regression implements DataFrameAnalysis {
|
||||
private String predictionFieldName;
|
||||
private Double trainingPercent;
|
||||
private Long randomizeSeed;
|
||||
private LossFunction lossFunction;
|
||||
private Double lossFunctionParameter;
|
||||
|
||||
private Builder(String dependentVariable) {
|
||||
this.dependentVariable = Objects.requireNonNull(dependentVariable);
|
||||
@ -276,9 +315,32 @@ public class Regression implements DataFrameAnalysis {
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setLossFunction(LossFunction lossFunction) {
|
||||
this.lossFunction = lossFunction;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setLossFunctionParameter(Double lossFunctionParameter) {
|
||||
this.lossFunctionParameter = lossFunctionParameter;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Regression build() {
|
||||
return new Regression(dependentVariable, lambda, gamma, eta, maxTrees, featureBagFraction,
|
||||
numTopFeatureImportanceValues, predictionFieldName, trainingPercent, randomizeSeed);
|
||||
numTopFeatureImportanceValues, predictionFieldName, trainingPercent, randomizeSeed, lossFunction, lossFunctionParameter);
|
||||
}
|
||||
}
|
||||
|
||||
public enum LossFunction {
|
||||
MSE, MSLE, HUBER;
|
||||
|
||||
private static LossFunction fromString(String value) {
|
||||
return LossFunction.valueOf(value.toUpperCase(Locale.ROOT));
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return name().toLowerCase(Locale.ROOT);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1356,6 +1356,8 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
||||
.setMaxTrees(10)
|
||||
.setFeatureBagFraction(0.5)
|
||||
.setNumTopFeatureImportanceValues(3)
|
||||
.setLossFunction(org.elasticsearch.client.ml.dataframe.Regression.LossFunction.MSLE)
|
||||
.setLossFunctionParameter(1.0)
|
||||
.build())
|
||||
.setDescription("this is a regression")
|
||||
.build();
|
||||
|
@ -151,6 +151,7 @@ import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsState;
|
||||
import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsStats;
|
||||
import org.elasticsearch.client.ml.dataframe.OutlierDetection;
|
||||
import org.elasticsearch.client.ml.dataframe.QueryConfig;
|
||||
import org.elasticsearch.client.ml.dataframe.Regression;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.Evaluation;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric;
|
||||
@ -3007,6 +3008,8 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
||||
.setPredictionFieldName("my_prediction_field_name") // <8>
|
||||
.setTrainingPercent(50.0) // <9>
|
||||
.setRandomizeSeed(1234L) // <10>
|
||||
.setLossFunction(Regression.LossFunction.MSE) // <11>
|
||||
.setLossFunctionParameter(1.0) // <12>
|
||||
.build();
|
||||
// end::put-data-frame-analytics-regression
|
||||
|
||||
|
@ -35,6 +35,8 @@ public class RegressionTests extends AbstractXContentTestCase<Regression> {
|
||||
.setNumTopFeatureImportanceValues(randomBoolean() ? null : randomIntBetween(0, Integer.MAX_VALUE))
|
||||
.setPredictionFieldName(randomBoolean() ? null : randomAlphaOfLength(10))
|
||||
.setTrainingPercent(randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true))
|
||||
.setLossFunction(randomBoolean() ? null : randomFrom(Regression.LossFunction.values()))
|
||||
.setLossFunctionParameter(randomBoolean() ? null : randomDoubleBetween(1.0, Double.MAX_VALUE, true))
|
||||
.build();
|
||||
}
|
||||
|
||||
|
@ -143,6 +143,8 @@ include-tagged::{doc-tests-file}[{api}-regression]
|
||||
<8> The name of the prediction field in the results object.
|
||||
<9> The percentage of training-eligible rows to be used in training. Defaults to 100%.
|
||||
<10> The seed to be used by the random generator that picks which rows are used in training.
|
||||
<11> The loss function used for regression. Defaults to `mse`.
|
||||
<12> An optional parameter to the loss function.
|
||||
|
||||
==== Analyzed fields
|
||||
|
||||
|
@ -225,6 +225,15 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=gamma]
|
||||
(Optional, double)
|
||||
include::{docdir}/ml/ml-shared.asciidoc[tag=lambda]
|
||||
|
||||
`loss_function`::::
|
||||
(Optional, string)
|
||||
The loss function used during regression. Available options are `mse` (mean squared error),
|
||||
`msle` (mean squared logarithmic error), `huber` (Pseudo-Huber loss). Defaults to `mse`.
|
||||
|
||||
`loss_function_parameter`::::
|
||||
(Optional, double)
|
||||
A strictly positive number that is used as a parameter to the `loss_function`.
|
||||
|
||||
`max_trees`::::
|
||||
(Optional, integer)
|
||||
include::{docdir}/ml/ml-shared.asciidoc[tag=max-trees]
|
||||
|
@ -12,6 +12,7 @@ import org.elasticsearch.common.Randomness;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.index.mapper.NumberFieldMapper;
|
||||
@ -21,6 +22,7 @@ import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
@ -36,6 +38,8 @@ public class Regression implements DataFrameAnalysis {
|
||||
public static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
|
||||
public static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
|
||||
public static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");
|
||||
public static final ParseField LOSS_FUNCTION = new ParseField("loss_function");
|
||||
public static final ParseField LOSS_FUNCTION_PARAMETER = new ParseField("loss_function_parameter");
|
||||
|
||||
private static final String STATE_DOC_ID_SUFFIX = "_regression_state#1";
|
||||
|
||||
@ -51,12 +55,21 @@ public class Regression implements DataFrameAnalysis {
|
||||
new BoostedTreeParams((Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5], (Integer) a[6]),
|
||||
(String) a[7],
|
||||
(Double) a[8],
|
||||
(Long) a[9]));
|
||||
(Long) a[9],
|
||||
(LossFunction) a[10],
|
||||
(Double) a[11]));
|
||||
parser.declareString(constructorArg(), DEPENDENT_VARIABLE);
|
||||
BoostedTreeParams.declareFields(parser);
|
||||
parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME);
|
||||
parser.declareDouble(optionalConstructorArg(), TRAINING_PERCENT);
|
||||
parser.declareLong(optionalConstructorArg(), RANDOMIZE_SEED);
|
||||
parser.declareField(optionalConstructorArg(), p -> {
|
||||
if (p.currentToken() == XContentParser.Token.VALUE_STRING) {
|
||||
return LossFunction.fromString(p.text());
|
||||
}
|
||||
throw new IllegalArgumentException("Unsupported token [" + p.currentToken() + "]");
|
||||
}, LOSS_FUNCTION, ObjectParser.ValueType.STRING);
|
||||
parser.declareDouble(optionalConstructorArg(), LOSS_FUNCTION_PARAMETER);
|
||||
return parser;
|
||||
}
|
||||
|
||||
@ -69,12 +82,16 @@ public class Regression implements DataFrameAnalysis {
|
||||
private final String predictionFieldName;
|
||||
private final double trainingPercent;
|
||||
private final long randomizeSeed;
|
||||
private final LossFunction lossFunction;
|
||||
private final Double lossFunctionParameter;
|
||||
|
||||
public Regression(String dependentVariable,
|
||||
BoostedTreeParams boostedTreeParams,
|
||||
@Nullable String predictionFieldName,
|
||||
@Nullable Double trainingPercent,
|
||||
@Nullable Long randomizeSeed) {
|
||||
@Nullable Long randomizeSeed,
|
||||
@Nullable LossFunction lossFunction,
|
||||
@Nullable Double lossFunctionParameter) {
|
||||
if (trainingPercent != null && (trainingPercent < 1.0 || trainingPercent > 100.0)) {
|
||||
throw ExceptionsHelper.badRequestException("[{}] must be a double in [1, 100]", TRAINING_PERCENT.getPreferredName());
|
||||
}
|
||||
@ -83,10 +100,16 @@ public class Regression implements DataFrameAnalysis {
|
||||
this.predictionFieldName = predictionFieldName == null ? dependentVariable + "_prediction" : predictionFieldName;
|
||||
this.trainingPercent = trainingPercent == null ? 100.0 : trainingPercent;
|
||||
this.randomizeSeed = randomizeSeed == null ? Randomness.get().nextLong() : randomizeSeed;
|
||||
// Prior to introducing the loss function setting only MSE was implemented
|
||||
this.lossFunction = lossFunction == null ? LossFunction.MSE : lossFunction;
|
||||
if (lossFunctionParameter != null && lossFunctionParameter <= 0.0) {
|
||||
throw ExceptionsHelper.badRequestException("[{}] must be a positive double", LOSS_FUNCTION_PARAMETER.getPreferredName());
|
||||
}
|
||||
this.lossFunctionParameter = lossFunctionParameter;
|
||||
}
|
||||
|
||||
public Regression(String dependentVariable) {
|
||||
this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null);
|
||||
this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null, null);
|
||||
}
|
||||
|
||||
public Regression(StreamInput in) throws IOException {
|
||||
@ -99,6 +122,14 @@ public class Regression implements DataFrameAnalysis {
|
||||
} else {
|
||||
randomizeSeed = Randomness.get().nextLong();
|
||||
}
|
||||
if (in.getVersion().onOrAfter(Version.V_7_8_0)) {
|
||||
lossFunction = in.readEnum(LossFunction.class);
|
||||
lossFunctionParameter = in.readOptionalDouble();
|
||||
} else {
|
||||
// Prior to introducing the loss function setting only MSE was implemented
|
||||
lossFunction = LossFunction.MSE;
|
||||
lossFunctionParameter = null;
|
||||
}
|
||||
}
|
||||
|
||||
public String getDependentVariable() {
|
||||
@ -121,6 +152,14 @@ public class Regression implements DataFrameAnalysis {
|
||||
return randomizeSeed;
|
||||
}
|
||||
|
||||
public LossFunction getLossFunction() {
|
||||
return lossFunction;
|
||||
}
|
||||
|
||||
public Double getLossFunctionParameter() {
|
||||
return lossFunctionParameter;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME.getPreferredName();
|
||||
@ -135,6 +174,10 @@ public class Regression implements DataFrameAnalysis {
|
||||
if (out.getVersion().onOrAfter(Version.V_7_6_0)) {
|
||||
out.writeOptionalLong(randomizeSeed);
|
||||
}
|
||||
if (out.getVersion().onOrAfter(Version.V_7_8_0)) {
|
||||
out.writeEnum(lossFunction);
|
||||
out.writeOptionalDouble(lossFunctionParameter);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -151,6 +194,10 @@ public class Regression implements DataFrameAnalysis {
|
||||
if (version.onOrAfter(Version.V_7_6_0)) {
|
||||
builder.field(RANDOMIZE_SEED.getPreferredName(), randomizeSeed);
|
||||
}
|
||||
builder.field(LOSS_FUNCTION.getPreferredName(), lossFunction);
|
||||
if (lossFunctionParameter != null) {
|
||||
builder.field(LOSS_FUNCTION_PARAMETER.getPreferredName(), lossFunctionParameter);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
@ -164,6 +211,10 @@ public class Regression implements DataFrameAnalysis {
|
||||
params.put(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
|
||||
}
|
||||
params.put(TRAINING_PERCENT.getPreferredName(), trainingPercent);
|
||||
params.put(LOSS_FUNCTION.getPreferredName(), lossFunction.toString());
|
||||
if (lossFunctionParameter != null) {
|
||||
params.put(LOSS_FUNCTION_PARAMETER.getPreferredName(), lossFunctionParameter);
|
||||
}
|
||||
return params;
|
||||
}
|
||||
|
||||
@ -232,11 +283,27 @@ public class Regression implements DataFrameAnalysis {
|
||||
&& Objects.equals(boostedTreeParams, that.boostedTreeParams)
|
||||
&& Objects.equals(predictionFieldName, that.predictionFieldName)
|
||||
&& trainingPercent == that.trainingPercent
|
||||
&& randomizeSeed == that.randomizeSeed;
|
||||
&& randomizeSeed == that.randomizeSeed
|
||||
&& lossFunction == that.lossFunction
|
||||
&& Objects.equals(lossFunctionParameter, that.lossFunctionParameter);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, trainingPercent, randomizeSeed);
|
||||
return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, trainingPercent, randomizeSeed, lossFunction,
|
||||
lossFunctionParameter);
|
||||
}
|
||||
|
||||
public enum LossFunction {
|
||||
MSE, MSLE, HUBER;
|
||||
|
||||
private static LossFunction fromString(String value) {
|
||||
return LossFunction.valueOf(value.toUpperCase(Locale.ROOT));
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return name().toLowerCase(Locale.ROOT);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -315,6 +315,8 @@ public final class ReservedFieldNames {
|
||||
OutlierDetection.FEATURE_INFLUENCE_THRESHOLD.getPreferredName(),
|
||||
Regression.NAME.getPreferredName(),
|
||||
Regression.DEPENDENT_VARIABLE.getPreferredName(),
|
||||
Regression.LOSS_FUNCTION.getPreferredName(),
|
||||
Regression.LOSS_FUNCTION_PARAMETER.getPreferredName(),
|
||||
Regression.PREDICTION_FIELD_NAME.getPreferredName(),
|
||||
Regression.TRAINING_PERCENT.getPreferredName(),
|
||||
Classification.NAME.getPreferredName(),
|
||||
|
@ -90,6 +90,12 @@
|
||||
"lambda" : {
|
||||
"type" : "double"
|
||||
},
|
||||
"loss_function" : {
|
||||
"type" : "keyword"
|
||||
},
|
||||
"loss_function_parameter" : {
|
||||
"type" : "double"
|
||||
},
|
||||
"max_trees" : {
|
||||
"type" : "integer"
|
||||
},
|
||||
|
@ -138,12 +138,16 @@ public class DataFrameAnalyticsConfigTests extends AbstractBWCSerializationTestC
|
||||
bwcRegression.getBoostedTreeParams(),
|
||||
bwcRegression.getPredictionFieldName(),
|
||||
bwcRegression.getTrainingPercent(),
|
||||
42L);
|
||||
42L,
|
||||
bwcRegression.getLossFunction(),
|
||||
bwcRegression.getLossFunctionParameter());
|
||||
testAnalysis = new Regression(testRegression.getDependentVariable(),
|
||||
testRegression.getBoostedTreeParams(),
|
||||
testRegression.getPredictionFieldName(),
|
||||
testRegression.getTrainingPercent(),
|
||||
42L);
|
||||
42L,
|
||||
testRegression.getLossFunction(),
|
||||
testRegression.getLossFunctionParameter());
|
||||
} else {
|
||||
Classification testClassification = (Classification)testInstance.getAnalysis();
|
||||
Classification bwcClassification = (Classification)bwcSerializedObject.getAnalysis();
|
||||
|
@ -16,8 +16,8 @@ import org.elasticsearch.common.xcontent.json.JsonXContent;
|
||||
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Map;
|
||||
import java.util.Collections;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.empty;
|
||||
@ -43,12 +43,18 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
|
||||
}
|
||||
|
||||
public static Regression createRandom() {
|
||||
return createRandom(BoostedTreeParamsTests.createRandom());
|
||||
}
|
||||
|
||||
private static Regression createRandom(BoostedTreeParams boostedTreeParams) {
|
||||
String dependentVariableName = randomAlphaOfLength(10);
|
||||
BoostedTreeParams boostedTreeParams = BoostedTreeParamsTests.createRandom();
|
||||
String predictionFieldName = randomBoolean() ? null : randomAlphaOfLength(10);
|
||||
Double trainingPercent = randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true);
|
||||
Long randomizeSeed = randomBoolean() ? null : randomLong();
|
||||
return new Regression(dependentVariableName, boostedTreeParams, predictionFieldName, trainingPercent, randomizeSeed);
|
||||
Regression.LossFunction lossFunction = randomBoolean() ? null : randomFrom(Regression.LossFunction.values());
|
||||
Double lossFunctionParameter = randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, false);
|
||||
return new Regression(dependentVariableName, boostedTreeParams, predictionFieldName, trainingPercent, randomizeSeed, lossFunction,
|
||||
lossFunctionParameter);
|
||||
}
|
||||
|
||||
public static Regression mutateForVersion(Regression instance, Version version) {
|
||||
@ -56,7 +62,9 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
|
||||
BoostedTreeParamsTests.mutateForVersion(instance.getBoostedTreeParams(), version),
|
||||
instance.getPredictionFieldName(),
|
||||
instance.getTrainingPercent(),
|
||||
instance.getRandomizeSeed());
|
||||
instance.getRandomizeSeed(),
|
||||
version.onOrAfter(Version.V_7_8_0) ? instance.getLossFunction() : null,
|
||||
version.onOrAfter(Version.V_7_8_0) ? instance.getLossFunctionParameter() : null);
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -70,12 +78,16 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
|
||||
bwcSerializedObject.getBoostedTreeParams(),
|
||||
bwcSerializedObject.getPredictionFieldName(),
|
||||
bwcSerializedObject.getTrainingPercent(),
|
||||
42L);
|
||||
42L,
|
||||
bwcSerializedObject.getLossFunction(),
|
||||
bwcSerializedObject.getLossFunctionParameter());
|
||||
Regression newInstance = new Regression(testInstance.getDependentVariable(),
|
||||
testInstance.getBoostedTreeParams(),
|
||||
testInstance.getPredictionFieldName(),
|
||||
testInstance.getTrainingPercent(),
|
||||
42L);
|
||||
42L,
|
||||
testInstance.getLossFunction(),
|
||||
testInstance.getLossFunctionParameter());
|
||||
super.assertOnBWCObject(newBwc, newInstance, version);
|
||||
}
|
||||
|
||||
@ -91,60 +103,93 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
|
||||
|
||||
public void testConstructor_GivenTrainingPercentIsLessThanOne() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 0.999, randomLong()));
|
||||
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 0.999, randomLong(), Regression.LossFunction.MSE, null));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
|
||||
}
|
||||
|
||||
public void testConstructor_GivenTrainingPercentIsGreaterThan100() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0001, randomLong()));
|
||||
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0001, randomLong(), Regression.LossFunction.MSE, null));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
|
||||
}
|
||||
|
||||
public void testConstructor_GivenLossFunctionParameterIsZero() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong(), Regression.LossFunction.MSE, 0.0));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[loss_function_parameter] must be a positive double"));
|
||||
}
|
||||
|
||||
public void testConstructor_GivenLossFunctionParameterIsNegative() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong(), Regression.LossFunction.MSE, -1.0));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[loss_function_parameter] must be a positive double"));
|
||||
}
|
||||
|
||||
public void testGetPredictionFieldName() {
|
||||
Regression regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 50.0, randomLong());
|
||||
Regression regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 50.0, randomLong(), Regression.LossFunction.MSE, 1.0);
|
||||
assertThat(regression.getPredictionFieldName(), equalTo("result"));
|
||||
|
||||
regression = new Regression("foo", BOOSTED_TREE_PARAMS, null, 50.0, randomLong());
|
||||
regression = new Regression("foo", BOOSTED_TREE_PARAMS, null, 50.0, randomLong(), Regression.LossFunction.MSE, null);
|
||||
assertThat(regression.getPredictionFieldName(), equalTo("foo_prediction"));
|
||||
}
|
||||
|
||||
public void testGetTrainingPercent() {
|
||||
Regression regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 50.0, randomLong());
|
||||
Regression regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 50.0, randomLong(), Regression.LossFunction.MSE, 1.0);
|
||||
assertThat(regression.getTrainingPercent(), equalTo(50.0));
|
||||
|
||||
// Boundary condition: training_percent == 1.0
|
||||
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 1.0, randomLong());
|
||||
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 1.0, randomLong(), Regression.LossFunction.MSE, null);
|
||||
assertThat(regression.getTrainingPercent(), equalTo(1.0));
|
||||
|
||||
// Boundary condition: training_percent == 100.0
|
||||
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong());
|
||||
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong(), Regression.LossFunction.MSE, null);
|
||||
assertThat(regression.getTrainingPercent(), equalTo(100.0));
|
||||
|
||||
// training_percent == null, default applied
|
||||
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", null, randomLong());
|
||||
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", null, randomLong(), Regression.LossFunction.MSE, null);
|
||||
assertThat(regression.getTrainingPercent(), equalTo(100.0));
|
||||
}
|
||||
|
||||
public void testGetParams() {
|
||||
assertThat(
|
||||
new Regression("foo").getParams(null),
|
||||
equalTo(org.elasticsearch.common.collect.Map.of(
|
||||
"dependent_variable", "foo",
|
||||
"prediction_field_name", "foo_prediction",
|
||||
"training_percent", 100.0)));
|
||||
assertThat(
|
||||
new Regression("foo",
|
||||
BoostedTreeParams.builder().build(),
|
||||
null,
|
||||
50.0,
|
||||
null).getParams(null),
|
||||
equalTo(org.elasticsearch.common.collect.Map.of(
|
||||
"dependent_variable", "foo",
|
||||
"prediction_field_name", "foo_prediction",
|
||||
"training_percent", 50.0)));
|
||||
public void testGetParams_ShouldIncludeBoostedTreeParams() {
|
||||
int maxTrees = randomIntBetween(1, 100);
|
||||
Regression regression = new Regression("foo",
|
||||
BoostedTreeParams.builder().setMaxTrees(maxTrees).build(),
|
||||
null,
|
||||
100.0,
|
||||
0L,
|
||||
Regression.LossFunction.MSE,
|
||||
null);
|
||||
|
||||
Map<String, Object> params = regression.getParams(null);
|
||||
|
||||
assertThat(params.size(), equalTo(5));
|
||||
assertThat(params.get("dependent_variable"), equalTo("foo"));
|
||||
assertThat(params.get("prediction_field_name"), equalTo("foo_prediction"));
|
||||
assertThat(params.get("max_trees"), equalTo(maxTrees));
|
||||
assertThat(params.get("training_percent"), equalTo(100.0));
|
||||
assertThat(params.get("loss_function"), equalTo("mse"));
|
||||
}
|
||||
|
||||
public void testGetParams_GivenRandomWithoutBoostedTreeParams() {
|
||||
Regression regression = createRandom(BoostedTreeParams.builder().build());
|
||||
|
||||
Map<String, Object> params = regression.getParams(null);
|
||||
|
||||
int expectedParamsCount = 4 + (regression.getLossFunctionParameter() == null ? 0 : 1);
|
||||
assertThat(params.size(), equalTo(expectedParamsCount));
|
||||
assertThat(params.get("dependent_variable"), equalTo(regression.getDependentVariable()));
|
||||
assertThat(params.get("prediction_field_name"), equalTo(regression.getPredictionFieldName()));
|
||||
assertThat(params.get("training_percent"), equalTo(regression.getTrainingPercent()));
|
||||
assertThat(params.get("loss_function"), equalTo(regression.getLossFunction().toString()));
|
||||
if (regression.getLossFunctionParameter() == null) {
|
||||
assertThat(params.containsKey("loss_function_parameter"), is(false));
|
||||
} else {
|
||||
assertThat(params.get("loss_function_parameter"), equalTo(regression.getLossFunctionParameter()));
|
||||
}
|
||||
}
|
||||
|
||||
public void testRequiredFieldsIsNonEmpty() {
|
||||
|
@ -83,6 +83,8 @@ integTest.runner {
|
||||
'ml/data_frame_analytics_crud/Test put regression given feature_bag_fraction is greater than one',
|
||||
'ml/data_frame_analytics_crud/Test put regression given training_percent is less than one',
|
||||
'ml/data_frame_analytics_crud/Test put regression given training_percent is greater than hundred',
|
||||
'ml/data_frame_analytics_crud/Test put regression given loss_function_parameter is zero',
|
||||
'ml/data_frame_analytics_crud/Test put regression given loss_function_parameter is negative',
|
||||
'ml/data_frame_analytics_crud/Test put classification given dependent_variable is not defined',
|
||||
'ml/data_frame_analytics_crud/Test put classification given negative lambda',
|
||||
'ml/data_frame_analytics_crud/Test put classification given negative gamma',
|
||||
|
@ -99,6 +99,8 @@ public class ExplainDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsInteg
|
||||
BoostedTreeParams.builder().build(),
|
||||
null,
|
||||
100.0,
|
||||
null,
|
||||
null,
|
||||
null))
|
||||
.buildForExplain();
|
||||
|
||||
@ -115,6 +117,8 @@ public class ExplainDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsInteg
|
||||
BoostedTreeParams.builder().build(),
|
||||
null,
|
||||
50.0,
|
||||
null,
|
||||
null,
|
||||
null))
|
||||
.buildForExplain();
|
||||
|
||||
|
@ -67,6 +67,8 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||
BoostedTreeParams.builder().setNumTopFeatureImportanceValues(1).build(),
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null)
|
||||
);
|
||||
putAnalytics(config);
|
||||
@ -180,7 +182,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||
sourceIndex,
|
||||
destIndex,
|
||||
null,
|
||||
new Regression(DEPENDENT_VARIABLE_FIELD, BoostedTreeParams.builder().build(), null, 50.0, null));
|
||||
new Regression(DEPENDENT_VARIABLE_FIELD, BoostedTreeParams.builder().build(), null, 50.0, null, null, null));
|
||||
putAnalytics(config);
|
||||
|
||||
assertIsStopped(jobId);
|
||||
@ -307,7 +309,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||
.build();
|
||||
|
||||
DataFrameAnalyticsConfig firstJob = buildAnalytics(firstJobId, sourceIndex, firstJobDestIndex, null,
|
||||
new Regression(DEPENDENT_VARIABLE_FIELD, boostedTreeParams, null, 50.0, null));
|
||||
new Regression(DEPENDENT_VARIABLE_FIELD, boostedTreeParams, null, 50.0, null, null, null));
|
||||
putAnalytics(firstJob);
|
||||
|
||||
String secondJobId = "regression_two_jobs_with_same_randomize_seed_2";
|
||||
@ -315,7 +317,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||
|
||||
long randomizeSeed = ((Regression) firstJob.getAnalysis()).getRandomizeSeed();
|
||||
DataFrameAnalyticsConfig secondJob = buildAnalytics(secondJobId, sourceIndex, secondJobDestIndex, null,
|
||||
new Regression(DEPENDENT_VARIABLE_FIELD, boostedTreeParams, null, 50.0, randomizeSeed));
|
||||
new Regression(DEPENDENT_VARIABLE_FIELD, boostedTreeParams, null, 50.0, randomizeSeed, null, null));
|
||||
|
||||
putAnalytics(secondJob);
|
||||
|
||||
@ -376,7 +378,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||
sourceIndex,
|
||||
destIndex,
|
||||
null,
|
||||
new Regression(DISCRETE_NUMERICAL_FEATURE_FIELD, BoostedTreeParams.builder().build(), null, null, null));
|
||||
new Regression(DISCRETE_NUMERICAL_FEATURE_FIELD, BoostedTreeParams.builder().build(), null, null, null, null, null));
|
||||
putAnalytics(config);
|
||||
|
||||
assertIsStopped(jobId);
|
||||
|
@ -1552,6 +1552,52 @@ setup:
|
||||
}
|
||||
}
|
||||
|
||||
---
|
||||
"Test put regression given loss_function_parameter is zero":
|
||||
|
||||
- do:
|
||||
catch: /\[loss_function_parameter\] must be a positive double/
|
||||
ml.put_data_frame_analytics:
|
||||
id: "regression-loss-function-param-is-zero"
|
||||
body: >
|
||||
{
|
||||
"source": {
|
||||
"index": "index-source"
|
||||
},
|
||||
"dest": {
|
||||
"index": "index-dest"
|
||||
},
|
||||
"analysis": {
|
||||
"regression": {
|
||||
"dependent_variable": "foo",
|
||||
"loss_function_parameter": 0.0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
---
|
||||
"Test put regression given loss_function_parameter is negative":
|
||||
|
||||
- do:
|
||||
catch: /\[loss_function_parameter\] must be a positive double/
|
||||
ml.put_data_frame_analytics:
|
||||
id: "regression-loss-function-param-is-negative"
|
||||
body: >
|
||||
{
|
||||
"source": {
|
||||
"index": "index-source"
|
||||
},
|
||||
"dest": {
|
||||
"index": "index-dest"
|
||||
},
|
||||
"analysis": {
|
||||
"regression": {
|
||||
"dependent_variable": "foo",
|
||||
"loss_function_parameter": -0.01
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
---
|
||||
"Test put regression given valid":
|
||||
|
||||
@ -1575,7 +1621,9 @@ setup:
|
||||
"max_trees": 400,
|
||||
"feature_bag_fraction": 0.3,
|
||||
"training_percent": 60.3,
|
||||
"randomize_seed": 42
|
||||
"randomize_seed": 42,
|
||||
"loss_function": "msle",
|
||||
"loss_function_parameter": 2.0
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1592,7 +1640,9 @@ setup:
|
||||
"feature_bag_fraction": 0.3,
|
||||
"prediction_field_name": "foo_prediction",
|
||||
"training_percent": 60.3,
|
||||
"randomize_seed": 42
|
||||
"randomize_seed": 42,
|
||||
"loss_function": "msle",
|
||||
"loss_function_parameter": 2.0
|
||||
}
|
||||
}}
|
||||
- is_true: create_time
|
||||
@ -2018,7 +2068,8 @@ setup:
|
||||
"dependent_variable": "foo",
|
||||
"prediction_field_name": "foo_prediction",
|
||||
"training_percent": 100.0,
|
||||
"randomize_seed": 42
|
||||
"randomize_seed": 42,
|
||||
"loss_function": "mse"
|
||||
}
|
||||
}}
|
||||
- is_true: create_time
|
||||
|
Loading…
x
Reference in New Issue
Block a user