[7.x][ML] Introduce randomize_seed setting for regression and classification (#49990) (#50023)

This adds a new `randomize_seed` for regression and classification.
When not explicitly set, the seed is randomly generated. One can
reuse the seed in a similar job in order to ensure the same docs
are picked for training.

Backport of #49990
This commit is contained in:
Dimitris Athanasiou 2019-12-10 15:29:19 +02:00 committed by GitHub
parent ee4a8a08dd
commit 8891f4db88
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 465 additions and 77 deletions

View File

@ -49,6 +49,7 @@ public class Classification implements DataFrameAnalysis {
static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");
private static final ConstructingObjectParser<Classification, Void> PARSER =
new ConstructingObjectParser<>(
@ -63,7 +64,8 @@ public class Classification implements DataFrameAnalysis {
(Double) a[5],
(String) a[6],
(Double) a[7],
(Integer) a[8]));
(Integer) a[8],
(Long) a[9]));
static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE);
@ -75,6 +77,7 @@ public class Classification implements DataFrameAnalysis {
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME);
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT);
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUM_TOP_CLASSES);
PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), RANDOMIZE_SEED);
}
private final String dependentVariable;
@ -86,10 +89,11 @@ public class Classification implements DataFrameAnalysis {
private final String predictionFieldName;
private final Double trainingPercent;
private final Integer numTopClasses;
private final Long randomizeSeed;
private Classification(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
@Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, @Nullable String predictionFieldName,
@Nullable Double trainingPercent, @Nullable Integer numTopClasses) {
@Nullable Double trainingPercent, @Nullable Integer numTopClasses, @Nullable Long randomizeSeed) {
this.dependentVariable = Objects.requireNonNull(dependentVariable);
this.lambda = lambda;
this.gamma = gamma;
@ -99,6 +103,7 @@ public class Classification implements DataFrameAnalysis {
this.predictionFieldName = predictionFieldName;
this.trainingPercent = trainingPercent;
this.numTopClasses = numTopClasses;
this.randomizeSeed = randomizeSeed;
}
@Override
@ -138,6 +143,10 @@ public class Classification implements DataFrameAnalysis {
return trainingPercent;
}
public Long getRandomizeSeed() {
return randomizeSeed;
}
public Integer getNumTopClasses() {
return numTopClasses;
}
@ -167,6 +176,9 @@ public class Classification implements DataFrameAnalysis {
if (trainingPercent != null) {
builder.field(TRAINING_PERCENT.getPreferredName(), trainingPercent);
}
if (randomizeSeed != null) {
builder.field(RANDOMIZE_SEED.getPreferredName(), randomizeSeed);
}
if (numTopClasses != null) {
builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses);
}
@ -177,7 +189,7 @@ public class Classification implements DataFrameAnalysis {
@Override
public int hashCode() {
return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
trainingPercent, numTopClasses);
trainingPercent, randomizeSeed, numTopClasses);
}
@Override
@ -193,6 +205,7 @@ public class Classification implements DataFrameAnalysis {
&& Objects.equals(featureBagFraction, that.featureBagFraction)
&& Objects.equals(predictionFieldName, that.predictionFieldName)
&& Objects.equals(trainingPercent, that.trainingPercent)
&& Objects.equals(randomizeSeed, that.randomizeSeed)
&& Objects.equals(numTopClasses, that.numTopClasses);
}
@ -211,6 +224,7 @@ public class Classification implements DataFrameAnalysis {
private String predictionFieldName;
private Double trainingPercent;
private Integer numTopClasses;
private Long randomizeSeed;
private Builder(String dependentVariable) {
this.dependentVariable = Objects.requireNonNull(dependentVariable);
@ -251,6 +265,11 @@ public class Classification implements DataFrameAnalysis {
return this;
}
public Builder setRandomizeSeed(Long randomizeSeed) {
this.randomizeSeed = randomizeSeed;
return this;
}
public Builder setNumTopClasses(Integer numTopClasses) {
this.numTopClasses = numTopClasses;
return this;
@ -258,7 +277,7 @@ public class Classification implements DataFrameAnalysis {
public Classification build() {
return new Classification(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
trainingPercent, numTopClasses);
trainingPercent, numTopClasses, randomizeSeed);
}
}
}

View File

@ -48,6 +48,7 @@ public class Regression implements DataFrameAnalysis {
static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction");
static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");
private static final ConstructingObjectParser<Regression, Void> PARSER =
new ConstructingObjectParser<>(
@ -61,7 +62,8 @@ public class Regression implements DataFrameAnalysis {
(Integer) a[4],
(Double) a[5],
(String) a[6],
(Double) a[7]));
(Double) a[7],
(Long) a[8]));
static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE);
@ -72,6 +74,7 @@ public class Regression implements DataFrameAnalysis {
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), FEATURE_BAG_FRACTION);
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME);
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT);
PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), RANDOMIZE_SEED);
}
private final String dependentVariable;
@ -82,10 +85,11 @@ public class Regression implements DataFrameAnalysis {
private final Double featureBagFraction;
private final String predictionFieldName;
private final Double trainingPercent;
private final Long randomizeSeed;
private Regression(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
@Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, @Nullable String predictionFieldName,
@Nullable Double trainingPercent) {
@Nullable Double trainingPercent, @Nullable Long randomizeSeed) {
this.dependentVariable = Objects.requireNonNull(dependentVariable);
this.lambda = lambda;
this.gamma = gamma;
@ -94,6 +98,7 @@ public class Regression implements DataFrameAnalysis {
this.featureBagFraction = featureBagFraction;
this.predictionFieldName = predictionFieldName;
this.trainingPercent = trainingPercent;
this.randomizeSeed = randomizeSeed;
}
@Override
@ -133,6 +138,10 @@ public class Regression implements DataFrameAnalysis {
return trainingPercent;
}
public Long getRandomizeSeed() {
return randomizeSeed;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
@ -158,6 +167,9 @@ public class Regression implements DataFrameAnalysis {
if (trainingPercent != null) {
builder.field(TRAINING_PERCENT.getPreferredName(), trainingPercent);
}
if (randomizeSeed != null) {
builder.field(RANDOMIZE_SEED.getPreferredName(), randomizeSeed);
}
builder.endObject();
return builder;
}
@ -165,7 +177,7 @@ public class Regression implements DataFrameAnalysis {
@Override
public int hashCode() {
return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
trainingPercent);
trainingPercent, randomizeSeed);
}
@Override
@ -180,7 +192,8 @@ public class Regression implements DataFrameAnalysis {
&& Objects.equals(maximumNumberTrees, that.maximumNumberTrees)
&& Objects.equals(featureBagFraction, that.featureBagFraction)
&& Objects.equals(predictionFieldName, that.predictionFieldName)
&& Objects.equals(trainingPercent, that.trainingPercent);
&& Objects.equals(trainingPercent, that.trainingPercent)
&& Objects.equals(randomizeSeed, that.randomizeSeed);
}
@Override
@ -197,6 +210,7 @@ public class Regression implements DataFrameAnalysis {
private Double featureBagFraction;
private String predictionFieldName;
private Double trainingPercent;
private Long randomizeSeed;
private Builder(String dependentVariable) {
this.dependentVariable = Objects.requireNonNull(dependentVariable);
@ -237,9 +251,14 @@ public class Regression implements DataFrameAnalysis {
return this;
}
public Builder setRandomizeSeed(Long randomizeSeed) {
this.randomizeSeed = randomizeSeed;
return this;
}
public Regression build() {
return new Regression(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
trainingPercent);
trainingPercent, randomizeSeed);
}
}
}

View File

@ -1321,6 +1321,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
.setAnalysis(org.elasticsearch.client.ml.dataframe.Regression.builder("my_dependent_variable")
.setPredictionFieldName("my_dependent_variable_prediction")
.setTrainingPercent(80.0)
.setRandomizeSeed(42L)
.build())
.setDescription("this is a regression")
.build();
@ -1356,6 +1357,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
.setAnalysis(org.elasticsearch.client.ml.dataframe.Classification.builder("my_dependent_variable")
.setPredictionFieldName("my_dependent_variable_prediction")
.setTrainingPercent(80.0)
.setRandomizeSeed(42L)
.setNumTopClasses(1)
.build())
.setDescription("this is a classification")

View File

@ -2975,7 +2975,8 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
.setFeatureBagFraction(0.4) // <6>
.setPredictionFieldName("my_prediction_field_name") // <7>
.setTrainingPercent(50.0) // <8>
.setNumTopClasses(1) // <9>
.setRandomizeSeed(1234L) // <9>
.setNumTopClasses(1) // <10>
.build();
// end::put-data-frame-analytics-classification
@ -2988,6 +2989,7 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
.setFeatureBagFraction(0.4) // <6>
.setPredictionFieldName("my_prediction_field_name") // <7>
.setTrainingPercent(50.0) // <8>
.setRandomizeSeed(1234L) // <9>
.build();
// end::put-data-frame-analytics-regression

View File

@ -34,6 +34,7 @@ public class ClassificationTests extends AbstractXContentTestCase<Classification
.setFeatureBagFraction(randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false))
.setPredictionFieldName(randomBoolean() ? null : randomAlphaOfLength(10))
.setTrainingPercent(randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true))
.setRandomizeSeed(randomBoolean() ? null : randomLong())
.setNumTopClasses(randomBoolean() ? null : randomIntBetween(0, 10))
.build();
}

