This adds a new `randomize_seed` for regression and classification. When not explicitly set, the seed is randomly generated. One can reuse the seed in a similar job in order to ensure the same docs are picked for training. Backport of #49990
This commit is contained in:
parent
ee4a8a08dd
commit
8891f4db88
|
@ -49,6 +49,7 @@ public class Classification implements DataFrameAnalysis {
|
|||
static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
|
||||
static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
|
||||
static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
|
||||
static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");
|
||||
|
||||
private static final ConstructingObjectParser<Classification, Void> PARSER =
|
||||
new ConstructingObjectParser<>(
|
||||
|
@ -63,7 +64,8 @@ public class Classification implements DataFrameAnalysis {
|
|||
(Double) a[5],
|
||||
(String) a[6],
|
||||
(Double) a[7],
|
||||
(Integer) a[8]));
|
||||
(Integer) a[8],
|
||||
(Long) a[9]));
|
||||
|
||||
static {
|
||||
PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE);
|
||||
|
@ -75,6 +77,7 @@ public class Classification implements DataFrameAnalysis {
|
|||
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME);
|
||||
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT);
|
||||
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUM_TOP_CLASSES);
|
||||
PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), RANDOMIZE_SEED);
|
||||
}
|
||||
|
||||
private final String dependentVariable;
|
||||
|
@ -86,10 +89,11 @@ public class Classification implements DataFrameAnalysis {
|
|||
private final String predictionFieldName;
|
||||
private final Double trainingPercent;
|
||||
private final Integer numTopClasses;
|
||||
private final Long randomizeSeed;
|
||||
|
||||
private Classification(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
|
||||
@Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, @Nullable String predictionFieldName,
|
||||
@Nullable Double trainingPercent, @Nullable Integer numTopClasses) {
|
||||
@Nullable Double trainingPercent, @Nullable Integer numTopClasses, @Nullable Long randomizeSeed) {
|
||||
this.dependentVariable = Objects.requireNonNull(dependentVariable);
|
||||
this.lambda = lambda;
|
||||
this.gamma = gamma;
|
||||
|
@ -99,6 +103,7 @@ public class Classification implements DataFrameAnalysis {
|
|||
this.predictionFieldName = predictionFieldName;
|
||||
this.trainingPercent = trainingPercent;
|
||||
this.numTopClasses = numTopClasses;
|
||||
this.randomizeSeed = randomizeSeed;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -138,6 +143,10 @@ public class Classification implements DataFrameAnalysis {
|
|||
return trainingPercent;
|
||||
}
|
||||
|
||||
public Long getRandomizeSeed() {
|
||||
return randomizeSeed;
|
||||
}
|
||||
|
||||
public Integer getNumTopClasses() {
|
||||
return numTopClasses;
|
||||
}
|
||||
|
@ -167,6 +176,9 @@ public class Classification implements DataFrameAnalysis {
|
|||
if (trainingPercent != null) {
|
||||
builder.field(TRAINING_PERCENT.getPreferredName(), trainingPercent);
|
||||
}
|
||||
if (randomizeSeed != null) {
|
||||
builder.field(RANDOMIZE_SEED.getPreferredName(), randomizeSeed);
|
||||
}
|
||||
if (numTopClasses != null) {
|
||||
builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses);
|
||||
}
|
||||
|
@ -177,7 +189,7 @@ public class Classification implements DataFrameAnalysis {
|
|||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
|
||||
trainingPercent, numTopClasses);
|
||||
trainingPercent, randomizeSeed, numTopClasses);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -193,6 +205,7 @@ public class Classification implements DataFrameAnalysis {
|
|||
&& Objects.equals(featureBagFraction, that.featureBagFraction)
|
||||
&& Objects.equals(predictionFieldName, that.predictionFieldName)
|
||||
&& Objects.equals(trainingPercent, that.trainingPercent)
|
||||
&& Objects.equals(randomizeSeed, that.randomizeSeed)
|
||||
&& Objects.equals(numTopClasses, that.numTopClasses);
|
||||
}
|
||||
|
||||
|
@ -211,6 +224,7 @@ public class Classification implements DataFrameAnalysis {
|
|||
private String predictionFieldName;
|
||||
private Double trainingPercent;
|
||||
private Integer numTopClasses;
|
||||
private Long randomizeSeed;
|
||||
|
||||
private Builder(String dependentVariable) {
|
||||
this.dependentVariable = Objects.requireNonNull(dependentVariable);
|
||||
|
@ -251,6 +265,11 @@ public class Classification implements DataFrameAnalysis {
|
|||
return this;
|
||||
}
|
||||
|
||||
public Builder setRandomizeSeed(Long randomizeSeed) {
|
||||
this.randomizeSeed = randomizeSeed;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setNumTopClasses(Integer numTopClasses) {
|
||||
this.numTopClasses = numTopClasses;
|
||||
return this;
|
||||
|
@ -258,7 +277,7 @@ public class Classification implements DataFrameAnalysis {
|
|||
|
||||
public Classification build() {
|
||||
return new Classification(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
|
||||
trainingPercent, numTopClasses);
|
||||
trainingPercent, numTopClasses, randomizeSeed);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -48,6 +48,7 @@ public class Regression implements DataFrameAnalysis {
|
|||
static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction");
|
||||
static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
|
||||
static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
|
||||
static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");
|
||||
|
||||
private static final ConstructingObjectParser<Regression, Void> PARSER =
|
||||
new ConstructingObjectParser<>(
|
||||
|
@ -61,7 +62,8 @@ public class Regression implements DataFrameAnalysis {
|
|||
(Integer) a[4],
|
||||
(Double) a[5],
|
||||
(String) a[6],
|
||||
(Double) a[7]));
|
||||
(Double) a[7],
|
||||
(Long) a[8]));
|
||||
|
||||
static {
|
||||
PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE);
|
||||
|
@ -72,6 +74,7 @@ public class Regression implements DataFrameAnalysis {
|
|||
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), FEATURE_BAG_FRACTION);
|
||||
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME);
|
||||
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT);
|
||||
PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), RANDOMIZE_SEED);
|
||||
}
|
||||
|
||||
private final String dependentVariable;
|
||||
|
@ -82,10 +85,11 @@ public class Regression implements DataFrameAnalysis {
|
|||
private final Double featureBagFraction;
|
||||
private final String predictionFieldName;
|
||||
private final Double trainingPercent;
|
||||
private final Long randomizeSeed;
|
||||
|
||||
private Regression(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
|
||||
@Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, @Nullable String predictionFieldName,
|
||||
@Nullable Double trainingPercent) {
|
||||
@Nullable Double trainingPercent, @Nullable Long randomizeSeed) {
|
||||
this.dependentVariable = Objects.requireNonNull(dependentVariable);
|
||||
this.lambda = lambda;
|
||||
this.gamma = gamma;
|
||||
|
@ -94,6 +98,7 @@ public class Regression implements DataFrameAnalysis {
|
|||
this.featureBagFraction = featureBagFraction;
|
||||
this.predictionFieldName = predictionFieldName;
|
||||
this.trainingPercent = trainingPercent;
|
||||
this.randomizeSeed = randomizeSeed;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -133,6 +138,10 @@ public class Regression implements DataFrameAnalysis {
|
|||
return trainingPercent;
|
||||
}
|
||||
|
||||
public Long getRandomizeSeed() {
|
||||
return randomizeSeed;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
|
@ -158,6 +167,9 @@ public class Regression implements DataFrameAnalysis {
|
|||
if (trainingPercent != null) {
|
||||
builder.field(TRAINING_PERCENT.getPreferredName(), trainingPercent);
|
||||
}
|
||||
if (randomizeSeed != null) {
|
||||
builder.field(RANDOMIZE_SEED.getPreferredName(), randomizeSeed);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
@ -165,7 +177,7 @@ public class Regression implements DataFrameAnalysis {
|
|||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
|
||||
trainingPercent);
|
||||
trainingPercent, randomizeSeed);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -180,7 +192,8 @@ public class Regression implements DataFrameAnalysis {
|
|||
&& Objects.equals(maximumNumberTrees, that.maximumNumberTrees)
|
||||
&& Objects.equals(featureBagFraction, that.featureBagFraction)
|
||||
&& Objects.equals(predictionFieldName, that.predictionFieldName)
|
||||
&& Objects.equals(trainingPercent, that.trainingPercent);
|
||||
&& Objects.equals(trainingPercent, that.trainingPercent)
|
||||
&& Objects.equals(randomizeSeed, that.randomizeSeed);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -197,6 +210,7 @@ public class Regression implements DataFrameAnalysis {
|
|||
private Double featureBagFraction;
|
||||
private String predictionFieldName;
|
||||
private Double trainingPercent;
|
||||
private Long randomizeSeed;
|
||||
|
||||
private Builder(String dependentVariable) {
|
||||
this.dependentVariable = Objects.requireNonNull(dependentVariable);
|
||||
|
@ -237,9 +251,14 @@ public class Regression implements DataFrameAnalysis {
|
|||
return this;
|
||||
}
|
||||
|
||||
public Builder setRandomizeSeed(Long randomizeSeed) {
|
||||
this.randomizeSeed = randomizeSeed;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Regression build() {
|
||||
return new Regression(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
|
||||
trainingPercent);
|
||||
trainingPercent, randomizeSeed);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1321,6 +1321,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|||
.setAnalysis(org.elasticsearch.client.ml.dataframe.Regression.builder("my_dependent_variable")
|
||||
.setPredictionFieldName("my_dependent_variable_prediction")
|
||||
.setTrainingPercent(80.0)
|
||||
.setRandomizeSeed(42L)
|
||||
.build())
|
||||
.setDescription("this is a regression")
|
||||
.build();
|
||||
|
@ -1356,6 +1357,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|||
.setAnalysis(org.elasticsearch.client.ml.dataframe.Classification.builder("my_dependent_variable")
|
||||
.setPredictionFieldName("my_dependent_variable_prediction")
|
||||
.setTrainingPercent(80.0)
|
||||
.setRandomizeSeed(42L)
|
||||
.setNumTopClasses(1)
|
||||
.build())
|
||||
.setDescription("this is a classification")
|
||||
|
|
|
@ -2975,7 +2975,8 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
|||
.setFeatureBagFraction(0.4) // <6>
|
||||
.setPredictionFieldName("my_prediction_field_name") // <7>
|
||||
.setTrainingPercent(50.0) // <8>
|
||||
.setNumTopClasses(1) // <9>
|
||||
.setRandomizeSeed(1234L) // <9>
|
||||
.setNumTopClasses(1) // <10>
|
||||
.build();
|
||||
// end::put-data-frame-analytics-classification
|
||||
|
||||
|
@ -2988,6 +2989,7 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
|||
.setFeatureBagFraction(0.4) // <6>
|
||||
.setPredictionFieldName("my_prediction_field_name") // <7>
|
||||
.setTrainingPercent(50.0) // <8>
|
||||
.setRandomizeSeed(1234L) // <9>
|
||||
.build();
|
||||
// end::put-data-frame-analytics-regression
|
||||
|
||||
|
|
|
@ -34,6 +34,7 @@ public class ClassificationTests extends AbstractXContentTestCase<Classification
|
|||
.setFeatureBagFraction(randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false))
|
||||
.setPredictionFieldName(randomBoolean() ? null : randomAlphaOfLength(10))
|
||||
.setTrainingPercent(randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true))
|
||||
.setRandomizeSeed(randomBoolean() ? null : randomLong())
|
||||
.setNumTopClasses(randomBoolean() ? null : randomIntBetween(0, 10))
|
||||
.build();
|
||||
}
|
||||
|
|
|
@ -119,7 +119,8 @@ include-tagged::{doc-tests-file}[{api}-classification]
|
|||
<6> The fraction of features which will be used when selecting a random bag for each candidate split. A double in (0, 1].
|
||||
<7> The name of the prediction field in the results object.
|
||||
<8> The percentage of training-eligible rows to be used in training. Defaults to 100%.
|
||||
<9> The number of top classes to be reported in the results. Defaults to 2.
|
||||
<9> The seed to be used by the random generator that picks which rows are used in training.
|
||||
<10> The number of top classes to be reported in the results. Defaults to 2.
|
||||
|
||||
===== Regression
|
||||
|
||||
|
@ -138,6 +139,7 @@ include-tagged::{doc-tests-file}[{api}-regression]
|
|||
<6> The fraction of features which will be used when selecting a random bag for each candidate split. A double in (0, 1].
|
||||
<7> The name of the prediction field in the results object.
|
||||
<8> The percentage of training-eligible rows to be used in training. Defaults to 100%.
|
||||
<9> The seed to be used by the random generator that picks which rows are used in training.
|
||||
|
||||
==== Analyzed fields
|
||||
|
||||
|
|
|
@ -204,6 +204,8 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=prediction_field_name]
|
|||
|
||||
include::{docdir}/ml/ml-shared.asciidoc[tag=training_percent]
|
||||
|
||||
include::{docdir}/ml/ml-shared.asciidoc[tag=randomize_seed]
|
||||
|
||||
|
||||
[float]
|
||||
[[regression-resources-advanced]]
|
||||
|
@ -252,6 +254,8 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=prediction_field_name]
|
|||
|
||||
include::{docdir}/ml/ml-shared.asciidoc[tag=training_percent]
|
||||
|
||||
include::{docdir}/ml/ml-shared.asciidoc[tag=randomize_seed]
|
||||
|
||||
|
||||
[float]
|
||||
[[classification-resources-advanced]]
|
||||
|
|
|
@ -402,7 +402,8 @@ PUT _ml/data_frame/analytics/student_performance_mathematics_0.3
|
|||
{
|
||||
"regression": {
|
||||
"dependent_variable": "G3",
|
||||
"training_percent": 70 <1>
|
||||
"training_percent": 70, <1>
|
||||
"randomize_seed": 19673948271 <2>
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -411,6 +412,7 @@ PUT _ml/data_frame/analytics/student_performance_mathematics_0.3
|
|||
|
||||
<1> The `training_percent` defines the percentage of the data set that will be used
|
||||
for training the model.
|
||||
<2> The `randomize_seed` is the seed used to randomly pick which data is used for training.
|
||||
|
||||
|
||||
[[ml-put-dfanalytics-example-c]]
|
||||
|
|
|
@ -68,3 +68,17 @@ be used for training. Documents that are ignored by the analysis (for example
|
|||
those that contain arrays) won’t be included in the calculation for used
|
||||
percentage. Defaults to `100`.
|
||||
end::training_percent[]
|
||||
|
||||
tag::randomize_seed[]
|
||||
`randomize_seed`::
|
||||
(Optional, long) Defines the seed to the random generator that is used to pick
|
||||
which documents will be used for training. By default it is randomly generated.
|
||||
Set it to a specific value to ensure the same documents are used for training
|
||||
assuming other related parameters (e.g. `source`, `analyzed_fields`, etc.) are the same.
|
||||
end::randomize_seed[]
|
||||
|
||||
|
||||
tag::use-null[]
|
||||
Defines whether a new series is used as the null series when there is no value
|
||||
for the by or partition fields. The default value is `false`.
|
||||
end::use-null[]
|
||||
|
|
|
@ -225,7 +225,8 @@ public class DataFrameAnalyticsConfig implements ToXContentObject, Writeable {
|
|||
builder.field(DEST.getPreferredName(), dest);
|
||||
|
||||
builder.startObject(ANALYSIS.getPreferredName());
|
||||
builder.field(analysis.getWriteableName(), analysis);
|
||||
builder.field(analysis.getWriteableName(), analysis,
|
||||
new MapParams(Collections.singletonMap(VERSION.getPreferredName(), version == null ? null : version.toString())));
|
||||
builder.endObject();
|
||||
|
||||
if (params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false)) {
|
||||
|
|
|
@ -49,7 +49,7 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable {
|
|||
private final Integer maximumNumberTrees;
|
||||
private final Double featureBagFraction;
|
||||
|
||||
BoostedTreeParams(@Nullable Double lambda,
|
||||
public BoostedTreeParams(@Nullable Double lambda,
|
||||
@Nullable Double gamma,
|
||||
@Nullable Double eta,
|
||||
@Nullable Integer maximumNumberTrees,
|
||||
|
@ -76,7 +76,7 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable {
|
|||
this.featureBagFraction = featureBagFraction;
|
||||
}
|
||||
|
||||
BoostedTreeParams() {
|
||||
public BoostedTreeParams() {
|
||||
this(null, null, null, null, null);
|
||||
}
|
||||
|
||||
|
|
|
@ -5,8 +5,10 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.analyses;
|
||||
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.Randomness;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
|
@ -35,6 +37,7 @@ public class Classification implements DataFrameAnalysis {
|
|||
public static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
|
||||
public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
|
||||
public static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
|
||||
public static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");
|
||||
|
||||
private static final ConstructingObjectParser<Classification, Void> LENIENT_PARSER = createParser(true);
|
||||
private static final ConstructingObjectParser<Classification, Void> STRICT_PARSER = createParser(false);
|
||||
|
@ -48,12 +51,14 @@ public class Classification implements DataFrameAnalysis {
|
|||
new BoostedTreeParams((Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5]),
|
||||
(String) a[6],
|
||||
(Integer) a[7],
|
||||
(Double) a[8]));
|
||||
(Double) a[8],
|
||||
(Long) a[9]));
|
||||
parser.declareString(constructorArg(), DEPENDENT_VARIABLE);
|
||||
BoostedTreeParams.declareFields(parser);
|
||||
parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME);
|
||||
parser.declareInt(optionalConstructorArg(), NUM_TOP_CLASSES);
|
||||
parser.declareDouble(optionalConstructorArg(), TRAINING_PERCENT);
|
||||
parser.declareLong(optionalConstructorArg(), RANDOMIZE_SEED);
|
||||
return parser;
|
||||
}
|
||||
|
||||
|
@ -84,12 +89,14 @@ public class Classification implements DataFrameAnalysis {
|
|||
private final String predictionFieldName;
|
||||
private final int numTopClasses;
|
||||
private final double trainingPercent;
|
||||
private final long randomizeSeed;
|
||||
|
||||
public Classification(String dependentVariable,
|
||||
BoostedTreeParams boostedTreeParams,
|
||||
@Nullable String predictionFieldName,
|
||||
@Nullable Integer numTopClasses,
|
||||
@Nullable Double trainingPercent) {
|
||||
@Nullable Double trainingPercent,
|
||||
@Nullable Long randomizeSeed) {
|
||||
if (numTopClasses != null && (numTopClasses < 0 || numTopClasses > 1000)) {
|
||||
throw ExceptionsHelper.badRequestException("[{}] must be an integer in [0, 1000]", NUM_TOP_CLASSES.getPreferredName());
|
||||
}
|
||||
|
@ -101,10 +108,11 @@ public class Classification implements DataFrameAnalysis {
|
|||
this.predictionFieldName = predictionFieldName == null ? dependentVariable + "_prediction" : predictionFieldName;
|
||||
this.numTopClasses = numTopClasses == null ? DEFAULT_NUM_TOP_CLASSES : numTopClasses;
|
||||
this.trainingPercent = trainingPercent == null ? 100.0 : trainingPercent;
|
||||
this.randomizeSeed = randomizeSeed == null ? Randomness.get().nextLong() : randomizeSeed;
|
||||
}
|
||||
|
||||
public Classification(String dependentVariable) {
|
||||
this(dependentVariable, new BoostedTreeParams(), null, null, null);
|
||||
this(dependentVariable, new BoostedTreeParams(), null, null, null, null);
|
||||
}
|
||||
|
||||
public Classification(StreamInput in) throws IOException {
|
||||
|
@ -113,12 +121,21 @@ public class Classification implements DataFrameAnalysis {
|
|||
predictionFieldName = in.readOptionalString();
|
||||
numTopClasses = in.readOptionalVInt();
|
||||
trainingPercent = in.readDouble();
|
||||
if (in.getVersion().onOrAfter(Version.V_7_6_0)) {
|
||||
randomizeSeed = in.readOptionalLong();
|
||||
} else {
|
||||
randomizeSeed = Randomness.get().nextLong();
|
||||
}
|
||||
}
|
||||
|
||||
public String getDependentVariable() {
|
||||
return dependentVariable;
|
||||
}
|
||||
|
||||
public BoostedTreeParams getBoostedTreeParams() {
|
||||
return boostedTreeParams;
|
||||
}
|
||||
|
||||
public String getPredictionFieldName() {
|
||||
return predictionFieldName;
|
||||
}
|
||||
|
@ -131,6 +148,11 @@ public class Classification implements DataFrameAnalysis {
|
|||
return trainingPercent;
|
||||
}
|
||||
|
||||
@Nullable
|
||||
public Long getRandomizeSeed() {
|
||||
return randomizeSeed;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME.getPreferredName();
|
||||
|
@ -143,10 +165,15 @@ public class Classification implements DataFrameAnalysis {
|
|||
out.writeOptionalString(predictionFieldName);
|
||||
out.writeOptionalVInt(numTopClasses);
|
||||
out.writeDouble(trainingPercent);
|
||||
if (out.getVersion().onOrAfter(Version.V_7_6_0)) {
|
||||
out.writeOptionalLong(randomizeSeed);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
Version version = Version.fromString(params.param("version", Version.CURRENT.toString()));
|
||||
|
||||
builder.startObject();
|
||||
builder.field(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable);
|
||||
boostedTreeParams.toXContent(builder, params);
|
||||
|
@ -155,6 +182,9 @@ public class Classification implements DataFrameAnalysis {
|
|||
builder.field(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
|
||||
}
|
||||
builder.field(TRAINING_PERCENT.getPreferredName(), trainingPercent);
|
||||
if (version.onOrAfter(Version.V_7_6_0)) {
|
||||
builder.field(RANDOMIZE_SEED.getPreferredName(), randomizeSeed);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
@ -240,11 +270,12 @@ public class Classification implements DataFrameAnalysis {
|
|||
&& Objects.equals(boostedTreeParams, that.boostedTreeParams)
|
||||
&& Objects.equals(predictionFieldName, that.predictionFieldName)
|
||||
&& Objects.equals(numTopClasses, that.numTopClasses)
|
||||
&& trainingPercent == that.trainingPercent;
|
||||
&& trainingPercent == that.trainingPercent
|
||||
&& randomizeSeed == that.randomizeSeed;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, numTopClasses, trainingPercent);
|
||||
return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, numTopClasses, trainingPercent, randomizeSeed);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,8 +5,10 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.analyses;
|
||||
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.Randomness;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
|
@ -32,6 +34,7 @@ public class Regression implements DataFrameAnalysis {
|
|||
public static final ParseField DEPENDENT_VARIABLE = new ParseField("dependent_variable");
|
||||
public static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
|
||||
public static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
|
||||
public static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");
|
||||
|
||||
private static final ConstructingObjectParser<Regression, Void> LENIENT_PARSER = createParser(true);
|
||||
private static final ConstructingObjectParser<Regression, Void> STRICT_PARSER = createParser(false);
|
||||
|
@ -44,11 +47,13 @@ public class Regression implements DataFrameAnalysis {
|
|||
(String) a[0],
|
||||
new BoostedTreeParams((Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5]),
|
||||
(String) a[6],
|
||||
(Double) a[7]));
|
||||
(Double) a[7],
|
||||
(Long) a[8]));
|
||||
parser.declareString(constructorArg(), DEPENDENT_VARIABLE);
|
||||
BoostedTreeParams.declareFields(parser);
|
||||
parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME);
|
||||
parser.declareDouble(optionalConstructorArg(), TRAINING_PERCENT);
|
||||
parser.declareLong(optionalConstructorArg(), RANDOMIZE_SEED);
|
||||
return parser;
|
||||
}
|
||||
|
||||
|
@ -60,11 +65,13 @@ public class Regression implements DataFrameAnalysis {
|
|||
private final BoostedTreeParams boostedTreeParams;
|
||||
private final String predictionFieldName;
|
||||
private final double trainingPercent;
|
||||
private final long randomizeSeed;
|
||||
|
||||
public Regression(String dependentVariable,
|
||||
BoostedTreeParams boostedTreeParams,
|
||||
@Nullable String predictionFieldName,
|
||||
@Nullable Double trainingPercent) {
|
||||
@Nullable Double trainingPercent,
|
||||
@Nullable Long randomizeSeed) {
|
||||
if (trainingPercent != null && (trainingPercent < 1.0 || trainingPercent > 100.0)) {
|
||||
throw ExceptionsHelper.badRequestException("[{}] must be a double in [1, 100]", TRAINING_PERCENT.getPreferredName());
|
||||
}
|
||||
|
@ -72,10 +79,11 @@ public class Regression implements DataFrameAnalysis {
|
|||
this.boostedTreeParams = ExceptionsHelper.requireNonNull(boostedTreeParams, BoostedTreeParams.NAME);
|
||||
this.predictionFieldName = predictionFieldName == null ? dependentVariable + "_prediction" : predictionFieldName;
|
||||
this.trainingPercent = trainingPercent == null ? 100.0 : trainingPercent;
|
||||
this.randomizeSeed = randomizeSeed == null ? Randomness.get().nextLong() : randomizeSeed;
|
||||
}
|
||||
|
||||
public Regression(String dependentVariable) {
|
||||
this(dependentVariable, new BoostedTreeParams(), null, null);
|
||||
this(dependentVariable, new BoostedTreeParams(), null, null, null);
|
||||
}
|
||||
|
||||
public Regression(StreamInput in) throws IOException {
|
||||
|
@ -83,12 +91,21 @@ public class Regression implements DataFrameAnalysis {
|
|||
boostedTreeParams = new BoostedTreeParams(in);
|
||||
predictionFieldName = in.readOptionalString();
|
||||
trainingPercent = in.readDouble();
|
||||
if (in.getVersion().onOrAfter(Version.V_7_6_0)) {
|
||||
randomizeSeed = in.readOptionalLong();
|
||||
} else {
|
||||
randomizeSeed = Randomness.get().nextLong();
|
||||
}
|
||||
}
|
||||
|
||||
public String getDependentVariable() {
|
||||
return dependentVariable;
|
||||
}
|
||||
|
||||
public BoostedTreeParams getBoostedTreeParams() {
|
||||
return boostedTreeParams;
|
||||
}
|
||||
|
||||
public String getPredictionFieldName() {
|
||||
return predictionFieldName;
|
||||
}
|
||||
|
@ -97,6 +114,11 @@ public class Regression implements DataFrameAnalysis {
|
|||
return trainingPercent;
|
||||
}
|
||||
|
||||
@Nullable
|
||||
public Long getRandomizeSeed() {
|
||||
return randomizeSeed;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME.getPreferredName();
|
||||
|
@ -108,10 +130,15 @@ public class Regression implements DataFrameAnalysis {
|
|||
boostedTreeParams.writeTo(out);
|
||||
out.writeOptionalString(predictionFieldName);
|
||||
out.writeDouble(trainingPercent);
|
||||
if (out.getVersion().onOrAfter(Version.V_7_6_0)) {
|
||||
out.writeOptionalLong(randomizeSeed);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
Version version = Version.fromString(params.param("version", Version.CURRENT.toString()));
|
||||
|
||||
builder.startObject();
|
||||
builder.field(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable);
|
||||
boostedTreeParams.toXContent(builder, params);
|
||||
|
@ -119,6 +146,9 @@ public class Regression implements DataFrameAnalysis {
|
|||
builder.field(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
|
||||
}
|
||||
builder.field(TRAINING_PERCENT.getPreferredName(), trainingPercent);
|
||||
if (version.onOrAfter(Version.V_7_6_0)) {
|
||||
builder.field(RANDOMIZE_SEED.getPreferredName(), randomizeSeed);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
@ -177,11 +207,12 @@ public class Regression implements DataFrameAnalysis {
|
|||
return Objects.equals(dependentVariable, that.dependentVariable)
|
||||
&& Objects.equals(boostedTreeParams, that.boostedTreeParams)
|
||||
&& Objects.equals(predictionFieldName, that.predictionFieldName)
|
||||
&& trainingPercent == that.trainingPercent;
|
||||
&& trainingPercent == that.trainingPercent
|
||||
&& randomizeSeed == randomizeSeed;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, trainingPercent);
|
||||
return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, trainingPercent, randomizeSeed);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ import com.carrotsearch.randomizedtesting.generators.CodepointSetGenerator;
|
|||
import org.elasticsearch.ElasticsearchException;
|
||||
import org.elasticsearch.ElasticsearchStatusException;
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.bytes.BytesReference;
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
|
@ -20,17 +21,20 @@ import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
|
|||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContent;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentFactory;
|
||||
import org.elasticsearch.common.xcontent.XContentHelper;
|
||||
import org.elasticsearch.common.xcontent.XContentParseException;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.common.xcontent.XContentType;
|
||||
import org.elasticsearch.common.xcontent.json.JsonXContent;
|
||||
import org.elasticsearch.index.query.MatchAllQueryBuilder;
|
||||
import org.elasticsearch.search.SearchModule;
|
||||
import org.elasticsearch.search.fetch.subphase.FetchSourceContext;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetectionTests;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
|
||||
import org.junit.Before;
|
||||
|
||||
|
@ -42,10 +46,13 @@ import java.util.HashMap;
|
|||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasEntry;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.hamcrest.Matchers.not;
|
||||
import static org.hamcrest.Matchers.notNullValue;
|
||||
import static org.hamcrest.Matchers.startsWith;
|
||||
|
||||
public class DataFrameAnalyticsConfigTests extends AbstractSerializingTestCase<DataFrameAnalyticsConfig> {
|
||||
|
@ -339,6 +346,44 @@ public class DataFrameAnalyticsConfigTests extends AbstractSerializingTestCase<D
|
|||
}
|
||||
}
|
||||
|
||||
public void testToXContent_GivenAnalysisWithRandomizeSeedAndVersionIsCurrent() throws IOException {
|
||||
Regression regression = new Regression("foo");
|
||||
assertThat(regression.getRandomizeSeed(), is(notNullValue()));
|
||||
|
||||
DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder()
|
||||
.setVersion(Version.CURRENT)
|
||||
.setId("test_config")
|
||||
.setSource(new DataFrameAnalyticsSource(new String[] {"source_index"}, null, null))
|
||||
.setDest(new DataFrameAnalyticsDest("dest_index", null))
|
||||
.setAnalysis(regression)
|
||||
.build();
|
||||
|
||||
try (XContentBuilder builder = JsonXContent.contentBuilder()) {
|
||||
config.toXContent(builder, ToXContent.EMPTY_PARAMS);
|
||||
String json = Strings.toString(builder);
|
||||
assertThat(json, containsString("randomize_seed"));
|
||||
}
|
||||
}
|
||||
|
||||
public void testToXContent_GivenAnalysisWithRandomizeSeedAndVersionIsBeforeItWasIntroduced() throws IOException {
|
||||
Regression regression = new Regression("foo");
|
||||
assertThat(regression.getRandomizeSeed(), is(notNullValue()));
|
||||
|
||||
DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder()
|
||||
.setVersion(Version.V_7_5_0)
|
||||
.setId("test_config")
|
||||
.setSource(new DataFrameAnalyticsSource(new String[] {"source_index"}, null, null))
|
||||
.setDest(new DataFrameAnalyticsDest("dest_index", null))
|
||||
.setAnalysis(regression)
|
||||
.build();
|
||||
|
||||
try (XContentBuilder builder = JsonXContent.contentBuilder()) {
|
||||
config.toXContent(builder, ToXContent.EMPTY_PARAMS);
|
||||
String json = Strings.toString(builder);
|
||||
assertThat(json, not(containsString("randomize_seed")));
|
||||
}
|
||||
}
|
||||
|
||||
private static void assertTooSmall(ElasticsearchStatusException e) {
|
||||
assertThat(e.getMessage(), startsWith("model_memory_limit must be at least 1kb."));
|
||||
}
|
||||
|
|
|
@ -6,8 +6,13 @@
|
|||
package org.elasticsearch.xpack.core.ml.dataframe.analyses;
|
||||
|
||||
import org.elasticsearch.ElasticsearchStatusException;
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.ToXContent;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.common.xcontent.json.JsonXContent;
|
||||
import org.elasticsearch.index.mapper.BooleanFieldMapper;
|
||||
import org.elasticsearch.index.mapper.KeywordFieldMapper;
|
||||
import org.elasticsearch.index.mapper.NumberFieldMapper;
|
||||
|
@ -20,10 +25,12 @@ import java.util.HashMap;
|
|||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasEntry;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.hamcrest.Matchers.not;
|
||||
import static org.hamcrest.Matchers.notNullValue;
|
||||
import static org.hamcrest.Matchers.nullValue;
|
||||
|
||||
public class ClassificationTests extends AbstractSerializingTestCase<Classification> {
|
||||
|
@ -46,7 +53,9 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
|
|||
String predictionFieldName = randomBoolean() ? null : randomAlphaOfLength(10);
|
||||
Integer numTopClasses = randomBoolean() ? null : randomIntBetween(0, 1000);
|
||||
Double trainingPercent = randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true);
|
||||
return new Classification(dependentVariableName, boostedTreeParams, predictionFieldName, numTopClasses, trainingPercent);
|
||||
Long randomizeSeed = randomBoolean() ? null : randomLong();
|
||||
return new Classification(dependentVariableName, boostedTreeParams, predictionFieldName, numTopClasses, trainingPercent,
|
||||
randomizeSeed);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -56,71 +65,71 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
|
|||
|
||||
public void testConstructor_GivenTrainingPercentIsLessThanOne() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 0.999));
|
||||
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 0.999, randomLong()));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
|
||||
}
|
||||
|
||||
public void testConstructor_GivenTrainingPercentIsGreaterThan100() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 100.0001));
|
||||
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 100.0001, randomLong()));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
|
||||
}
|
||||
|
||||
public void testConstructor_GivenNumTopClassesIsLessThanZero() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", -1, 1.0));
|
||||
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", -1, 1.0, randomLong()));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]"));
|
||||
}
|
||||
|
||||
public void testConstructor_GivenNumTopClassesIsGreaterThan1000() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", 1001, 1.0));
|
||||
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", 1001, 1.0, randomLong()));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]"));
|
||||
}
|
||||
|
||||
public void testGetPredictionFieldName() {
|
||||
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 50.0);
|
||||
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 50.0, randomLong());
|
||||
assertThat(classification.getPredictionFieldName(), equalTo("result"));
|
||||
|
||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, null, 3, 50.0);
|
||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, null, 3, 50.0, randomLong());
|
||||
assertThat(classification.getPredictionFieldName(), equalTo("foo_prediction"));
|
||||
}
|
||||
|
||||
public void testGetNumTopClasses() {
|
||||
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 7, 1.0);
|
||||
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 7, 1.0, randomLong());
|
||||
assertThat(classification.getNumTopClasses(), equalTo(7));
|
||||
|
||||
// Boundary condition: num_top_classes == 0
|
||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 0, 1.0);
|
||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 0, 1.0, randomLong());
|
||||
assertThat(classification.getNumTopClasses(), equalTo(0));
|
||||
|
||||
// Boundary condition: num_top_classes == 1000
|
||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 1000, 1.0);
|
||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 1000, 1.0, randomLong());
|
||||
assertThat(classification.getNumTopClasses(), equalTo(1000));
|
||||
|
||||
// num_top_classes == null, default applied
|
||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1.0);
|
||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1.0, randomLong());
|
||||
assertThat(classification.getNumTopClasses(), equalTo(2));
|
||||
}
|
||||
|
||||
public void testGetTrainingPercent() {
|
||||
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 50.0);
|
||||
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 50.0, randomLong());
|
||||
assertThat(classification.getTrainingPercent(), equalTo(50.0));
|
||||
|
||||
// Boundary condition: training_percent == 1.0
|
||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 1.0);
|
||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 1.0, randomLong());
|
||||
assertThat(classification.getTrainingPercent(), equalTo(1.0));
|
||||
|
||||
// Boundary condition: training_percent == 100.0
|
||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 100.0);
|
||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 100.0, randomLong());
|
||||
assertThat(classification.getTrainingPercent(), equalTo(100.0));
|
||||
|
||||
// training_percent == null, default applied
|
||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, null);
|
||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, null, randomLong());
|
||||
assertThat(classification.getTrainingPercent(), equalTo(100.0));
|
||||
}
|
||||
|
||||
|
@ -155,4 +164,48 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
|
|||
public void testFieldCardinalityLimitsIsNonNull() {
|
||||
assertThat(createTestInstance().getFieldCardinalityLimits(), is(not(nullValue())));
|
||||
}
|
||||
|
||||
public void testToXContent_GivenVersionBeforeRandomizeSeedWasIntroduced() throws IOException {
|
||||
Classification classification = createRandom();
|
||||
assertThat(classification.getRandomizeSeed(), is(notNullValue()));
|
||||
|
||||
try (XContentBuilder builder = JsonXContent.contentBuilder()) {
|
||||
classification.toXContent(builder, new ToXContent.MapParams(Collections.singletonMap("version", "7.5.0")));
|
||||
String json = Strings.toString(builder);
|
||||
assertThat(json, not(containsString("randomize_seed")));
|
||||
}
|
||||
}
|
||||
|
||||
public void testToXContent_GivenVersionAfterRandomizeSeedWasIntroduced() throws IOException {
|
||||
Classification classification = createRandom();
|
||||
assertThat(classification.getRandomizeSeed(), is(notNullValue()));
|
||||
|
||||
try (XContentBuilder builder = JsonXContent.contentBuilder()) {
|
||||
classification.toXContent(builder, new ToXContent.MapParams(Collections.singletonMap("version", Version.CURRENT.toString())));
|
||||
String json = Strings.toString(builder);
|
||||
assertThat(json, containsString("randomize_seed"));
|
||||
}
|
||||
}
|
||||
|
||||
public void testToXContent_GivenVersionIsNull() throws IOException {
|
||||
Classification classification = createRandom();
|
||||
assertThat(classification.getRandomizeSeed(), is(notNullValue()));
|
||||
|
||||
try (XContentBuilder builder = JsonXContent.contentBuilder()) {
|
||||
classification.toXContent(builder, new ToXContent.MapParams(Collections.singletonMap("version", null)));
|
||||
String json = Strings.toString(builder);
|
||||
assertThat(json, containsString("randomize_seed"));
|
||||
}
|
||||
}
|
||||
|
||||
public void testToXContent_GivenEmptyParams() throws IOException {
|
||||
Classification classification = createRandom();
|
||||
assertThat(classification.getRandomizeSeed(), is(notNullValue()));
|
||||
|
||||
try (XContentBuilder builder = JsonXContent.contentBuilder()) {
|
||||
classification.toXContent(builder, ToXContent.EMPTY_PARAMS);
|
||||
String json = Strings.toString(builder);
|
||||
assertThat(json, containsString("randomize_seed"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,17 +6,25 @@
|
|||
package org.elasticsearch.xpack.core.ml.dataframe.analyses;
|
||||
|
||||
import org.elasticsearch.ElasticsearchStatusException;
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.ToXContent;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.common.xcontent.json.JsonXContent;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
|
||||
import static org.hamcrest.Matchers.allOf;
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasEntry;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.hamcrest.Matchers.not;
|
||||
import static org.hamcrest.Matchers.notNullValue;
|
||||
import static org.hamcrest.Matchers.nullValue;
|
||||
|
||||
public class RegressionTests extends AbstractSerializingTestCase<Regression> {
|
||||
|
@ -38,7 +46,8 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
|
|||
BoostedTreeParams boostedTreeParams = BoostedTreeParamsTests.createRandom();
|
||||
String predictionFieldName = randomBoolean() ? null : randomAlphaOfLength(10);
|
||||
Double trainingPercent = randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true);
|
||||
return new Regression(dependentVariableName, boostedTreeParams, predictionFieldName, trainingPercent);
|
||||
Long randomizeSeed = randomBoolean() ? null : randomLong();
|
||||
return new Regression(dependentVariableName, boostedTreeParams, predictionFieldName, trainingPercent, randomizeSeed);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -48,40 +57,40 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
|
|||
|
||||
public void testConstructor_GivenTrainingPercentIsLessThanOne() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 0.999));
|
||||
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 0.999, randomLong()));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
|
||||
}
|
||||
|
||||
public void testConstructor_GivenTrainingPercentIsGreaterThan100() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0001));
|
||||
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0001, randomLong()));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
|
||||
}
|
||||
|
||||
public void testGetPredictionFieldName() {
|
||||
Regression regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 50.0);
|
||||
Regression regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 50.0, randomLong());
|
||||
assertThat(regression.getPredictionFieldName(), equalTo("result"));
|
||||
|
||||
regression = new Regression("foo", BOOSTED_TREE_PARAMS, null, 50.0);
|
||||
regression = new Regression("foo", BOOSTED_TREE_PARAMS, null, 50.0, randomLong());
|
||||
assertThat(regression.getPredictionFieldName(), equalTo("foo_prediction"));
|
||||
}
|
||||
|
||||
public void testGetTrainingPercent() {
|
||||
Regression regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 50.0);
|
||||
Regression regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 50.0, randomLong());
|
||||
assertThat(regression.getTrainingPercent(), equalTo(50.0));
|
||||
|
||||
// Boundary condition: training_percent == 1.0
|
||||
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 1.0);
|
||||
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 1.0, randomLong());
|
||||
assertThat(regression.getTrainingPercent(), equalTo(1.0));
|
||||
|
||||
// Boundary condition: training_percent == 100.0
|
||||
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0);
|
||||
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong());
|
||||
assertThat(regression.getTrainingPercent(), equalTo(100.0));
|
||||
|
||||
// training_percent == null, default applied
|
||||
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", null);
|
||||
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", null, randomLong());
|
||||
assertThat(regression.getTrainingPercent(), equalTo(100.0));
|
||||
}
|
||||
|
||||
|
@ -101,4 +110,48 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
|
|||
String randomId = randomAlphaOfLength(10);
|
||||
assertThat(regression.getStateDocId(randomId), equalTo(randomId + "_regression_state#1"));
|
||||
}
|
||||
|
||||
public void testToXContent_GivenVersionBeforeRandomizeSeedWasIntroduced() throws IOException {
|
||||
Regression regression = createRandom();
|
||||
assertThat(regression.getRandomizeSeed(), is(notNullValue()));
|
||||
|
||||
try (XContentBuilder builder = JsonXContent.contentBuilder()) {
|
||||
regression.toXContent(builder, new ToXContent.MapParams(Collections.singletonMap("version", "7.5.0")));
|
||||
String json = Strings.toString(builder);
|
||||
assertThat(json, not(containsString("randomize_seed")));
|
||||
}
|
||||
}
|
||||
|
||||
public void testToXContent_GivenVersionAfterRandomizeSeedWasIntroduced() throws IOException {
|
||||
Regression regression = createRandom();
|
||||
assertThat(regression.getRandomizeSeed(), is(notNullValue()));
|
||||
|
||||
try (XContentBuilder builder = JsonXContent.contentBuilder()) {
|
||||
regression.toXContent(builder, new ToXContent.MapParams(Collections.singletonMap("version", Version.CURRENT.toString())));
|
||||
String json = Strings.toString(builder);
|
||||
assertThat(json, containsString("randomize_seed"));
|
||||
}
|
||||
}
|
||||
|
||||
public void testToXContent_GivenVersionIsNull() throws IOException {
|
||||
Regression regression = createRandom();
|
||||
assertThat(regression.getRandomizeSeed(), is(notNullValue()));
|
||||
|
||||
try (XContentBuilder builder = JsonXContent.contentBuilder()) {
|
||||
regression.toXContent(builder, new ToXContent.MapParams(Collections.singletonMap("version", null)));
|
||||
String json = Strings.toString(builder);
|
||||
assertThat(json, containsString("randomize_seed"));
|
||||
}
|
||||
}
|
||||
|
||||
public void testToXContent_GivenEmptyParams() throws IOException {
|
||||
Regression regression = createRandom();
|
||||
assertThat(regression.getRandomizeSeed(), is(notNullValue()));
|
||||
|
||||
try (XContentBuilder builder = JsonXContent.contentBuilder()) {
|
||||
regression.toXContent(builder, ToXContent.EMPTY_PARAMS);
|
||||
String json = Strings.toString(builder);
|
||||
assertThat(json, containsString("randomize_seed"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@ import org.elasticsearch.index.query.QueryBuilders;
|
|||
import org.elasticsearch.search.SearchHit;
|
||||
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParamsTests;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy;
|
||||
|
@ -31,6 +32,7 @@ import java.util.Arrays;
|
|||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
import static java.util.stream.Collectors.toList;
|
||||
import static org.hamcrest.Matchers.allOf;
|
||||
|
@ -158,7 +160,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
sourceIndex,
|
||||
destIndex,
|
||||
null,
|
||||
new Classification(dependentVariable, BoostedTreeParamsTests.createRandom(), null, numTopClasses, 50.0));
|
||||
new Classification(dependentVariable, BoostedTreeParamsTests.createRandom(), null, numTopClasses, 50.0, null));
|
||||
registerAnalytics(config);
|
||||
putAnalytics(config);
|
||||
|
||||
|
@ -268,6 +270,44 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
assertProgress(jobId, 100, 100, 100, 100);
|
||||
}
|
||||
|
||||
public void testTwoJobsWithSameRandomizeSeedUseSameTrainingSet() throws Exception {
|
||||
String sourceIndex = "classification_two_jobs_with_same_randomize_seed_source";
|
||||
String dependentVariable = KEYWORD_FIELD;
|
||||
indexData(sourceIndex, 10, 0, dependentVariable);
|
||||
|
||||
String firstJobId = "classification_two_jobs_with_same_randomize_seed_1";
|
||||
String firstJobDestIndex = firstJobId + "_dest";
|
||||
|
||||
BoostedTreeParams boostedTreeParams = new BoostedTreeParams(1.0, 1.0, 1.0, 1, 1.0);
|
||||
|
||||
DataFrameAnalyticsConfig firstJob = buildAnalytics(firstJobId, sourceIndex, firstJobDestIndex, null,
|
||||
new Classification(dependentVariable, boostedTreeParams, null, 1, 50.0, null));
|
||||
registerAnalytics(firstJob);
|
||||
putAnalytics(firstJob);
|
||||
|
||||
String secondJobId = "classification_two_jobs_with_same_randomize_seed_2";
|
||||
String secondJobDestIndex = secondJobId + "_dest";
|
||||
|
||||
long randomizeSeed = ((Classification) firstJob.getAnalysis()).getRandomizeSeed();
|
||||
DataFrameAnalyticsConfig secondJob = buildAnalytics(secondJobId, sourceIndex, secondJobDestIndex, null,
|
||||
new Classification(dependentVariable, boostedTreeParams, null, 1, 50.0, randomizeSeed));
|
||||
|
||||
registerAnalytics(secondJob);
|
||||
putAnalytics(secondJob);
|
||||
|
||||
// Let's run both jobs in parallel and wait until they are finished
|
||||
startAnalytics(firstJobId);
|
||||
startAnalytics(secondJobId);
|
||||
waitUntilAnalyticsIsStopped(firstJobId);
|
||||
waitUntilAnalyticsIsStopped(secondJobId);
|
||||
|
||||
// Now we compare they both used the same training rows
|
||||
Set<String> firstRunTrainingRowsIds = getTrainingRowsIds(firstJobDestIndex);
|
||||
Set<String> secondRunTrainingRowsIds = getTrainingRowsIds(secondJobDestIndex);
|
||||
|
||||
assertThat(secondRunTrainingRowsIds, equalTo(firstRunTrainingRowsIds));
|
||||
}
|
||||
|
||||
private void initialize(String jobId) {
|
||||
this.jobId = jobId;
|
||||
this.sourceIndex = jobId + "_source_index";
|
||||
|
|
|
@ -18,6 +18,7 @@ import org.elasticsearch.common.Strings;
|
|||
import org.elasticsearch.common.unit.TimeValue;
|
||||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.search.SearchHit;
|
||||
import org.elasticsearch.search.sort.SortOrder;
|
||||
import org.elasticsearch.xpack.core.ml.action.DeleteDataFrameAnalyticsAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
|
||||
|
@ -45,7 +46,10 @@ import org.hamcrest.Matchers;
|
|||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
|
@ -252,4 +256,22 @@ abstract class MlNativeDataFrameAnalyticsIntegTestCase extends MlNativeIntegTest
|
|||
.map(hit -> (String) hit.getSourceAsMap().get("message"))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
protected static Set<String> getTrainingRowsIds(String index) {
|
||||
Set<String> trainingRowsIds = new HashSet<>();
|
||||
SearchResponse hits = client().prepareSearch(index).get();
|
||||
for (SearchHit hit : hits.getHits()) {
|
||||
Map<String, Object> sourceAsMap = hit.getSourceAsMap();
|
||||
assertThat(sourceAsMap.containsKey("ml"), is(true));
|
||||
@SuppressWarnings("unchecked")
|
||||
Map<String, Object> resultsObject = (Map<String, Object>) sourceAsMap.get("ml");
|
||||
|
||||
assertThat(resultsObject.containsKey("is_training"), is(true));
|
||||
if (Boolean.TRUE.equals(resultsObject.get("is_training"))) {
|
||||
trainingRowsIds.add(hit.getId());
|
||||
}
|
||||
}
|
||||
assertThat(trainingRowsIds.isEmpty(), is(false));
|
||||
return trainingRowsIds;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,6 +16,7 @@ import org.elasticsearch.index.query.QueryBuilders;
|
|||
import org.elasticsearch.search.SearchHit;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParamsTests;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
|
||||
|
@ -25,6 +26,7 @@ import java.util.Arrays;
|
|||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
import static org.hamcrest.Matchers.anyOf;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
@ -139,7 +141,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
sourceIndex,
|
||||
destIndex,
|
||||
null,
|
||||
new Regression(DEPENDENT_VARIABLE_FIELD, BoostedTreeParamsTests.createRandom(), null, 50.0));
|
||||
new Regression(DEPENDENT_VARIABLE_FIELD, BoostedTreeParamsTests.createRandom(), null, 50.0, null));
|
||||
registerAnalytics(config);
|
||||
putAnalytics(config);
|
||||
|
||||
|
@ -235,6 +237,43 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
assertInferenceModelPersisted(jobId);
|
||||
}
|
||||
|
||||
public void testTwoJobsWithSameRandomizeSeedUseSameTrainingSet() throws Exception {
|
||||
String sourceIndex = "regression_two_jobs_with_same_randomize_seed_source";
|
||||
indexData(sourceIndex, 10, 0);
|
||||
|
||||
String firstJobId = "regression_two_jobs_with_same_randomize_seed_1";
|
||||
String firstJobDestIndex = firstJobId + "_dest";
|
||||
|
||||
BoostedTreeParams boostedTreeParams = new BoostedTreeParams(1.0, 1.0, 1.0, 1, 1.0);
|
||||
|
||||
DataFrameAnalyticsConfig firstJob = buildAnalytics(firstJobId, sourceIndex, firstJobDestIndex, null,
|
||||
new Regression(DEPENDENT_VARIABLE_FIELD, boostedTreeParams, null, 50.0, null));
|
||||
registerAnalytics(firstJob);
|
||||
putAnalytics(firstJob);
|
||||
|
||||
String secondJobId = "regression_two_jobs_with_same_randomize_seed_2";
|
||||
String secondJobDestIndex = secondJobId + "_dest";
|
||||
|
||||
long randomizeSeed = ((Regression) firstJob.getAnalysis()).getRandomizeSeed();
|
||||
DataFrameAnalyticsConfig secondJob = buildAnalytics(secondJobId, sourceIndex, secondJobDestIndex, null,
|
||||
new Regression(DEPENDENT_VARIABLE_FIELD, boostedTreeParams, null, 50.0, randomizeSeed));
|
||||
|
||||
registerAnalytics(secondJob);
|
||||
putAnalytics(secondJob);
|
||||
|
||||
// Let's run both jobs in parallel and wait until they are finished
|
||||
startAnalytics(firstJobId);
|
||||
startAnalytics(secondJobId);
|
||||
waitUntilAnalyticsIsStopped(firstJobId);
|
||||
waitUntilAnalyticsIsStopped(secondJobId);
|
||||
|
||||
// Now we compare they both used the same training rows
|
||||
Set<String> firstRunTrainingRowsIds = getTrainingRowsIds(firstJobDestIndex);
|
||||
Set<String> secondRunTrainingRowsIds = getTrainingRowsIds(secondJobDestIndex);
|
||||
|
||||
assertThat(secondRunTrainingRowsIds, equalTo(firstRunTrainingRowsIds));
|
||||
}
|
||||
|
||||
private void initialize(String jobId) {
|
||||
this.jobId = jobId;
|
||||
this.sourceIndex = jobId + "_source_index";
|
||||
|
|
|
@ -111,7 +111,7 @@ public class TransportPutDataFrameAnalyticsAction
|
|||
protected void masterOperation(PutDataFrameAnalyticsAction.Request request, ClusterState state,
|
||||
ActionListener<PutDataFrameAnalyticsAction.Response> listener) {
|
||||
validateConfig(request.getConfig());
|
||||
DataFrameAnalyticsConfig memoryCappedConfig =
|
||||
DataFrameAnalyticsConfig preparedForPutConfig =
|
||||
new DataFrameAnalyticsConfig.Builder(request.getConfig(), maxModelMemoryLimit)
|
||||
.setCreateTime(Instant.now())
|
||||
.setVersion(Version.CURRENT)
|
||||
|
@ -120,11 +120,11 @@ public class TransportPutDataFrameAnalyticsAction
|
|||
if (licenseState.isAuthAllowed()) {
|
||||
final String username = securityContext.getUser().principal();
|
||||
RoleDescriptor.IndicesPrivileges sourceIndexPrivileges = RoleDescriptor.IndicesPrivileges.builder()
|
||||
.indices(memoryCappedConfig.getSource().getIndex())
|
||||
.indices(preparedForPutConfig.getSource().getIndex())
|
||||
.privileges("read")
|
||||
.build();
|
||||
RoleDescriptor.IndicesPrivileges destIndexPrivileges = RoleDescriptor.IndicesPrivileges.builder()
|
||||
.indices(memoryCappedConfig.getDest().getIndex())
|
||||
.indices(preparedForPutConfig.getDest().getIndex())
|
||||
.privileges("read", "index", "create_index")
|
||||
.build();
|
||||
|
||||
|
@ -135,16 +135,16 @@ public class TransportPutDataFrameAnalyticsAction
|
|||
privRequest.indexPrivileges(sourceIndexPrivileges, destIndexPrivileges);
|
||||
|
||||
ActionListener<HasPrivilegesResponse> privResponseListener = ActionListener.wrap(
|
||||
r -> handlePrivsResponse(username, memoryCappedConfig, r, listener),
|
||||
r -> handlePrivsResponse(username, preparedForPutConfig, r, listener),
|
||||
listener::onFailure);
|
||||
|
||||
client.execute(HasPrivilegesAction.INSTANCE, privRequest, privResponseListener);
|
||||
} else {
|
||||
updateDocMappingAndPutConfig(
|
||||
memoryCappedConfig,
|
||||
preparedForPutConfig,
|
||||
threadPool.getThreadContext().getHeaders(),
|
||||
ActionListener.wrap(
|
||||
indexResponse -> listener.onResponse(new PutDataFrameAnalyticsAction.Response(memoryCappedConfig)),
|
||||
indexResponse -> listener.onResponse(new PutDataFrameAnalyticsAction.Response(preparedForPutConfig)),
|
||||
listener::onFailure
|
||||
));
|
||||
}
|
||||
|
|
|
@ -24,12 +24,12 @@ public class CustomProcessorFactory {
|
|||
if (analysis instanceof Regression) {
|
||||
Regression regression = (Regression) analysis;
|
||||
return new DatasetSplittingCustomProcessor(
|
||||
fieldNames, regression.getDependentVariable(), regression.getTrainingPercent());
|
||||
fieldNames, regression.getDependentVariable(), regression.getTrainingPercent(), regression.getRandomizeSeed());
|
||||
}
|
||||
if (analysis instanceof Classification) {
|
||||
Classification classification = (Classification) analysis;
|
||||
return new DatasetSplittingCustomProcessor(
|
||||
fieldNames, classification.getDependentVariable(), classification.getTrainingPercent());
|
||||
fieldNames, classification.getDependentVariable(), classification.getTrainingPercent(), classification.getRandomizeSeed());
|
||||
}
|
||||
return row -> {};
|
||||
}
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.ml.dataframe.process.customprocessing;
|
||||
|
||||
import org.elasticsearch.common.Randomness;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.util.List;
|
||||
|
@ -23,12 +22,13 @@ class DatasetSplittingCustomProcessor implements CustomProcessor {
|
|||
|
||||
private final int dependentVariableIndex;
|
||||
private final double trainingPercent;
|
||||
private final Random random = Randomness.get();
|
||||
private final Random random;
|
||||
private boolean isFirstRow = true;
|
||||
|
||||
DatasetSplittingCustomProcessor(List<String> fieldNames, String dependentVariable, double trainingPercent) {
|
||||
DatasetSplittingCustomProcessor(List<String> fieldNames, String dependentVariable, double trainingPercent, long randomizeSeed) {
|
||||
this.dependentVariableIndex = findDependentVariableIndex(fieldNames, dependentVariable);
|
||||
this.trainingPercent = trainingPercent;
|
||||
this.random = new Random(randomizeSeed);
|
||||
}
|
||||
|
||||
private static int findDependentVariableIndex(List<String> fieldNames, String dependentVariable) {
|
||||
|
|
|
@ -24,6 +24,7 @@ public class DatasetSplittingCustomProcessorTests extends ESTestCase {
|
|||
private List<String> fields;
|
||||
private int dependentVariableIndex;
|
||||
private String dependentVariable;
|
||||
private long randomizeSeed;
|
||||
|
||||
@Before
|
||||
public void setUpTests() {
|
||||
|
@ -34,10 +35,11 @@ public class DatasetSplittingCustomProcessorTests extends ESTestCase {
|
|||
}
|
||||
dependentVariableIndex = randomIntBetween(0, fieldCount - 1);
|
||||
dependentVariable = fields.get(dependentVariableIndex);
|
||||
randomizeSeed = randomLong();
|
||||
}
|
||||
|
||||
public void testProcess_GivenRowsWithoutDependentVariableValue() {
|
||||
CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, 50.0);
|
||||
CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, 50.0, randomizeSeed);
|
||||
|
||||
for (int i = 0; i < 100; i++) {
|
||||
String[] row = new String[fields.size()];
|
||||
|
@ -55,7 +57,7 @@ public class DatasetSplittingCustomProcessorTests extends ESTestCase {
|
|||
}
|
||||
|
||||
public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIsHundred() {
|
||||
CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, 100.0);
|
||||
CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, 100.0, randomizeSeed);
|
||||
|
||||
for (int i = 0; i < 100; i++) {
|
||||
String[] row = new String[fields.size()];
|
||||
|
@ -75,7 +77,7 @@ public class DatasetSplittingCustomProcessorTests extends ESTestCase {
|
|||
public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIsRandom() {
|
||||
double trainingPercent = randomDoubleBetween(1.0, 100.0, true);
|
||||
double trainingFraction = trainingPercent / 100;
|
||||
CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, trainingPercent);
|
||||
CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, trainingPercent, randomizeSeed);
|
||||
|
||||
int runCount = 20;
|
||||
int rowsCount = 1000;
|
||||
|
@ -121,7 +123,7 @@ public class DatasetSplittingCustomProcessorTests extends ESTestCase {
|
|||
}
|
||||
|
||||
public void testProcess_ShouldHaveAtLeastOneTrainingRow() {
|
||||
CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, 1.0);
|
||||
CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, 1.0, randomizeSeed);
|
||||
|
||||
// We have some non-training rows and then a training row to check
|
||||
// we maintain the first training row and not just the first row
|
||||
|
|
|
@ -1456,7 +1456,8 @@ setup:
|
|||
"eta": 0.5,
|
||||
"maximum_number_trees": 400,
|
||||
"feature_bag_fraction": 0.3,
|
||||
"training_percent": 60.3
|
||||
"training_percent": 60.3,
|
||||
"randomize_seed": 42
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1472,7 +1473,8 @@ setup:
|
|||
"maximum_number_trees": 400,
|
||||
"feature_bag_fraction": 0.3,
|
||||
"prediction_field_name": "foo_prediction",
|
||||
"training_percent": 60.3
|
||||
"training_percent": 60.3,
|
||||
"randomize_seed": 42
|
||||
}
|
||||
}}
|
||||
- is_true: create_time
|
||||
|
@ -1796,7 +1798,8 @@ setup:
|
|||
"eta": 0.5,
|
||||
"maximum_number_trees": 400,
|
||||
"feature_bag_fraction": 0.3,
|
||||
"training_percent": 60.3
|
||||
"training_percent": 60.3,
|
||||
"randomize_seed": 24
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1813,6 +1816,7 @@ setup:
|
|||
"feature_bag_fraction": 0.3,
|
||||
"prediction_field_name": "foo_prediction",
|
||||
"training_percent": 60.3,
|
||||
"randomize_seed": 24,
|
||||
"num_top_classes": 2
|
||||
}
|
||||
}}
|
||||
|
@ -1836,7 +1840,8 @@ setup:
|
|||
},
|
||||
"analysis": {
|
||||
"regression": {
|
||||
"dependent_variable": "foo"
|
||||
"dependent_variable": "foo",
|
||||
"randomize_seed": 42
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1848,7 +1853,8 @@ setup:
|
|||
"regression":{
|
||||
"dependent_variable": "foo",
|
||||
"prediction_field_name": "foo_prediction",
|
||||
"training_percent": 100.0
|
||||
"training_percent": 100.0,
|
||||
"randomize_seed": 42
|
||||
}
|
||||
}}
|
||||
- is_true: create_time
|
||||
|
|
Loading…
Reference in New Issue