View File

@ -119,7 +119,8 @@ include-tagged::{doc-tests-file}[{api}-classification]
<6> The fraction of features which will be used when selecting a random bag for each candidate split. A double in (0, 1].
<7> The name of the prediction field in the results object.
<8> The percentage of training-eligible rows to be used in training. Defaults to 100%.
<9> The number of top classes to be reported in the results. Defaults to 2.
<9> The seed to be used by the random generator that picks which rows are used in training.
<10> The number of top classes to be reported in the results. Defaults to 2.
===== Regression
@ -138,6 +139,7 @@ include-tagged::{doc-tests-file}[{api}-regression]
<6> The fraction of features which will be used when selecting a random bag for each candidate split. A double in (0, 1].
<7> The name of the prediction field in the results object.
<8> The percentage of training-eligible rows to be used in training. Defaults to 100%.
<9> The seed to be used by the random generator that picks which rows are used in training.
==== Analyzed fields

View File

@ -204,6 +204,8 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=prediction_field_name]
include::{docdir}/ml/ml-shared.asciidoc[tag=training_percent]
include::{docdir}/ml/ml-shared.asciidoc[tag=randomize_seed]
[float]
[[regression-resources-advanced]]
@ -252,6 +254,8 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=prediction_field_name]
include::{docdir}/ml/ml-shared.asciidoc[tag=training_percent]
include::{docdir}/ml/ml-shared.asciidoc[tag=randomize_seed]
[float]
[[classification-resources-advanced]]

View File

@ -402,7 +402,8 @@ PUT _ml/data_frame/analytics/student_performance_mathematics_0.3
{
"regression": {
"dependent_variable": "G3",
"training_percent": 70 <1>
"training_percent": 70, <1>
"randomize_seed": 19673948271 <2>
}
}
}
@ -411,6 +412,7 @@ PUT _ml/data_frame/analytics/student_performance_mathematics_0.3
<1> The `training_percent` defines the percentage of the data set that will be used
for training the model.
<2> The `randomize_seed` is the seed used to randomly pick which data is used for training.
[[ml-put-dfanalytics-example-c]]

View File

@ -68,3 +68,17 @@ be used for training. Documents that are ignored by the analysis (for example
those that contain arrays) wont be included in the calculation for used
percentage. Defaults to `100`.
end::training_percent[]
tag::randomize_seed[]
`randomize_seed`::
(Optional, long) Defines the seed to the random generator that is used to pick
which documents will be used for training. By default it is randomly generated.
Set it to a specific value to ensure the same documents are used for training
assuming other related parameters (e.g. `source`, `analyzed_fields`, etc.) are the same.
end::randomize_seed[]
tag::use-null[]
Defines whether a new series is used as the null series when there is no value
for the by or partition fields. The default value is `false`.
end::use-null[]

View File

@ -225,7 +225,8 @@ public class DataFrameAnalyticsConfig implements ToXContentObject, Writeable {
builder.field(DEST.getPreferredName(), dest);
builder.startObject(ANALYSIS.getPreferredName());
builder.field(analysis.getWriteableName(), analysis);
builder.field(analysis.getWriteableName(), analysis,
new MapParams(Collections.singletonMap(VERSION.getPreferredName(), version == null ? null : version.toString())));
builder.endObject();
if (params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false)) {

View File

@ -49,7 +49,7 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable {
private final Integer maximumNumberTrees;
private final Double featureBagFraction;
BoostedTreeParams(@Nullable Double lambda,
public BoostedTreeParams(@Nullable Double lambda,
@Nullable Double gamma,
@Nullable Double eta,
@Nullable Integer maximumNumberTrees,
@ -76,7 +76,7 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable {
this.featureBagFraction = featureBagFraction;
}
BoostedTreeParams() {
public BoostedTreeParams() {
this(null, null, null, null, null);
}

View File

@ -5,8 +5,10 @@
*/
package org.elasticsearch.xpack.core.ml.dataframe.analyses;
import org.elasticsearch.Version;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Randomness;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
@ -35,6 +37,7 @@ public class Classification implements DataFrameAnalysis {
public static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
public static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
public static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");
private static final ConstructingObjectParser<Classification, Void> LENIENT_PARSER = createParser(true);
private static final ConstructingObjectParser<Classification, Void> STRICT_PARSER = createParser(false);
@ -48,12 +51,14 @@ public class Classification implements DataFrameAnalysis {
new BoostedTreeParams((Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5]),
(String) a[6],
(Integer) a[7],
(Double) a[8]));
(Double) a[8],
(Long) a[9]));
parser.declareString(constructorArg(), DEPENDENT_VARIABLE);
BoostedTreeParams.declareFields(parser);
parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME);
parser.declareInt(optionalConstructorArg(), NUM_TOP_CLASSES);
parser.declareDouble(optionalConstructorArg(), TRAINING_PERCENT);
parser.declareLong(optionalConstructorArg(), RANDOMIZE_SEED);
return parser;
}
@ -84,12 +89,14 @@ public class Classification implements DataFrameAnalysis {
private final String predictionFieldName;
private final int numTopClasses;
private final double trainingPercent;
private final long randomizeSeed;
public Classification(String dependentVariable,
BoostedTreeParams boostedTreeParams,
@Nullable String predictionFieldName,
@Nullable Integer numTopClasses,
@Nullable Double trainingPercent) {
@Nullable Double trainingPercent,
@Nullable Long randomizeSeed) {
if (numTopClasses != null && (numTopClasses < 0 || numTopClasses > 1000)) {
throw ExceptionsHelper.badRequestException("[{}] must be an integer in [0, 1000]", NUM_TOP_CLASSES.getPreferredName());
}
@ -101,10 +108,11 @@ public class Classification implements DataFrameAnalysis {
this.predictionFieldName = predictionFieldName == null ? dependentVariable + "_prediction" : predictionFieldName;
this.numTopClasses = numTopClasses == null ? DEFAULT_NUM_TOP_CLASSES : numTopClasses;
this.trainingPercent = trainingPercent == null ? 100.0 : trainingPercent;
this.randomizeSeed = randomizeSeed == null ? Randomness.get().nextLong() : randomizeSeed;
}
public Classification(String dependentVariable) {
this(dependentVariable, new BoostedTreeParams(), null, null, null);
this(dependentVariable, new BoostedTreeParams(), null, null, null, null);
}
public Classification(StreamInput in) throws IOException {
@ -113,12 +121,21 @@ public class Classification implements DataFrameAnalysis {
predictionFieldName = in.readOptionalString();
numTopClasses = in.readOptionalVInt();
trainingPercent = in.readDouble();
if (in.getVersion().onOrAfter(Version.V_7_6_0)) {
randomizeSeed = in.readOptionalLong();
} else {
randomizeSeed = Randomness.get().nextLong();
}
}
public String getDependentVariable() {
return dependentVariable;
}
public BoostedTreeParams getBoostedTreeParams() {
return boostedTreeParams;
}
public String getPredictionFieldName() {
return predictionFieldName;
}
@ -131,6 +148,11 @@ public class Classification implements DataFrameAnalysis {
return trainingPercent;
}
@Nullable
public Long getRandomizeSeed() {
return randomizeSeed;
}
@Override
public String getWriteableName() {
return NAME.getPreferredName();
@ -143,10 +165,15 @@ public class Classification implements DataFrameAnalysis {
out.writeOptionalString(predictionFieldName);
out.writeOptionalVInt(numTopClasses);
out.writeDouble(trainingPercent);
if (out.getVersion().onOrAfter(Version.V_7_6_0)) {
out.writeOptionalLong(randomizeSeed);
}
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
Version version = Version.fromString(params.param("version", Version.CURRENT.toString()));
builder.startObject();
builder.field(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable);
boostedTreeParams.toXContent(builder, params);
@ -155,6 +182,9 @@ public class Classification implements DataFrameAnalysis {
builder.field(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
}
builder.field(TRAINING_PERCENT.getPreferredName(), trainingPercent);
if (version.onOrAfter(Version.V_7_6_0)) {
builder.field(RANDOMIZE_SEED.getPreferredName(), randomizeSeed);
}
builder.endObject();
return builder;
}
@ -240,11 +270,12 @@ public class Classification implements DataFrameAnalysis {
&& Objects.equals(boostedTreeParams, that.boostedTreeParams)
&& Objects.equals(predictionFieldName, that.predictionFieldName)
&& Objects.equals(numTopClasses, that.numTopClasses)
&& trainingPercent == that.trainingPercent;
&& trainingPercent == that.trainingPercent
&& randomizeSeed == that.randomizeSeed;
}
@Override
public int hashCode() {
return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, numTopClasses, trainingPercent);
return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, numTopClasses, trainingPercent, randomizeSeed);
}
}

View File

@ -5,8 +5,10 @@
*/
package org.elasticsearch.xpack.core.ml.dataframe.analyses;
import org.elasticsearch.Version;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Randomness;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
@ -32,6 +34,7 @@ public class Regression implements DataFrameAnalysis {
public static final ParseField DEPENDENT_VARIABLE = new ParseField("dependent_variable");
public static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
public static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
public static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");
private static final ConstructingObjectParser<Regression, Void> LENIENT_PARSER = createParser(true);
private static final ConstructingObjectParser<Regression, Void> STRICT_PARSER = createParser(false);
@ -44,11 +47,13 @@ public class Regression implements DataFrameAnalysis {
(String) a[0],
new BoostedTreeParams((Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5]),
(String) a[6],
(Double) a[7]));
(Double) a[7],
(Long) a[8]));
parser.declareString(constructorArg(), DEPENDENT_VARIABLE);
BoostedTreeParams.declareFields(parser);
parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME);
parser.declareDouble(optionalConstructorArg(), TRAINING_PERCENT);
parser.declareLong(optionalConstructorArg(), RANDOMIZE_SEED);
return parser;
}
@ -60,11 +65,13 @@ public class Regression implements DataFrameAnalysis {
private final BoostedTreeParams boostedTreeParams;
private final String predictionFieldName;
private final double trainingPercent;
private final long randomizeSeed;
public Regression(String dependentVariable,
BoostedTreeParams boostedTreeParams,
@Nullable String predictionFieldName,
@Nullable Double trainingPercent) {
@Nullable Double trainingPercent,
@Nullable Long randomizeSeed) {
if (trainingPercent != null && (trainingPercent < 1.0 || trainingPercent > 100.0)) {
throw ExceptionsHelper.badRequestException("[{}] must be a double in [1, 100]", TRAINING_PERCENT.getPreferredName());
}
@ -72,10 +79,11 @@ public class Regression implements DataFrameAnalysis {
this.boostedTreeParams = ExceptionsHelper.requireNonNull(boostedTreeParams, BoostedTreeParams.NAME);
this.predictionFieldName = predictionFieldName == null ? dependentVariable + "_prediction" : predictionFieldName;
this.trainingPercent = trainingPercent == null ? 100.0 : trainingPercent;
this.randomizeSeed = randomizeSeed == null ? Randomness.get().nextLong() : randomizeSeed;
}
public Regression(String dependentVariable) {
this(dependentVariable, new BoostedTreeParams(), null, null);
this(dependentVariable, new BoostedTreeParams(), null, null, null);
}
public Regression(StreamInput in) throws IOException {
@ -83,12 +91,21 @@ public class Regression implements DataFrameAnalysis {
boostedTreeParams = new BoostedTreeParams(in);
predictionFieldName = in.readOptionalString();
trainingPercent = in.readDouble();
if (in.getVersion().onOrAfter(Version.V_7_6_0)) {
randomizeSeed = in.readOptionalLong();
} else {
randomizeSeed = Randomness.get().nextLong();
}
}
public String getDependentVariable() {
return dependentVariable;
}
public BoostedTreeParams getBoostedTreeParams() {
return boostedTreeParams;
}
public String getPredictionFieldName() {
return predictionFieldName;
}
@ -97,6 +114,11 @@ public class Regression implements DataFrameAnalysis {
return trainingPercent;
}
@Nullable
public Long getRandomizeSeed() {
return randomizeSeed;
}
@Override
public String getWriteableName() {
return NAME.getPreferredName();
@ -108,10 +130,15 @@ public class Regression implements DataFrameAnalysis {
boostedTreeParams.writeTo(out);
out.writeOptionalString(predictionFieldName);
out.writeDouble(trainingPercent);
if (out.getVersion().onOrAfter(Version.V_7_6_0)) {
out.writeOptionalLong(randomizeSeed);
}
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
Version version = Version.fromString(params.param("version", Version.CURRENT.toString()));
builder.startObject();
builder.field(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable);
boostedTreeParams.toXContent(builder, params);
@ -119,6 +146,9 @@ public class Regression implements DataFrameAnalysis {
builder.field(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
}
builder.field(TRAINING_PERCENT.getPreferredName(), trainingPercent);
if (version.onOrAfter(Version.V_7_6_0)) {
builder.field(RANDOMIZE_SEED.getPreferredName(), randomizeSeed);
}
builder.endObject();
return builder;
}
@ -177,11 +207,12 @@ public class Regression implements DataFrameAnalysis {
return Objects.equals(dependentVariable, that.dependentVariable)
&& Objects.equals(boostedTreeParams, that.boostedTreeParams)
&& Objects.equals(predictionFieldName, that.predictionFieldName)
&& trainingPercent == that.trainingPercent;
&& trainingPercent == that.trainingPercent
&& randomizeSeed == randomizeSeed;
}
@Override
public int hashCode() {
return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, trainingPercent);
return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, trainingPercent, randomizeSeed);
}
}

View File

@ -9,6 +9,7 @@ import com.carrotsearch.randomizedtesting.generators.CodepointSetGenerator;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.Version;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable;
@ -20,17 +21,20 @@ import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.XContentParseException;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.index.query.MatchAllQueryBuilder;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.search.fetch.subphase.FetchSourceContext;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetectionTests;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
import org.junit.Before;
@ -42,10 +46,13 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasEntry;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.startsWith;
public class DataFrameAnalyticsConfigTests extends AbstractSerializingTestCase<DataFrameAnalyticsConfig> {
@ -339,6 +346,44 @@ public class DataFrameAnalyticsConfigTests extends AbstractSerializingTestCase<D
}
}
public void testToXContent_GivenAnalysisWithRandomizeSeedAndVersionIsCurrent() throws IOException {
Regression regression = new Regression("foo");
assertThat(regression.getRandomizeSeed(), is(notNullValue()));
DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder()
.setVersion(Version.CURRENT)
.setId("test_config")
.setSource(new DataFrameAnalyticsSource(new String[] {"source_index"}, null, null))
.setDest(new DataFrameAnalyticsDest("dest_index", null))
.setAnalysis(regression)
.build();
try (XContentBuilder builder = JsonXContent.contentBuilder()) {
config.toXContent(builder, ToXContent.EMPTY_PARAMS);
String json = Strings.toString(builder);
assertThat(json, containsString("randomize_seed"));
}
}
public void testToXContent_GivenAnalysisWithRandomizeSeedAndVersionIsBeforeItWasIntroduced() throws IOException {
Regression regression = new Regression("foo");
assertThat(regression.getRandomizeSeed(), is(notNullValue()));
DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder()
.setVersion(Version.V_7_5_0)
.setId("test_config")
.setSource(new DataFrameAnalyticsSource(new String[] {"source_index"}, null, null))
.setDest(new DataFrameAnalyticsDest("dest_index", null))
.setAnalysis(regression)
.build();
try (XContentBuilder builder = JsonXContent.contentBuilder()) {
config.toXContent(builder, ToXContent.EMPTY_PARAMS);
String json = Strings.toString(builder);
assertThat(json, not(containsString("randomize_seed")));
}
}
private static void assertTooSmall(ElasticsearchStatusException e) {
assertThat(e.getMessage(), startsWith("model_memory_limit must be at least 1kb."));
}

View File

@ -6,8 +6,13 @@
package org.elasticsearch.xpack.core.ml.dataframe.analyses;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.Version;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.index.mapper.BooleanFieldMapper;
import org.elasticsearch.index.mapper.KeywordFieldMapper;
import org.elasticsearch.index.mapper.NumberFieldMapper;
@ -20,10 +25,12 @@ import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasEntry;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.nullValue;
public class ClassificationTests extends AbstractSerializingTestCase<Classification> {
@ -46,7 +53,9 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
String predictionFieldName = randomBoolean() ? null : randomAlphaOfLength(10);
Integer numTopClasses = randomBoolean() ? null : randomIntBetween(0, 1000);
Double trainingPercent = randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true);
return new Classification(dependentVariableName, boostedTreeParams, predictionFieldName, numTopClasses, trainingPercent);
Long randomizeSeed = randomBoolean() ? null : randomLong();
return new Classification(dependentVariableName, boostedTreeParams, predictionFieldName, numTopClasses, trainingPercent,
randomizeSeed);
}
@Override
@ -56,71 +65,71 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
public void testConstructor_GivenTrainingPercentIsLessThanOne() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 0.999));
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 0.999, randomLong()));
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
}
public void testConstructor_GivenTrainingPercentIsGreaterThan100() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 100.0001));
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 100.0001, randomLong()));
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
}
public void testConstructor_GivenNumTopClassesIsLessThanZero() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", -1, 1.0));
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", -1, 1.0, randomLong()));
assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]"));
}
public void testConstructor_GivenNumTopClassesIsGreaterThan1000() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", 1001, 1.0));
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", 1001, 1.0, randomLong()));
assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]"));
}
public void testGetPredictionFieldName() {
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 50.0);
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 50.0, randomLong());
assertThat(classification.getPredictionFieldName(), equalTo("result"));
classification = new Classification("foo", BOOSTED_TREE_PARAMS, null, 3, 50.0);
classification = new Classification("foo", BOOSTED_TREE_PARAMS, null, 3, 50.0, randomLong());
assertThat(classification.getPredictionFieldName(), equalTo("foo_prediction"));
}
public void testGetNumTopClasses() {
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 7, 1.0);
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 7, 1.0, randomLong());
assertThat(classification.getNumTopClasses(), equalTo(7));
// Boundary condition: num_top_classes == 0
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 0, 1.0);
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 0, 1.0, randomLong());
assertThat(classification.getNumTopClasses(), equalTo(0));
// Boundary condition: num_top_classes == 1000
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 1000, 1.0);
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 1000, 1.0, randomLong());
assertThat(classification.getNumTopClasses(), equalTo(1000));
// num_top_classes == null, default applied
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1.0);
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1.0, randomLong());
assertThat(classification.getNumTopClasses(), equalTo(2));
}
public void testGetTrainingPercent() {
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 50.0);
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 50.0, randomLong());
assertThat(classification.getTrainingPercent(), equalTo(50.0));
// Boundary condition: training_percent == 1.0
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 1.0);
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 1.0, randomLong());
assertThat(classification.getTrainingPercent(), equalTo(1.0));
// Boundary condition: training_percent == 100.0
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 100.0);
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 100.0, randomLong());
assertThat(classification.getTrainingPercent(), equalTo(100.0));
// training_percent == null, default applied
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, null);
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, null, randomLong());
assertThat(classification.getTrainingPercent(), equalTo(100.0));
}
@ -155,4 +164,48 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
public void testFieldCardinalityLimitsIsNonNull() {
assertThat(createTestInstance().getFieldCardinalityLimits(), is(not(nullValue())));
}
public void testToXContent_GivenVersionBeforeRandomizeSeedWasIntroduced() throws IOException {
Classification classification = createRandom();
assertThat(classification.getRandomizeSeed(), is(notNullValue()));
try (XContentBuilder builder = JsonXContent.contentBuilder()) {
classification.toXContent(builder, new ToXContent.MapParams(Collections.singletonMap("version", "7.5.0")));
String json = Strings.toString(builder);
assertThat(json, not(containsString("randomize_seed")));
}
}
public void testToXContent_GivenVersionAfterRandomizeSeedWasIntroduced() throws IOException {
Classification classification = createRandom();
assertThat(classification.getRandomizeSeed(), is(notNullValue()));
try (XContentBuilder builder = JsonXContent.contentBuilder()) {
classification.toXContent(builder, new ToXContent.MapParams(Collections.singletonMap("version", Version.CURRENT.toString())));
String json = Strings.toString(builder);
assertThat(json, containsString("randomize_seed"));
}
}
public void testToXContent_GivenVersionIsNull() throws IOException {
Classification classification = createRandom();
assertThat(classification.getRandomizeSeed(), is(notNullValue()));
try (XContentBuilder builder = JsonXContent.contentBuilder()) {
classification.toXContent(builder, new ToXContent.MapParams(Collections.singletonMap("version", null)));
String json = Strings.toString(builder);
assertThat(json, containsString("randomize_seed"));
}
}
public void testToXContent_GivenEmptyParams() throws IOException {
Classification classification = createRandom();
assertThat(classification.getRandomizeSeed(), is(notNullValue()));
try (XContentBuilder builder = JsonXContent.contentBuilder()) {
classification.toXContent(builder, ToXContent.EMPTY_PARAMS);
String json = Strings.toString(builder);
assertThat(json, containsString("randomize_seed"));
}
}
}

View File

@ -6,17 +6,25 @@
package org.elasticsearch.xpack.core.ml.dataframe.analyses;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.Version;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.test.AbstractSerializingTestCase;
import java.io.IOException;
import java.util.Collections;
import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasEntry;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.nullValue;
public class RegressionTests extends AbstractSerializingTestCase<Regression> {
@ -38,7 +46,8 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
BoostedTreeParams boostedTreeParams = BoostedTreeParamsTests.createRandom();
String predictionFieldName = randomBoolean() ? null : randomAlphaOfLength(10);
Double trainingPercent = randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true);
return new Regression(dependentVariableName, boostedTreeParams, predictionFieldName, trainingPercent);
Long randomizeSeed = randomBoolean() ? null : randomLong();
return new Regression(dependentVariableName, boostedTreeParams, predictionFieldName, trainingPercent, randomizeSeed);
}
@Override
@ -48,40 +57,40 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
public void testConstructor_GivenTrainingPercentIsLessThanOne() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 0.999));
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 0.999, randomLong()));
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
}
public void testConstructor_GivenTrainingPercentIsGreaterThan100() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0001));
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0001, randomLong()));
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
}
public void testGetPredictionFieldName() {
Regression regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 50.0);
Regression regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 50.0, randomLong());
assertThat(regression.getPredictionFieldName(), equalTo("result"));
regression = new Regression("foo", BOOSTED_TREE_PARAMS, null, 50.0);
regression = new Regression("foo", BOOSTED_TREE_PARAMS, null, 50.0, randomLong());
assertThat(regression.getPredictionFieldName(), equalTo("foo_prediction"));
}
public void testGetTrainingPercent() {
Regression regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 50.0);
Regression regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 50.0, randomLong());
assertThat(regression.getTrainingPercent(), equalTo(50.0));
// Boundary condition: training_percent == 1.0
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 1.0);
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 1.0, randomLong());
assertThat(regression.getTrainingPercent(), equalTo(1.0));
// Boundary condition: training_percent == 100.0
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0);
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong());
assertThat(regression.getTrainingPercent(), equalTo(100.0));
// training_percent == null, default applied
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", null);
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", null, randomLong());
assertThat(regression.getTrainingPercent(), equalTo(100.0));
}
@ -101,4 +110,48 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
String randomId = randomAlphaOfLength(10);
assertThat(regression.getStateDocId(randomId), equalTo(randomId + "_regression_state#1"));
}
public void testToXContent_GivenVersionBeforeRandomizeSeedWasIntroduced() throws IOException {
Regression regression = createRandom();
assertThat(regression.getRandomizeSeed(), is(notNullValue()));
try (XContentBuilder builder = JsonXContent.contentBuilder()) {
regression.toXContent(builder, new ToXContent.MapParams(Collections.singletonMap("version", "7.5.0")));
String json = Strings.toString(builder);
assertThat(json, not(containsString("randomize_seed")));
}
}
public void testToXContent_GivenVersionAfterRandomizeSeedWasIntroduced() throws IOException {
Regression regression = createRandom();
assertThat(regression.getRandomizeSeed(), is(notNullValue()));
try (XContentBuilder builder = JsonXContent.contentBuilder()) {
regression.toXContent(builder, new ToXContent.MapParams(Collections.singletonMap("version", Version.CURRENT.toString())));
String json = Strings.toString(builder);
assertThat(json, containsString("randomize_seed"));
}
}
public void testToXContent_GivenVersionIsNull() throws IOException {
Regression regression = createRandom();
assertThat(regression.getRandomizeSeed(), is(notNullValue()));
try (XContentBuilder builder = JsonXContent.contentBuilder()) {
regression.toXContent(builder, new ToXContent.MapParams(Collections.singletonMap("version", null)));
String json = Strings.toString(builder);
assertThat(json, containsString("randomize_seed"));
}
}
public void testToXContent_GivenEmptyParams() throws IOException {
Regression regression = createRandom();
assertThat(regression.getRandomizeSeed(), is(notNullValue()));
try (XContentBuilder builder = JsonXContent.contentBuilder()) {
regression.toXContent(builder, ToXContent.EMPTY_PARAMS);
String json = Strings.toString(builder);
assertThat(json, containsString("randomize_seed"));
}
}
}

View File

@ -20,6 +20,7 @@ import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParamsTests;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy;
@ -31,6 +32,7 @@ import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import static java.util.stream.Collectors.toList;
import static org.hamcrest.Matchers.allOf;
@ -158,7 +160,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
sourceIndex,
destIndex,
null,
new Classification(dependentVariable, BoostedTreeParamsTests.createRandom(), null, numTopClasses, 50.0));
new Classification(dependentVariable, BoostedTreeParamsTests.createRandom(), null, numTopClasses, 50.0, null));
registerAnalytics(config);
putAnalytics(config);
@ -268,6 +270,44 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
assertProgress(jobId, 100, 100, 100, 100);
}
public void testTwoJobsWithSameRandomizeSeedUseSameTrainingSet() throws Exception {
String sourceIndex = "classification_two_jobs_with_same_randomize_seed_source";
String dependentVariable = KEYWORD_FIELD;
indexData(sourceIndex, 10, 0, dependentVariable);
String firstJobId = "classification_two_jobs_with_same_randomize_seed_1";
String firstJobDestIndex = firstJobId + "_dest";
BoostedTreeParams boostedTreeParams = new BoostedTreeParams(1.0, 1.0, 1.0, 1, 1.0);
DataFrameAnalyticsConfig firstJob = buildAnalytics(firstJobId, sourceIndex, firstJobDestIndex, null,
new Classification(dependentVariable, boostedTreeParams, null, 1, 50.0, null));
registerAnalytics(firstJob);
putAnalytics(firstJob);
String secondJobId = "classification_two_jobs_with_same_randomize_seed_2";
String secondJobDestIndex = secondJobId + "_dest";
long randomizeSeed = ((Classification) firstJob.getAnalysis()).getRandomizeSeed();
DataFrameAnalyticsConfig secondJob = buildAnalytics(secondJobId, sourceIndex, secondJobDestIndex, null,
new Classification(dependentVariable, boostedTreeParams, null, 1, 50.0, randomizeSeed));
registerAnalytics(secondJob);
putAnalytics(secondJob);
// Let's run both jobs in parallel and wait until they are finished
startAnalytics(firstJobId);
startAnalytics(secondJobId);
waitUntilAnalyticsIsStopped(firstJobId);
waitUntilAnalyticsIsStopped(secondJobId);
// Now we compare they both used the same training rows
Set<String> firstRunTrainingRowsIds = getTrainingRowsIds(firstJobDestIndex);
Set<String> secondRunTrainingRowsIds = getTrainingRowsIds(secondJobDestIndex);
assertThat(secondRunTrainingRowsIds, equalTo(firstRunTrainingRowsIds));
}
private void initialize(String jobId) {
this.jobId = jobId;
this.sourceIndex = jobId + "_source_index";

View File

@ -18,6 +18,7 @@ import org.elasticsearch.common.Strings;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.xpack.core.ml.action.DeleteDataFrameAnalyticsAction;
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
@ -45,7 +46,10 @@ import org.hamcrest.Matchers;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
@ -252,4 +256,22 @@ abstract class MlNativeDataFrameAnalyticsIntegTestCase extends MlNativeIntegTest
.map(hit -> (String) hit.getSourceAsMap().get("message"))
.collect(Collectors.toList());
}
protected static Set<String> getTrainingRowsIds(String index) {
Set<String> trainingRowsIds = new HashSet<>();
SearchResponse hits = client().prepareSearch(index).get();
for (SearchHit hit : hits.getHits()) {
Map<String, Object> sourceAsMap = hit.getSourceAsMap();
assertThat(sourceAsMap.containsKey("ml"), is(true));
@SuppressWarnings("unchecked")
Map<String, Object> resultsObject = (Map<String, Object>) sourceAsMap.get("ml");
assertThat(resultsObject.containsKey("is_training"), is(true));
if (Boolean.TRUE.equals(resultsObject.get("is_training"))) {
trainingRowsIds.add(hit.getId());
}
}
assertThat(trainingRowsIds.isEmpty(), is(false));
return trainingRowsIds;
}
}

View File

@ -16,6 +16,7 @@ import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParamsTests;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
@ -25,6 +26,7 @@ import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import static org.hamcrest.Matchers.anyOf;
import static org.hamcrest.Matchers.equalTo;
@ -139,7 +141,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
sourceIndex,
destIndex,
null,
new Regression(DEPENDENT_VARIABLE_FIELD, BoostedTreeParamsTests.createRandom(), null, 50.0));
new Regression(DEPENDENT_VARIABLE_FIELD, BoostedTreeParamsTests.createRandom(), null, 50.0, null));
registerAnalytics(config);
putAnalytics(config);
@ -235,6 +237,43 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
assertInferenceModelPersisted(jobId);
}
public void testTwoJobsWithSameRandomizeSeedUseSameTrainingSet() throws Exception {
String sourceIndex = "regression_two_jobs_with_same_randomize_seed_source";
indexData(sourceIndex, 10, 0);
String firstJobId = "regression_two_jobs_with_same_randomize_seed_1";
String firstJobDestIndex = firstJobId + "_dest";
BoostedTreeParams boostedTreeParams = new BoostedTreeParams(1.0, 1.0, 1.0, 1, 1.0);
DataFrameAnalyticsConfig firstJob = buildAnalytics(firstJobId, sourceIndex, firstJobDestIndex, null,
new Regression(DEPENDENT_VARIABLE_FIELD, boostedTreeParams, null, 50.0, null));
registerAnalytics(firstJob);
putAnalytics(firstJob);
String secondJobId = "regression_two_jobs_with_same_randomize_seed_2";
String secondJobDestIndex = secondJobId + "_dest";
long randomizeSeed = ((Regression) firstJob.getAnalysis()).getRandomizeSeed();
DataFrameAnalyticsConfig secondJob = buildAnalytics(secondJobId, sourceIndex, secondJobDestIndex, null,
new Regression(DEPENDENT_VARIABLE_FIELD, boostedTreeParams, null, 50.0, randomizeSeed));
registerAnalytics(secondJob);
putAnalytics(secondJob);
// Let's run both jobs in parallel and wait until they are finished
startAnalytics(firstJobId);
startAnalytics(secondJobId);
waitUntilAnalyticsIsStopped(firstJobId);
waitUntilAnalyticsIsStopped(secondJobId);
// Now we compare they both used the same training rows
Set<String> firstRunTrainingRowsIds = getTrainingRowsIds(firstJobDestIndex);
Set<String> secondRunTrainingRowsIds = getTrainingRowsIds(secondJobDestIndex);
assertThat(secondRunTrainingRowsIds, equalTo(firstRunTrainingRowsIds));
}
private void initialize(String jobId) {
this.jobId = jobId;
this.sourceIndex = jobId + "_source_index";

View File

@ -111,7 +111,7 @@ public class TransportPutDataFrameAnalyticsAction
protected void masterOperation(PutDataFrameAnalyticsAction.Request request, ClusterState state,
ActionListener<PutDataFrameAnalyticsAction.Response> listener) {
validateConfig(request.getConfig());
DataFrameAnalyticsConfig memoryCappedConfig =
DataFrameAnalyticsConfig preparedForPutConfig =
new DataFrameAnalyticsConfig.Builder(request.getConfig(), maxModelMemoryLimit)
.setCreateTime(Instant.now())
.setVersion(Version.CURRENT)
@ -120,11 +120,11 @@ public class TransportPutDataFrameAnalyticsAction
if (licenseState.isAuthAllowed()) {
final String username = securityContext.getUser().principal();
RoleDescriptor.IndicesPrivileges sourceIndexPrivileges = RoleDescriptor.IndicesPrivileges.builder()
.indices(memoryCappedConfig.getSource().getIndex())
.indices(preparedForPutConfig.getSource().getIndex())
.privileges("read")
.build();
RoleDescriptor.IndicesPrivileges destIndexPrivileges = RoleDescriptor.IndicesPrivileges.builder()
.indices(memoryCappedConfig.getDest().getIndex())
.indices(preparedForPutConfig.getDest().getIndex())
.privileges("read", "index", "create_index")
.build();
@ -135,16 +135,16 @@ public class TransportPutDataFrameAnalyticsAction
privRequest.indexPrivileges(sourceIndexPrivileges, destIndexPrivileges);
ActionListener<HasPrivilegesResponse> privResponseListener = ActionListener.wrap(
r -> handlePrivsResponse(username, memoryCappedConfig, r, listener),
r -> handlePrivsResponse(username, preparedForPutConfig, r, listener),
listener::onFailure);
client.execute(HasPrivilegesAction.INSTANCE, privRequest, privResponseListener);
} else {
updateDocMappingAndPutConfig(
memoryCappedConfig,
preparedForPutConfig,
threadPool.getThreadContext().getHeaders(),
ActionListener.wrap(
indexResponse -> listener.onResponse(new PutDataFrameAnalyticsAction.Response(memoryCappedConfig)),
indexResponse -> listener.onResponse(new PutDataFrameAnalyticsAction.Response(preparedForPutConfig)),
listener::onFailure
));
}

View File

@ -24,12 +24,12 @@ public class CustomProcessorFactory {
if (analysis instanceof Regression) {
Regression regression = (Regression) analysis;
return new DatasetSplittingCustomProcessor(
fieldNames, regression.getDependentVariable(), regression.getTrainingPercent());
fieldNames, regression.getDependentVariable(), regression.getTrainingPercent(), regression.getRandomizeSeed());
}
if (analysis instanceof Classification) {
Classification classification = (Classification) analysis;
return new DatasetSplittingCustomProcessor(
fieldNames, classification.getDependentVariable(), classification.getTrainingPercent());
fieldNames, classification.getDependentVariable(), classification.getTrainingPercent(), classification.getRandomizeSeed());
}
return row -> {};
}

View File

@ -5,7 +5,6 @@
*/
package org.elasticsearch.xpack.ml.dataframe.process.customprocessing;
import org.elasticsearch.common.Randomness;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.util.List;
@ -23,12 +22,13 @@ class DatasetSplittingCustomProcessor implements CustomProcessor {
private final int dependentVariableIndex;
private final double trainingPercent;
private final Random random = Randomness.get();
private final Random random;
private boolean isFirstRow = true;
DatasetSplittingCustomProcessor(List<String> fieldNames, String dependentVariable, double trainingPercent) {
DatasetSplittingCustomProcessor(List<String> fieldNames, String dependentVariable, double trainingPercent, long randomizeSeed) {
this.dependentVariableIndex = findDependentVariableIndex(fieldNames, dependentVariable);
this.trainingPercent = trainingPercent;
this.random = new Random(randomizeSeed);
}
private static int findDependentVariableIndex(List<String> fieldNames, String dependentVariable) {

View File

@ -24,6 +24,7 @@ public class DatasetSplittingCustomProcessorTests extends ESTestCase {
private List<String> fields;
private int dependentVariableIndex;
private String dependentVariable;
private long randomizeSeed;
@Before
public void setUpTests() {
@ -34,10 +35,11 @@ public class DatasetSplittingCustomProcessorTests extends ESTestCase {
}
dependentVariableIndex = randomIntBetween(0, fieldCount - 1);
dependentVariable = fields.get(dependentVariableIndex);
randomizeSeed = randomLong();
}
public void testProcess_GivenRowsWithoutDependentVariableValue() {
CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, 50.0);
CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, 50.0, randomizeSeed);
for (int i = 0; i < 100; i++) {
String[] row = new String[fields.size()];
@ -55,7 +57,7 @@ public class DatasetSplittingCustomProcessorTests extends ESTestCase {
}
public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIsHundred() {
CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, 100.0);
CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, 100.0, randomizeSeed);
for (int i = 0; i < 100; i++) {
String[] row = new String[fields.size()];
@ -75,7 +77,7 @@ public class DatasetSplittingCustomProcessorTests extends ESTestCase {
public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIsRandom() {
double trainingPercent = randomDoubleBetween(1.0, 100.0, true);
double trainingFraction = trainingPercent / 100;
CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, trainingPercent);
CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, trainingPercent, randomizeSeed);
int runCount = 20;
int rowsCount = 1000;
@ -121,7 +123,7 @@ public class DatasetSplittingCustomProcessorTests extends ESTestCase {
}
public void testProcess_ShouldHaveAtLeastOneTrainingRow() {
CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, 1.0);
CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, 1.0, randomizeSeed);
// We have some non-training rows and then a training row to check
// we maintain the first training row and not just the first row

View File

@ -1456,7 +1456,8 @@ setup:
"eta": 0.5,
"maximum_number_trees": 400,
"feature_bag_fraction": 0.3,
"training_percent": 60.3
"training_percent": 60.3,
"randomize_seed": 42
}
}
}
@ -1472,7 +1473,8 @@ setup:
"maximum_number_trees": 400,
"feature_bag_fraction": 0.3,
"prediction_field_name": "foo_prediction",
"training_percent": 60.3
"training_percent": 60.3,
"randomize_seed": 42
}
}}
- is_true: create_time
@ -1796,7 +1798,8 @@ setup:
"eta": 0.5,
"maximum_number_trees": 400,
"feature_bag_fraction": 0.3,
"training_percent": 60.3
"training_percent": 60.3,
"randomize_seed": 24
}
}
}
@ -1813,6 +1816,7 @@ setup:
"feature_bag_fraction": 0.3,
"prediction_field_name": "foo_prediction",
"training_percent": 60.3,
"randomize_seed": 24,
"num_top_classes": 2
}
}}
@ -1836,7 +1840,8 @@ setup:
},
"analysis": {
"regression": {
"dependent_variable": "foo"
"dependent_variable": "foo",
"randomize_seed": 42
}
}
}
@ -1848,7 +1853,8 @@ setup:
"regression":{
"dependent_variable": "foo",
"prediction_field_name": "foo_prediction",
"training_percent": 100.0
"training_percent": 100.0,
"randomize_seed": 42
}
}}
- is_true: create_time