[7.x][ML] Data frame analytics analysis stats (#53788) (#53844)

Adds parsing and indexing of analysis instrumentation stats.
The latest one is also returned from the get-stats API.

Note that we chose to duplicate objects even where they are currently
similar. There are already ideas on how these will diverge in the future
and while the duplication looks ugly at the moment, it is the option
that offers the highest flexibility.

Backport of #53788
This commit is contained in:
Dimitris Athanasiou 2020-03-20 12:11:53 +02:00 committed by GitHub
parent f7143b8d85
commit 60153c5433
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
72 changed files with 4981 additions and 54 deletions

View File

@ -20,12 +20,15 @@
package org.elasticsearch.client.ml.dataframe;
import org.elasticsearch.client.ml.NodeAttributes;
import org.elasticsearch.client.ml.dataframe.stats.AnalysisStats;
import org.elasticsearch.client.ml.dataframe.stats.common.MemoryUsage;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.inject.internal.ToStringBuilder;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentParserUtils;
import java.io.IOException;
import java.util.List;
@ -45,6 +48,7 @@ public class DataFrameAnalyticsStats {
static final ParseField FAILURE_REASON = new ParseField("failure_reason");
static final ParseField PROGRESS = new ParseField("progress");
static final ParseField MEMORY_USAGE = new ParseField("memory_usage");
static final ParseField ANALYSIS_STATS = new ParseField("analysis_stats");
static final ParseField NODE = new ParseField("node");
static final ParseField ASSIGNMENT_EXPLANATION = new ParseField("assignment_explanation");
@ -57,8 +61,9 @@ public class DataFrameAnalyticsStats {
(String) args[2],
(List<PhaseProgress>) args[3],
(MemoryUsage) args[4],
(NodeAttributes) args[5],
(String) args[6]));
(AnalysisStats) args[5],
(NodeAttributes) args[6],
(String) args[7]));
static {
PARSER.declareString(constructorArg(), ID);
@ -71,26 +76,38 @@ public class DataFrameAnalyticsStats {
PARSER.declareString(optionalConstructorArg(), FAILURE_REASON);
PARSER.declareObjectArray(optionalConstructorArg(), PhaseProgress.PARSER, PROGRESS);
PARSER.declareObject(optionalConstructorArg(), MemoryUsage.PARSER, MEMORY_USAGE);
PARSER.declareObject(optionalConstructorArg(), (p, c) -> parseAnalysisStats(p), ANALYSIS_STATS);
PARSER.declareObject(optionalConstructorArg(), NodeAttributes.PARSER, NODE);
PARSER.declareString(optionalConstructorArg(), ASSIGNMENT_EXPLANATION);
}
private static AnalysisStats parseAnalysisStats(XContentParser parser) throws IOException {
XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser::getTokenLocation);
XContentParserUtils.ensureExpectedToken(XContentParser.Token.FIELD_NAME, parser.nextToken(), parser::getTokenLocation);
AnalysisStats analysisStats = parser.namedObject(AnalysisStats.class, parser.currentName(), true);
XContentParserUtils.ensureExpectedToken(XContentParser.Token.END_OBJECT, parser.nextToken(), parser::getTokenLocation);
return analysisStats;
}
private final String id;
private final DataFrameAnalyticsState state;
private final String failureReason;
private final List<PhaseProgress> progress;
private final MemoryUsage memoryUsage;
private final AnalysisStats analysisStats;
private final NodeAttributes node;
private final String assignmentExplanation;
public DataFrameAnalyticsStats(String id, DataFrameAnalyticsState state, @Nullable String failureReason,
@Nullable List<PhaseProgress> progress, @Nullable MemoryUsage memoryUsage,
@Nullable NodeAttributes node, @Nullable String assignmentExplanation) {
@Nullable AnalysisStats analysisStats, @Nullable NodeAttributes node,
@Nullable String assignmentExplanation) {
this.id = id;
this.state = state;
this.failureReason = failureReason;
this.progress = progress;
this.memoryUsage = memoryUsage;
this.analysisStats = analysisStats;
this.node = node;
this.assignmentExplanation = assignmentExplanation;
}
@ -116,6 +133,11 @@ public class DataFrameAnalyticsStats {
return memoryUsage;
}
@Nullable
public AnalysisStats getAnalysisStats() {
return analysisStats;
}
public NodeAttributes getNode() {
return node;
}
@ -135,13 +157,14 @@ public class DataFrameAnalyticsStats {
&& Objects.equals(failureReason, other.failureReason)
&& Objects.equals(progress, other.progress)
&& Objects.equals(memoryUsage, other.memoryUsage)
&& Objects.equals(analysisStats, other.analysisStats)
&& Objects.equals(node, other.node)
&& Objects.equals(assignmentExplanation, other.assignmentExplanation);
}
@Override
public int hashCode() {
return Objects.hash(id, state, failureReason, progress, memoryUsage, node, assignmentExplanation);
return Objects.hash(id, state, failureReason, progress, memoryUsage, analysisStats, node, assignmentExplanation);
}
@Override
@ -152,6 +175,7 @@ public class DataFrameAnalyticsStats {
.add("failureReason", failureReason)
.add("progress", progress)
.add("memoryUsage", memoryUsage)
.add("analysisStats", analysisStats)
.add("node", node)
.add("assignmentExplanation", assignmentExplanation)
.toString();

View File

@ -0,0 +1,29 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe.stats;
import org.elasticsearch.common.xcontent.ToXContentObject;
/**
* Statistics for the data frame analysis
*/
public interface AnalysisStats extends ToXContentObject {
String getName();
}

View File

@ -0,0 +1,52 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe.stats;
import org.elasticsearch.client.ml.dataframe.stats.classification.ClassificationStats;
import org.elasticsearch.client.ml.dataframe.stats.outlierdetection.OutlierDetectionStats;
import org.elasticsearch.client.ml.dataframe.stats.regression.RegressionStats;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.plugins.spi.NamedXContentProvider;
import java.util.Arrays;
import java.util.List;
public class AnalysisStatsNamedXContentProvider implements NamedXContentProvider {
@Override
public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
return Arrays.asList(
new NamedXContentRegistry.Entry(
AnalysisStats.class,
ClassificationStats.NAME,
(p, c) -> ClassificationStats.PARSER.apply(p, null)
),
new NamedXContentRegistry.Entry(
AnalysisStats.class,
OutlierDetectionStats.NAME,
(p, c) -> OutlierDetectionStats.PARSER.apply(p, null)
),
new NamedXContentRegistry.Entry(
AnalysisStats.class,
RegressionStats.NAME,
(p, c) -> RegressionStats.PARSER.apply(p, null)
)
);
}
}

View File

@ -0,0 +1,135 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe.stats.classification;
import org.elasticsearch.client.common.TimeUtil;
import org.elasticsearch.client.ml.dataframe.stats.AnalysisStats;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import java.io.IOException;
import java.time.Instant;
import java.util.Objects;
public class ClassificationStats implements AnalysisStats {
public static final ParseField NAME = new ParseField("classification_stats");
public static final ParseField TIMESTAMP = new ParseField("timestamp");
public static final ParseField ITERATION = new ParseField("iteration");
public static final ParseField HYPERPARAMETERS = new ParseField("hyperparameters");
public static final ParseField TIMING_STATS = new ParseField("timing_stats");
public static final ParseField VALIDATION_LOSS = new ParseField("validation_loss");
public static final ConstructingObjectParser<ClassificationStats, Void> PARSER = new ConstructingObjectParser<>(NAME.getPreferredName(),
true,
a -> new ClassificationStats(
(Instant) a[0],
(Integer) a[1],
(Hyperparameters) a[2],
(TimingStats) a[3],
(ValidationLoss) a[4]
)
);
static {
PARSER.declareField(ConstructingObjectParser.constructorArg(),
p -> TimeUtil.parseTimeFieldToInstant(p, TIMESTAMP.getPreferredName()),
TIMESTAMP,
ObjectParser.ValueType.VALUE);
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), ITERATION);
PARSER.declareObject(ConstructingObjectParser.constructorArg(), Hyperparameters.PARSER, HYPERPARAMETERS);
PARSER.declareObject(ConstructingObjectParser.constructorArg(), TimingStats.PARSER, TIMING_STATS);
PARSER.declareObject(ConstructingObjectParser.constructorArg(), ValidationLoss.PARSER, VALIDATION_LOSS);
}
private final Instant timestamp;
private final Integer iteration;
private final Hyperparameters hyperparameters;
private final TimingStats timingStats;
private final ValidationLoss validationLoss;
public ClassificationStats(Instant timestamp, Integer iteration, Hyperparameters hyperparameters, TimingStats timingStats,
ValidationLoss validationLoss) {
this.timestamp = Instant.ofEpochMilli(Objects.requireNonNull(timestamp).toEpochMilli());
this.iteration = iteration;
this.hyperparameters = Objects.requireNonNull(hyperparameters);
this.timingStats = Objects.requireNonNull(timingStats);
this.validationLoss = Objects.requireNonNull(validationLoss);
}
public Instant getTimestamp() {
return timestamp;
}
public Integer getIteration() {
return iteration;
}
public Hyperparameters getHyperparameters() {
return hyperparameters;
}
public TimingStats getTimingStats() {
return timingStats;
}
public ValidationLoss getValidationLoss() {
return validationLoss;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
builder.startObject();
builder.timeField(TIMESTAMP.getPreferredName(), TIMESTAMP.getPreferredName() + "_string", timestamp.toEpochMilli());
if (iteration != null) {
builder.field(ITERATION.getPreferredName(), iteration);
}
builder.field(HYPERPARAMETERS.getPreferredName(), hyperparameters);
builder.field(TIMING_STATS.getPreferredName(), timingStats);
builder.field(VALIDATION_LOSS.getPreferredName(), validationLoss);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ClassificationStats that = (ClassificationStats) o;
return Objects.equals(timestamp, that.timestamp)
&& Objects.equals(iteration, that.iteration)
&& Objects.equals(hyperparameters, that.hyperparameters)
&& Objects.equals(timingStats, that.timingStats)
&& Objects.equals(validationLoss, that.validationLoss);
}
@Override
public int hashCode() {
return Objects.hash(timestamp, iteration, hyperparameters, timingStats, validationLoss);
}
@Override
public String getName() {
return NAME.getPreferredName();
}
}

View File

@ -0,0 +1,293 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe.stats.classification;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import java.io.IOException;
import java.util.Objects;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
public class Hyperparameters implements ToXContentObject {
public static final ParseField CLASS_ASSIGNMENT_OBJECTIVE = new ParseField("class_assignment_objective");
public static final ParseField DOWNSAMPLE_FACTOR = new ParseField("downsample_factor");
public static final ParseField ETA = new ParseField("eta");
public static final ParseField ETA_GROWTH_RATE_PER_TREE = new ParseField("eta_growth_rate_per_tree");
public static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction");
public static final ParseField MAX_ATTEMPTS_TO_ADD_TREE = new ParseField("max_attempts_to_add_tree");
public static final ParseField MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER = new ParseField(
"max_optimization_rounds_per_hyperparameter");
public static final ParseField MAX_TREES = new ParseField("max_trees");
public static final ParseField NUM_FOLDS = new ParseField("num_folds");
public static final ParseField NUM_SPLITS_PER_FEATURE = new ParseField("num_splits_per_feature");
public static final ParseField REGULARIZATION_DEPTH_PENALTY_MULTIPLIER = new ParseField("regularization_depth_penalty_multiplier");
public static final ParseField REGULARIZATION_LEAF_WEIGHT_PENALTY_MULTIPLIER
= new ParseField("regularization_leaf_weight_penalty_multiplier");
public static final ParseField REGULARIZATION_SOFT_TREE_DEPTH_LIMIT = new ParseField("regularization_soft_tree_depth_limit");
public static final ParseField REGULARIZATION_SOFT_TREE_DEPTH_TOLERANCE = new ParseField("regularization_soft_tree_depth_tolerance");
public static final ParseField REGULARIZATION_TREE_SIZE_PENALTY_MULTIPLIER =
new ParseField("regularization_tree_size_penalty_multiplier");
public static ConstructingObjectParser<Hyperparameters, Void> PARSER = new ConstructingObjectParser<>("classification_hyperparameters",
true,
a -> new Hyperparameters(
(String) a[0],
(Double) a[1],
(Double) a[2],
(Double) a[3],
(Double) a[4],
(Integer) a[5],
(Integer) a[6],
(Integer) a[7],
(Integer) a[8],
(Integer) a[9],
(Double) a[10],
(Double) a[11],
(Double) a[12],
(Double) a[13],
(Double) a[14]
));
static {
PARSER.declareString(optionalConstructorArg(), CLASS_ASSIGNMENT_OBJECTIVE);
PARSER.declareDouble(optionalConstructorArg(), DOWNSAMPLE_FACTOR);
PARSER.declareDouble(optionalConstructorArg(), ETA);
PARSER.declareDouble(optionalConstructorArg(), ETA_GROWTH_RATE_PER_TREE);
PARSER.declareDouble(optionalConstructorArg(), FEATURE_BAG_FRACTION);
PARSER.declareInt(optionalConstructorArg(), MAX_ATTEMPTS_TO_ADD_TREE);
PARSER.declareInt(optionalConstructorArg(), MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER);
PARSER.declareInt(optionalConstructorArg(), MAX_TREES);
PARSER.declareInt(optionalConstructorArg(), NUM_FOLDS);
PARSER.declareInt(optionalConstructorArg(), NUM_SPLITS_PER_FEATURE);
PARSER.declareDouble(optionalConstructorArg(), REGULARIZATION_DEPTH_PENALTY_MULTIPLIER);
PARSER.declareDouble(optionalConstructorArg(), REGULARIZATION_LEAF_WEIGHT_PENALTY_MULTIPLIER);
PARSER.declareDouble(optionalConstructorArg(), REGULARIZATION_SOFT_TREE_DEPTH_LIMIT);
PARSER.declareDouble(optionalConstructorArg(), REGULARIZATION_SOFT_TREE_DEPTH_TOLERANCE);
PARSER.declareDouble(optionalConstructorArg(), REGULARIZATION_TREE_SIZE_PENALTY_MULTIPLIER);
}
private final String classAssignmentObjective;
private final Double downsampleFactor;
private final Double eta;
private final Double etaGrowthRatePerTree;
private final Double featureBagFraction;
private final Integer maxAttemptsToAddTree;
private final Integer maxOptimizationRoundsPerHyperparameter;
private final Integer maxTrees;
private final Integer numFolds;
private final Integer numSplitsPerFeature;
private final Double regularizationDepthPenaltyMultiplier;
private final Double regularizationLeafWeightPenaltyMultiplier;
private final Double regularizationSoftTreeDepthLimit;
private final Double regularizationSoftTreeDepthTolerance;
private final Double regularizationTreeSizePenaltyMultiplier;
public Hyperparameters(String classAssignmentObjective,
Double downsampleFactor,
Double eta,
Double etaGrowthRatePerTree,
Double featureBagFraction,
Integer maxAttemptsToAddTree,
Integer maxOptimizationRoundsPerHyperparameter,
Integer maxTrees,
Integer numFolds,
Integer numSplitsPerFeature,
Double regularizationDepthPenaltyMultiplier,
Double regularizationLeafWeightPenaltyMultiplier,
Double regularizationSoftTreeDepthLimit,
Double regularizationSoftTreeDepthTolerance,
Double regularizationTreeSizePenaltyMultiplier) {
this.classAssignmentObjective = classAssignmentObjective;
this.downsampleFactor = downsampleFactor;
this.eta = eta;
this.etaGrowthRatePerTree = etaGrowthRatePerTree;
this.featureBagFraction = featureBagFraction;
this.maxAttemptsToAddTree = maxAttemptsToAddTree;
this.maxOptimizationRoundsPerHyperparameter = maxOptimizationRoundsPerHyperparameter;
this.maxTrees = maxTrees;
this.numFolds = numFolds;
this.numSplitsPerFeature = numSplitsPerFeature;
this.regularizationDepthPenaltyMultiplier = regularizationDepthPenaltyMultiplier;
this.regularizationLeafWeightPenaltyMultiplier = regularizationLeafWeightPenaltyMultiplier;
this.regularizationSoftTreeDepthLimit = regularizationSoftTreeDepthLimit;
this.regularizationSoftTreeDepthTolerance = regularizationSoftTreeDepthTolerance;
this.regularizationTreeSizePenaltyMultiplier = regularizationTreeSizePenaltyMultiplier;
}
public String getClassAssignmentObjective() {
return classAssignmentObjective;
}
public Double getDownsampleFactor() {
return downsampleFactor;
}
public Double getEta() {
return eta;
}
public Double getEtaGrowthRatePerTree() {
return etaGrowthRatePerTree;
}
public Double getFeatureBagFraction() {
return featureBagFraction;
}
public Integer getMaxAttemptsToAddTree() {
return maxAttemptsToAddTree;
}
public Integer getMaxOptimizationRoundsPerHyperparameter() {
return maxOptimizationRoundsPerHyperparameter;
}
public Integer getMaxTrees() {
return maxTrees;
}
public Integer getNumFolds() {
return numFolds;
}
public Integer getNumSplitsPerFeature() {
return numSplitsPerFeature;
}
public Double getRegularizationDepthPenaltyMultiplier() {
return regularizationDepthPenaltyMultiplier;
}
public Double getRegularizationLeafWeightPenaltyMultiplier() {
return regularizationLeafWeightPenaltyMultiplier;
}
public Double getRegularizationSoftTreeDepthLimit() {
return regularizationSoftTreeDepthLimit;
}
public Double getRegularizationSoftTreeDepthTolerance() {
return regularizationSoftTreeDepthTolerance;
}
public Double getRegularizationTreeSizePenaltyMultiplier() {
return regularizationTreeSizePenaltyMultiplier;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (classAssignmentObjective != null) {
builder.field(CLASS_ASSIGNMENT_OBJECTIVE.getPreferredName(), classAssignmentObjective);
}
if (downsampleFactor != null) {
builder.field(DOWNSAMPLE_FACTOR.getPreferredName(), downsampleFactor);
}
if (eta != null) {
builder.field(ETA.getPreferredName(), eta);
}
if (etaGrowthRatePerTree != null) {
builder.field(ETA_GROWTH_RATE_PER_TREE.getPreferredName(), etaGrowthRatePerTree);
}
if (featureBagFraction != null) {
builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction);
}
if (maxAttemptsToAddTree != null) {
builder.field(MAX_ATTEMPTS_TO_ADD_TREE.getPreferredName(), maxAttemptsToAddTree);
}
if (maxOptimizationRoundsPerHyperparameter != null) {
builder.field(MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER.getPreferredName(), maxOptimizationRoundsPerHyperparameter);
}
if (maxTrees != null) {
builder.field(MAX_TREES.getPreferredName(), maxTrees);
}
if (numFolds != null) {
builder.field(NUM_FOLDS.getPreferredName(), numFolds);
}
if (numSplitsPerFeature != null) {
builder.field(NUM_SPLITS_PER_FEATURE.getPreferredName(), numSplitsPerFeature);
}
if (regularizationDepthPenaltyMultiplier != null) {
builder.field(REGULARIZATION_DEPTH_PENALTY_MULTIPLIER.getPreferredName(), regularizationDepthPenaltyMultiplier);
}
if (regularizationLeafWeightPenaltyMultiplier != null) {
builder.field(REGULARIZATION_LEAF_WEIGHT_PENALTY_MULTIPLIER.getPreferredName(), regularizationLeafWeightPenaltyMultiplier);
}
if (regularizationSoftTreeDepthLimit != null) {
builder.field(REGULARIZATION_SOFT_TREE_DEPTH_LIMIT.getPreferredName(), regularizationSoftTreeDepthLimit);
}
if (regularizationSoftTreeDepthTolerance != null) {
builder.field(REGULARIZATION_SOFT_TREE_DEPTH_TOLERANCE.getPreferredName(), regularizationSoftTreeDepthTolerance);
}
if (regularizationTreeSizePenaltyMultiplier != null) {
builder.field(REGULARIZATION_TREE_SIZE_PENALTY_MULTIPLIER.getPreferredName(), regularizationTreeSizePenaltyMultiplier);
}
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Hyperparameters that = (Hyperparameters) o;
return Objects.equals(classAssignmentObjective, that.classAssignmentObjective)
&& Objects.equals(downsampleFactor, that.downsampleFactor)
&& Objects.equals(eta, that.eta)
&& Objects.equals(etaGrowthRatePerTree, that.etaGrowthRatePerTree)
&& Objects.equals(featureBagFraction, that.featureBagFraction)
&& Objects.equals(maxAttemptsToAddTree, that.maxAttemptsToAddTree)
&& Objects.equals(maxOptimizationRoundsPerHyperparameter, that.maxOptimizationRoundsPerHyperparameter)
&& Objects.equals(maxTrees, that.maxTrees)
&& Objects.equals(numFolds, that.numFolds)
&& Objects.equals(numSplitsPerFeature, that.numSplitsPerFeature)
&& Objects.equals(regularizationDepthPenaltyMultiplier, that.regularizationDepthPenaltyMultiplier)
&& Objects.equals(regularizationLeafWeightPenaltyMultiplier, that.regularizationLeafWeightPenaltyMultiplier)
&& Objects.equals(regularizationSoftTreeDepthLimit, that.regularizationSoftTreeDepthLimit)
&& Objects.equals(regularizationSoftTreeDepthTolerance, that.regularizationSoftTreeDepthTolerance)
&& Objects.equals(regularizationTreeSizePenaltyMultiplier, that.regularizationTreeSizePenaltyMultiplier);
}
@Override
public int hashCode() {
return Objects.hash(
classAssignmentObjective,
downsampleFactor,
eta,
etaGrowthRatePerTree,
featureBagFraction,
maxAttemptsToAddTree,
maxOptimizationRoundsPerHyperparameter,
maxTrees,
numFolds,
numSplitsPerFeature,
regularizationDepthPenaltyMultiplier,
regularizationLeafWeightPenaltyMultiplier,
regularizationSoftTreeDepthLimit,
regularizationSoftTreeDepthTolerance,
regularizationTreeSizePenaltyMultiplier
);
}
}

View File

@ -0,0 +1,87 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe.stats.classification;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import java.io.IOException;
import java.util.Objects;
public class TimingStats implements ToXContentObject {
public static final ParseField ELAPSED_TIME = new ParseField("elapsed_time");
public static final ParseField ITERATION_TIME = new ParseField("iteration_time");
public static ConstructingObjectParser<TimingStats, Void> PARSER = new ConstructingObjectParser<>("classification_timing_stats", true,
a -> new TimingStats(
a[0] == null ? null : TimeValue.timeValueMillis((long) a[0]),
a[1] == null ? null : TimeValue.timeValueMillis((long) a[1])
));
static {
PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), ELAPSED_TIME);
PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), ITERATION_TIME);
}
private final TimeValue elapsedTime;
private final TimeValue iterationTime;
public TimingStats(TimeValue elapsedTime, TimeValue iterationTime) {
this.elapsedTime = elapsedTime;
this.iterationTime = iterationTime;
}
public TimeValue getElapsedTime() {
return elapsedTime;
}
public TimeValue getIterationTime() {
return iterationTime;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (elapsedTime != null) {
builder.humanReadableField(ELAPSED_TIME.getPreferredName(), ELAPSED_TIME.getPreferredName() + "_string", elapsedTime);
}
if (iterationTime != null) {
builder.humanReadableField(ITERATION_TIME.getPreferredName(), ITERATION_TIME.getPreferredName() + "_string", iterationTime);
}
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
TimingStats that = (TimingStats) o;
return Objects.equals(elapsedTime, that.elapsedTime) && Objects.equals(iterationTime, that.iterationTime);
}
@Override
public int hashCode() {
return Objects.hash(elapsedTime, iterationTime);
}
}

View File

@ -0,0 +1,87 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe.stats.classification;
import org.elasticsearch.client.ml.dataframe.stats.common.FoldValues;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import java.io.IOException;
import java.util.List;
import java.util.Objects;
public class ValidationLoss implements ToXContentObject {
public static final ParseField LOSS_TYPE = new ParseField("loss_type");
public static final ParseField FOLD_VALUES = new ParseField("fold_values");
@SuppressWarnings("unchecked")
public static ConstructingObjectParser<ValidationLoss, Void> PARSER = new ConstructingObjectParser<>("classification_validation_loss",
true,
a -> new ValidationLoss((String) a[0], (List<FoldValues>) a[1]));
static {
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), LOSS_TYPE);
PARSER.declareObjectArray(ConstructingObjectParser.optionalConstructorArg(), FoldValues.PARSER, FOLD_VALUES);
}
private final String lossType;
private final List<FoldValues> foldValues;
public ValidationLoss(String lossType, List<FoldValues> values) {
this.lossType = lossType;
this.foldValues = values;
}
public String getLossType() {
return lossType;
}
public List<FoldValues> getFoldValues() {
return foldValues;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (lossType != null) {
builder.field(LOSS_TYPE.getPreferredName(), lossType);
}
if (foldValues != null) {
builder.field(FOLD_VALUES.getPreferredName(), foldValues);
}
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ValidationLoss that = (ValidationLoss) o;
return Objects.equals(lossType, that.lossType) && Objects.equals(foldValues, that.foldValues);
}
@Override
public int hashCode() {
return Objects.hash(lossType, foldValues);
}
}

View File

@ -0,0 +1,87 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe.stats.common;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
public class FoldValues implements ToXContentObject {
public static final ParseField FOLD = new ParseField("fold");
public static final ParseField VALUES = new ParseField("values");
@SuppressWarnings("unchecked")
public static ConstructingObjectParser<FoldValues, Void> PARSER = new ConstructingObjectParser<>("fold_values", true,
a -> new FoldValues((int) a[0], (List<Double>) a[1]));
static {
PARSER.declareInt(ConstructingObjectParser.constructorArg(), FOLD);
PARSER.declareDoubleArray(ConstructingObjectParser.constructorArg(), VALUES);
}
private final int fold;
private final double[] values;
private FoldValues(int fold, List<Double> values) {
this(fold, values.stream().mapToDouble(Double::doubleValue).toArray());
}
public FoldValues(int fold, double[] values) {
this.fold = fold;
this.values = values;
}
public int getFold() {
return fold;
}
public double[] getValues() {
return values;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(FOLD.getPreferredName(), fold);
builder.array(VALUES.getPreferredName(), values);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (o == this) return true;
if (o == null || getClass() != o.getClass()) return false;
FoldValues other = (FoldValues) o;
return fold == other.fold && Arrays.equals(values, other.values);
}
@Override
public int hashCode() {
return Objects.hash(fold, Arrays.hashCode(values));
}
}

View File

@ -16,7 +16,7 @@
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe;
package org.elasticsearch.client.ml.dataframe.stats.common;
import org.elasticsearch.client.common.TimeUtil;
import org.elasticsearch.common.ParseField;
@ -54,6 +54,14 @@ public class MemoryUsage implements ToXContentObject {
this.peakUsageBytes = peakUsageBytes;
}
public Instant getTimestamp() {
return timestamp;
}
public long getPeakUsageBytes() {
return peakUsageBytes;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();

View File

@ -0,0 +1,105 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe.stats.outlierdetection;
import org.elasticsearch.client.common.TimeUtil;
import org.elasticsearch.client.ml.dataframe.stats.AnalysisStats;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import java.io.IOException;
import java.time.Instant;
import java.util.Objects;
public class OutlierDetectionStats implements AnalysisStats {
public static final ParseField NAME = new ParseField("outlier_detection_stats");
public static final ParseField TIMESTAMP = new ParseField("timestamp");
public static final ParseField PARAMETERS = new ParseField("parameters");
public static final ParseField TIMING_STATS = new ParseField("timings_stats");
public static final ConstructingObjectParser<OutlierDetectionStats, Void> PARSER = new ConstructingObjectParser<>(
NAME.getPreferredName(), true,
a -> new OutlierDetectionStats((Instant) a[0], (Parameters) a[1], (TimingStats) a[2]));
static {
PARSER.declareField(ConstructingObjectParser.constructorArg(),
p -> TimeUtil.parseTimeFieldToInstant(p, TIMESTAMP.getPreferredName()),
TIMESTAMP,
ObjectParser.ValueType.VALUE);
PARSER.declareObject(ConstructingObjectParser.constructorArg(), Parameters.PARSER, PARAMETERS);
PARSER.declareObject(ConstructingObjectParser.constructorArg(), TimingStats.PARSER, TIMING_STATS);
}
private final Instant timestamp;
private final Parameters parameters;
private final TimingStats timingStats;
public OutlierDetectionStats(Instant timestamp, Parameters parameters, TimingStats timingStats) {
this.timestamp = Instant.ofEpochMilli(Objects.requireNonNull(timestamp).toEpochMilli());
this.parameters = Objects.requireNonNull(parameters);
this.timingStats = Objects.requireNonNull(timingStats);
}
public Instant getTimestamp() {
return timestamp;
}
public Parameters getParameters() {
return parameters;
}
public TimingStats getTimingStats() {
return timingStats;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
builder.startObject();
builder.timeField(TIMESTAMP.getPreferredName(), TIMESTAMP.getPreferredName() + "_string", timestamp.toEpochMilli());
builder.field(PARAMETERS.getPreferredName(), parameters);
builder.field(TIMING_STATS.getPreferredName(), timingStats);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
OutlierDetectionStats that = (OutlierDetectionStats) o;
return Objects.equals(timestamp, that.timestamp)
&& Objects.equals(parameters, that.parameters)
&& Objects.equals(timingStats, that.timingStats);
}
@Override
public int hashCode() {
return Objects.hash(timestamp, parameters, timingStats);
}
@Override
public String getName() {
return NAME.getPreferredName();
}
}

View File

@ -0,0 +1,146 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe.stats.outlierdetection;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import java.io.IOException;
import java.util.Objects;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
public class Parameters implements ToXContentObject {
public static final ParseField N_NEIGHBORS = new ParseField("n_neighbors");
public static final ParseField METHOD = new ParseField("method");
public static final ParseField FEATURE_INFLUENCE_THRESHOLD = new ParseField("feature_influence_threshold");
public static final ParseField COMPUTE_FEATURE_INFLUENCE = new ParseField("compute_feature_influence");
public static final ParseField OUTLIER_FRACTION = new ParseField("outlier_fraction");
public static final ParseField STANDARDIZATION_ENABLED = new ParseField("standardization_enabled");
@SuppressWarnings("unchecked")
public static ConstructingObjectParser<Parameters, Void> PARSER = new ConstructingObjectParser<>("outlier_detection_parameters",
true,
a -> new Parameters(
(Integer) a[0],
(String) a[1],
(Boolean) a[2],
(Double) a[3],
(Double) a[4],
(Boolean) a[5]
));
static {
PARSER.declareInt(optionalConstructorArg(), N_NEIGHBORS);
PARSER.declareString(optionalConstructorArg(), METHOD);
PARSER.declareBoolean(optionalConstructorArg(), COMPUTE_FEATURE_INFLUENCE);
PARSER.declareDouble(optionalConstructorArg(), FEATURE_INFLUENCE_THRESHOLD);
PARSER.declareDouble(optionalConstructorArg(), OUTLIER_FRACTION);
PARSER.declareBoolean(optionalConstructorArg(), STANDARDIZATION_ENABLED);
}
private final Integer nNeighbors;
private final String method;
private final Boolean computeFeatureInfluence;
private final Double featureInfluenceThreshold;
private final Double outlierFraction;
private final Boolean standardizationEnabled;
public Parameters(Integer nNeighbors, String method, Boolean computeFeatureInfluence, Double featureInfluenceThreshold,
Double outlierFraction, Boolean standardizationEnabled) {
this.nNeighbors = nNeighbors;
this.method = method;
this.computeFeatureInfluence = computeFeatureInfluence;
this.featureInfluenceThreshold = featureInfluenceThreshold;
this.outlierFraction = outlierFraction;
this.standardizationEnabled = standardizationEnabled;
}
public Integer getnNeighbors() {
return nNeighbors;
}
public String getMethod() {
return method;
}
public Boolean getComputeFeatureInfluence() {
return computeFeatureInfluence;
}
public Double getFeatureInfluenceThreshold() {
return featureInfluenceThreshold;
}
public Double getOutlierFraction() {
return outlierFraction;
}
public Boolean getStandardizationEnabled() {
return standardizationEnabled;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (nNeighbors != null) {
builder.field(N_NEIGHBORS.getPreferredName(), nNeighbors);
}
if (method != null) {
builder.field(METHOD.getPreferredName(), method);
}
if (computeFeatureInfluence != null) {
builder.field(COMPUTE_FEATURE_INFLUENCE.getPreferredName(), computeFeatureInfluence);
}
if (featureInfluenceThreshold != null) {
builder.field(FEATURE_INFLUENCE_THRESHOLD.getPreferredName(), featureInfluenceThreshold);
}
if (outlierFraction != null) {
builder.field(OUTLIER_FRACTION.getPreferredName(), outlierFraction);
}
if (standardizationEnabled != null) {
builder.field(STANDARDIZATION_ENABLED.getPreferredName(), standardizationEnabled);
}
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Parameters that = (Parameters) o;
return Objects.equals(nNeighbors, that.nNeighbors)
&& Objects.equals(method, that.method)
&& Objects.equals(computeFeatureInfluence, that.computeFeatureInfluence)
&& Objects.equals(featureInfluenceThreshold, that.featureInfluenceThreshold)
&& Objects.equals(outlierFraction, that.outlierFraction)
&& Objects.equals(standardizationEnabled, that.standardizationEnabled);
}
@Override
public int hashCode() {
return Objects.hash(nNeighbors, method, computeFeatureInfluence, featureInfluenceThreshold, outlierFraction,
standardizationEnabled);
}
}

View File

@ -0,0 +1,74 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe.stats.outlierdetection;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import java.io.IOException;
import java.util.Objects;
public class TimingStats implements ToXContentObject {
public static final ParseField ELAPSED_TIME = new ParseField("elapsed_time");
public static ConstructingObjectParser<TimingStats, Void> PARSER = new ConstructingObjectParser<>("outlier_detection_timing_stats",
true,
a -> new TimingStats(a[0] == null ? null : TimeValue.timeValueMillis((long) a[0])));
static {
PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), ELAPSED_TIME);
}
private final TimeValue elapsedTime;
public TimingStats(TimeValue elapsedTime) {
this.elapsedTime = elapsedTime;
}
public TimeValue getElapsedTime() {
return elapsedTime;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (elapsedTime != null) {
builder.humanReadableField(ELAPSED_TIME.getPreferredName(), ELAPSED_TIME.getPreferredName() + "_string", elapsedTime);
}
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
TimingStats that = (TimingStats) o;
return Objects.equals(elapsedTime, that.elapsedTime);
}
@Override
public int hashCode() {
return Objects.hash(elapsedTime);
}
}

View File

@ -0,0 +1,278 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe.stats.regression;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import java.io.IOException;
import java.util.Objects;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
public class Hyperparameters implements ToXContentObject {
public static final ParseField DOWNSAMPLE_FACTOR = new ParseField("downsample_factor");
public static final ParseField ETA = new ParseField("eta");
public static final ParseField ETA_GROWTH_RATE_PER_TREE = new ParseField("eta_growth_rate_per_tree");
public static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction");
public static final ParseField MAX_ATTEMPTS_TO_ADD_TREE = new ParseField("max_attempts_to_add_tree");
public static final ParseField MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER = new ParseField(
"max_optimization_rounds_per_hyperparameter");
public static final ParseField MAX_TREES = new ParseField("max_trees");
public static final ParseField NUM_FOLDS = new ParseField("num_folds");
public static final ParseField NUM_SPLITS_PER_FEATURE = new ParseField("num_splits_per_feature");
public static final ParseField REGULARIZATION_DEPTH_PENALTY_MULTIPLIER = new ParseField("regularization_depth_penalty_multiplier");
public static final ParseField REGULARIZATION_LEAF_WEIGHT_PENALTY_MULTIPLIER
= new ParseField("regularization_leaf_weight_penalty_multiplier");
public static final ParseField REGULARIZATION_SOFT_TREE_DEPTH_LIMIT = new ParseField("regularization_soft_tree_depth_limit");
public static final ParseField REGULARIZATION_SOFT_TREE_DEPTH_TOLERANCE = new ParseField("regularization_soft_tree_depth_tolerance");
public static final ParseField REGULARIZATION_TREE_SIZE_PENALTY_MULTIPLIER =
new ParseField("regularization_tree_size_penalty_multiplier");
public static ConstructingObjectParser<Hyperparameters, Void> PARSER = new ConstructingObjectParser<>("regression_hyperparameters",
true,
a -> new Hyperparameters(
(Double) a[0],
(Double) a[1],
(Double) a[2],
(Double) a[3],
(Integer) a[4],
(Integer) a[5],
(Integer) a[6],
(Integer) a[7],
(Integer) a[8],
(Double) a[9],
(Double) a[10],
(Double) a[11],
(Double) a[12],
(Double) a[13]
));
static {
PARSER.declareDouble(optionalConstructorArg(), DOWNSAMPLE_FACTOR);
PARSER.declareDouble(optionalConstructorArg(), ETA);
PARSER.declareDouble(optionalConstructorArg(), ETA_GROWTH_RATE_PER_TREE);
PARSER.declareDouble(optionalConstructorArg(), FEATURE_BAG_FRACTION);
PARSER.declareInt(optionalConstructorArg(), MAX_ATTEMPTS_TO_ADD_TREE);
PARSER.declareInt(optionalConstructorArg(), MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER);
PARSER.declareInt(optionalConstructorArg(), MAX_TREES);
PARSER.declareInt(optionalConstructorArg(), NUM_FOLDS);
PARSER.declareInt(optionalConstructorArg(), NUM_SPLITS_PER_FEATURE);
PARSER.declareDouble(optionalConstructorArg(), REGULARIZATION_DEPTH_PENALTY_MULTIPLIER);
PARSER.declareDouble(optionalConstructorArg(), REGULARIZATION_LEAF_WEIGHT_PENALTY_MULTIPLIER);
PARSER.declareDouble(optionalConstructorArg(), REGULARIZATION_SOFT_TREE_DEPTH_LIMIT);
PARSER.declareDouble(optionalConstructorArg(), REGULARIZATION_SOFT_TREE_DEPTH_TOLERANCE);
PARSER.declareDouble(optionalConstructorArg(), REGULARIZATION_TREE_SIZE_PENALTY_MULTIPLIER);
}
private final Double downsampleFactor;
private final Double eta;
private final Double etaGrowthRatePerTree;
private final Double featureBagFraction;
private final Integer maxAttemptsToAddTree;
private final Integer maxOptimizationRoundsPerHyperparameter;
private final Integer maxTrees;
private final Integer numFolds;
private final Integer numSplitsPerFeature;
private final Double regularizationDepthPenaltyMultiplier;
private final Double regularizationLeafWeightPenaltyMultiplier;
private final Double regularizationSoftTreeDepthLimit;
private final Double regularizationSoftTreeDepthTolerance;
private final Double regularizationTreeSizePenaltyMultiplier;
public Hyperparameters(Double downsampleFactor,
Double eta,
Double etaGrowthRatePerTree,
Double featureBagFraction,
Integer maxAttemptsToAddTree,
Integer maxOptimizationRoundsPerHyperparameter,
Integer maxTrees,
Integer numFolds,
Integer numSplitsPerFeature,
Double regularizationDepthPenaltyMultiplier,
Double regularizationLeafWeightPenaltyMultiplier,
Double regularizationSoftTreeDepthLimit,
Double regularizationSoftTreeDepthTolerance,
Double regularizationTreeSizePenaltyMultiplier) {
this.downsampleFactor = downsampleFactor;
this.eta = eta;
this.etaGrowthRatePerTree = etaGrowthRatePerTree;
this.featureBagFraction = featureBagFraction;
this.maxAttemptsToAddTree = maxAttemptsToAddTree;
this.maxOptimizationRoundsPerHyperparameter = maxOptimizationRoundsPerHyperparameter;
this.maxTrees = maxTrees;
this.numFolds = numFolds;
this.numSplitsPerFeature = numSplitsPerFeature;
this.regularizationDepthPenaltyMultiplier = regularizationDepthPenaltyMultiplier;
this.regularizationLeafWeightPenaltyMultiplier = regularizationLeafWeightPenaltyMultiplier;
this.regularizationSoftTreeDepthLimit = regularizationSoftTreeDepthLimit;
this.regularizationSoftTreeDepthTolerance = regularizationSoftTreeDepthTolerance;
this.regularizationTreeSizePenaltyMultiplier = regularizationTreeSizePenaltyMultiplier;
}
public Double getDownsampleFactor() {
return downsampleFactor;
}
public Double getEta() {
return eta;
}
public Double getEtaGrowthRatePerTree() {
return etaGrowthRatePerTree;
}
public Double getFeatureBagFraction() {
return featureBagFraction;
}
public Integer getMaxAttemptsToAddTree() {
return maxAttemptsToAddTree;
}
public Integer getMaxOptimizationRoundsPerHyperparameter() {
return maxOptimizationRoundsPerHyperparameter;
}
public Integer getMaxTrees() {
return maxTrees;
}
public Integer getNumFolds() {
return numFolds;
}
public Integer getNumSplitsPerFeature() {
return numSplitsPerFeature;
}
public Double getRegularizationDepthPenaltyMultiplier() {
return regularizationDepthPenaltyMultiplier;
}
public Double getRegularizationLeafWeightPenaltyMultiplier() {
return regularizationLeafWeightPenaltyMultiplier;
}
public Double getRegularizationSoftTreeDepthLimit() {
return regularizationSoftTreeDepthLimit;
}
public Double getRegularizationSoftTreeDepthTolerance() {
return regularizationSoftTreeDepthTolerance;
}
public Double getRegularizationTreeSizePenaltyMultiplier() {
return regularizationTreeSizePenaltyMultiplier;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (downsampleFactor != null) {
builder.field(DOWNSAMPLE_FACTOR.getPreferredName(), downsampleFactor);
}
if (eta != null) {
builder.field(ETA.getPreferredName(), eta);
}
if (etaGrowthRatePerTree != null) {
builder.field(ETA_GROWTH_RATE_PER_TREE.getPreferredName(), etaGrowthRatePerTree);
}
if (featureBagFraction != null) {
builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction);
}
if (maxAttemptsToAddTree != null) {
builder.field(MAX_ATTEMPTS_TO_ADD_TREE.getPreferredName(), maxAttemptsToAddTree);
}
if (maxOptimizationRoundsPerHyperparameter != null) {
builder.field(MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER.getPreferredName(), maxOptimizationRoundsPerHyperparameter);
}
if (maxTrees != null) {
builder.field(MAX_TREES.getPreferredName(), maxTrees);
}
if (numFolds != null) {
builder.field(NUM_FOLDS.getPreferredName(), numFolds);
}
if (numSplitsPerFeature != null) {
builder.field(NUM_SPLITS_PER_FEATURE.getPreferredName(), numSplitsPerFeature);
}
if (regularizationDepthPenaltyMultiplier != null) {
builder.field(REGULARIZATION_DEPTH_PENALTY_MULTIPLIER.getPreferredName(), regularizationDepthPenaltyMultiplier);
}
if (regularizationLeafWeightPenaltyMultiplier != null) {
builder.field(REGULARIZATION_LEAF_WEIGHT_PENALTY_MULTIPLIER.getPreferredName(), regularizationLeafWeightPenaltyMultiplier);
}
if (regularizationSoftTreeDepthLimit != null) {
builder.field(REGULARIZATION_SOFT_TREE_DEPTH_LIMIT.getPreferredName(), regularizationSoftTreeDepthLimit);
}
if (regularizationSoftTreeDepthTolerance != null) {
builder.field(REGULARIZATION_SOFT_TREE_DEPTH_TOLERANCE.getPreferredName(), regularizationSoftTreeDepthTolerance);
}
if (regularizationTreeSizePenaltyMultiplier != null) {
builder.field(REGULARIZATION_TREE_SIZE_PENALTY_MULTIPLIER.getPreferredName(), regularizationTreeSizePenaltyMultiplier);
}
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Hyperparameters that = (Hyperparameters) o;
return Objects.equals(downsampleFactor, that.downsampleFactor)
&& Objects.equals(eta, that.eta)
&& Objects.equals(etaGrowthRatePerTree, that.etaGrowthRatePerTree)
&& Objects.equals(featureBagFraction, that.featureBagFraction)
&& Objects.equals(maxAttemptsToAddTree, that.maxAttemptsToAddTree)
&& Objects.equals(maxOptimizationRoundsPerHyperparameter, that.maxOptimizationRoundsPerHyperparameter)
&& Objects.equals(maxTrees, that.maxTrees)
&& Objects.equals(numFolds, that.numFolds)
&& Objects.equals(numSplitsPerFeature, that.numSplitsPerFeature)
&& Objects.equals(regularizationDepthPenaltyMultiplier, that.regularizationDepthPenaltyMultiplier)
&& Objects.equals(regularizationLeafWeightPenaltyMultiplier, that.regularizationLeafWeightPenaltyMultiplier)
&& Objects.equals(regularizationSoftTreeDepthLimit, that.regularizationSoftTreeDepthLimit)
&& Objects.equals(regularizationSoftTreeDepthTolerance, that.regularizationSoftTreeDepthTolerance)
&& Objects.equals(regularizationTreeSizePenaltyMultiplier, that.regularizationTreeSizePenaltyMultiplier);
}
@Override
public int hashCode() {
return Objects.hash(
downsampleFactor,
eta,
etaGrowthRatePerTree,
featureBagFraction,
maxAttemptsToAddTree,
maxOptimizationRoundsPerHyperparameter,
maxTrees,
numFolds,
numSplitsPerFeature,
regularizationDepthPenaltyMultiplier,
regularizationLeafWeightPenaltyMultiplier,
regularizationSoftTreeDepthLimit,
regularizationSoftTreeDepthTolerance,
regularizationTreeSizePenaltyMultiplier
);
}
}

View File

@ -0,0 +1,135 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe.stats.regression;
import org.elasticsearch.client.common.TimeUtil;
import org.elasticsearch.client.ml.dataframe.stats.AnalysisStats;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import java.io.IOException;
import java.time.Instant;
import java.util.Objects;
public class RegressionStats implements AnalysisStats {
public static final ParseField NAME = new ParseField("regression_stats");
public static final ParseField TIMESTAMP = new ParseField("timestamp");
public static final ParseField ITERATION = new ParseField("iteration");
public static final ParseField HYPERPARAMETERS = new ParseField("hyperparameters");
public static final ParseField TIMING_STATS = new ParseField("timing_stats");
public static final ParseField VALIDATION_LOSS = new ParseField("validation_loss");
public static final ConstructingObjectParser<RegressionStats, Void> PARSER = new ConstructingObjectParser<>(NAME.getPreferredName(),
true,
a -> new RegressionStats(
(Instant) a[0],
(Integer) a[1],
(Hyperparameters) a[2],
(TimingStats) a[3],
(ValidationLoss) a[4]
)
);
static {
PARSER.declareField(ConstructingObjectParser.constructorArg(),
p -> TimeUtil.parseTimeFieldToInstant(p, TIMESTAMP.getPreferredName()),
TIMESTAMP,
ObjectParser.ValueType.VALUE);
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), ITERATION);
PARSER.declareObject(ConstructingObjectParser.constructorArg(), Hyperparameters.PARSER, HYPERPARAMETERS);
PARSER.declareObject(ConstructingObjectParser.constructorArg(), TimingStats.PARSER, TIMING_STATS);
PARSER.declareObject(ConstructingObjectParser.constructorArg(), ValidationLoss.PARSER, VALIDATION_LOSS);
}
private final Instant timestamp;
private final Integer iteration;
private final Hyperparameters hyperparameters;
private final TimingStats timingStats;
private final ValidationLoss validationLoss;
public RegressionStats(Instant timestamp, Integer iteration, Hyperparameters hyperparameters, TimingStats timingStats,
ValidationLoss validationLoss) {
this.timestamp = Instant.ofEpochMilli(Objects.requireNonNull(timestamp).toEpochMilli());
this.iteration = iteration;
this.hyperparameters = Objects.requireNonNull(hyperparameters);
this.timingStats = Objects.requireNonNull(timingStats);
this.validationLoss = Objects.requireNonNull(validationLoss);
}
public Instant getTimestamp() {
return timestamp;
}
public Integer getIteration() {
return iteration;
}
public Hyperparameters getHyperparameters() {
return hyperparameters;
}
public TimingStats getTimingStats() {
return timingStats;
}
public ValidationLoss getValidationLoss() {
return validationLoss;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
builder.startObject();
builder.timeField(TIMESTAMP.getPreferredName(), TIMESTAMP.getPreferredName() + "_string", timestamp.toEpochMilli());
if (iteration != null) {
builder.field(ITERATION.getPreferredName(), iteration);
}
builder.field(HYPERPARAMETERS.getPreferredName(), hyperparameters);
builder.field(TIMING_STATS.getPreferredName(), timingStats);
builder.field(VALIDATION_LOSS.getPreferredName(), validationLoss);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
RegressionStats that = (RegressionStats) o;
return Objects.equals(timestamp, that.timestamp)
&& Objects.equals(iteration, that.iteration)
&& Objects.equals(hyperparameters, that.hyperparameters)
&& Objects.equals(timingStats, that.timingStats)
&& Objects.equals(validationLoss, that.validationLoss);
}
@Override
public int hashCode() {
return Objects.hash(timestamp, iteration, hyperparameters, timingStats, validationLoss);
}
@Override
public String getName() {
return NAME.getPreferredName();
}
}

View File

@ -0,0 +1,87 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe.stats.regression;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import java.io.IOException;
import java.util.Objects;
public class TimingStats implements ToXContentObject {
public static final ParseField ELAPSED_TIME = new ParseField("elapsed_time");
public static final ParseField ITERATION_TIME = new ParseField("iteration_time");
public static ConstructingObjectParser<TimingStats, Void> PARSER = new ConstructingObjectParser<>("regression_timing_stats", true,
a -> new TimingStats(
a[0] == null ? null : TimeValue.timeValueMillis((long) a[0]),
a[1] == null ? null : TimeValue.timeValueMillis((long) a[1])
));
static {
PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), ELAPSED_TIME);
PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), ITERATION_TIME);
}
private final TimeValue elapsedTime;
private final TimeValue iterationTime;
public TimingStats(TimeValue elapsedTime, TimeValue iterationTime) {
this.elapsedTime = elapsedTime;
this.iterationTime = iterationTime;
}
public TimeValue getElapsedTime() {
return elapsedTime;
}
public TimeValue getIterationTime() {
return iterationTime;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (elapsedTime != null) {
builder.humanReadableField(ELAPSED_TIME.getPreferredName(), ELAPSED_TIME.getPreferredName() + "_string", elapsedTime);
}
if (iterationTime != null) {
builder.humanReadableField(ITERATION_TIME.getPreferredName(), ITERATION_TIME.getPreferredName() + "_string", iterationTime);
}
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
TimingStats that = (TimingStats) o;
return Objects.equals(elapsedTime, that.elapsedTime) && Objects.equals(iterationTime, that.iterationTime);
}
@Override
public int hashCode() {
return Objects.hash(elapsedTime, iterationTime);
}
}

View File

@ -0,0 +1,87 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe.stats.regression;
import org.elasticsearch.client.ml.dataframe.stats.common.FoldValues;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import java.io.IOException;
import java.util.List;
import java.util.Objects;
public class ValidationLoss implements ToXContentObject {
public static final ParseField LOSS_TYPE = new ParseField("loss_type");
public static final ParseField FOLD_VALUES = new ParseField("fold_values");
@SuppressWarnings("unchecked")
public static ConstructingObjectParser<ValidationLoss, Void> PARSER = new ConstructingObjectParser<>("regression_validation_loss",
true,
a -> new ValidationLoss((String) a[0], (List<FoldValues>) a[1]));
static {
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), LOSS_TYPE);
PARSER.declareObjectArray(ConstructingObjectParser.optionalConstructorArg(), FoldValues.PARSER, FOLD_VALUES);
}
private final String lossType;
private final List<FoldValues> foldValues;
public ValidationLoss(String lossType, List<FoldValues> values) {
this.lossType = lossType;
this.foldValues = values;
}
public String getLossType() {
return lossType;
}
public List<FoldValues> getFoldValues() {
return foldValues;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (lossType != null) {
builder.field(LOSS_TYPE.getPreferredName(), lossType);
}
if (foldValues != null) {
builder.field(FOLD_VALUES.getPreferredName(), foldValues);
}
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ValidationLoss that = (ValidationLoss) o;
return Objects.equals(lossType, that.lossType) && Objects.equals(foldValues, that.foldValues);
}
@Override
public int hashCode() {
return Objects.hash(lossType, foldValues);
}
}

View File

@ -1,5 +1,6 @@
org.elasticsearch.client.indexlifecycle.IndexLifecycleNamedXContentProvider
org.elasticsearch.client.ml.dataframe.MlDataFrameAnalysisNamedXContentProvider
org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider
org.elasticsearch.client.ml.dataframe.stats.AnalysisStatsNamedXContentProvider
org.elasticsearch.client.ml.inference.MlInferenceNamedXContentProvider
org.elasticsearch.client.transform.TransformNamedXContentProvider

View File

@ -91,6 +91,7 @@ import org.elasticsearch.client.ml.datafeed.DatafeedConfigTests;
import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.client.ml.dataframe.MlDataFrameAnalysisNamedXContentProvider;
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import org.elasticsearch.client.ml.dataframe.stats.AnalysisStatsNamedXContentProvider;
import org.elasticsearch.client.ml.filestructurefinder.FileStructure;
import org.elasticsearch.client.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.client.ml.inference.TrainedModelConfig;
@ -1067,6 +1068,7 @@ public class MLRequestConvertersTests extends ESTestCase {
namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers());
namedXContent.addAll(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
namedXContent.addAll(new AnalysisStatsNamedXContentProvider().getNamedXContentParsers());
return new NamedXContentRegistry(namedXContent);
}

View File

@ -20,7 +20,6 @@
package org.elasticsearch.client;
import com.fasterxml.jackson.core.JsonParseException;
import org.apache.http.HttpEntity;
import org.apache.http.HttpHost;
import org.apache.http.HttpResponse;
@ -69,6 +68,9 @@ import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.Binar
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.PrecisionMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.RecallMetric;
import org.elasticsearch.client.ml.dataframe.stats.classification.ClassificationStats;
import org.elasticsearch.client.ml.dataframe.stats.outlierdetection.OutlierDetectionStats;
import org.elasticsearch.client.ml.dataframe.stats.regression.RegressionStats;
import org.elasticsearch.client.ml.inference.preprocessing.CustomWordEmbedding;
import org.elasticsearch.client.ml.inference.preprocessing.FrequencyEncoding;
import org.elasticsearch.client.ml.inference.preprocessing.OneHotEncoding;
@ -697,7 +699,7 @@ public class RestHighLevelClientTests extends ESTestCase {
public void testProvidedNamedXContents() {
List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
assertEquals(59, namedXContents.size());
assertEquals(62, namedXContents.size());
Map<Class<?>, Integer> categories = new HashMap<>();
List<String> names = new ArrayList<>();
for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
@ -707,7 +709,7 @@ public class RestHighLevelClientTests extends ESTestCase {
categories.put(namedXContent.categoryClass, counter + 1);
}
}
assertEquals("Had: " + categories, 12, categories.size());
assertEquals("Had: " + categories, 13, categories.size());
assertEquals(Integer.valueOf(3), categories.get(Aggregation.class));
assertTrue(names.contains(ChildrenAggregationBuilder.NAME));
assertTrue(names.contains(MatrixStatsAggregationBuilder.NAME));
@ -737,6 +739,9 @@ public class RestHighLevelClientTests extends ESTestCase {
assertTrue(names.contains(OutlierDetection.NAME.getPreferredName()));
assertTrue(names.contains(org.elasticsearch.client.ml.dataframe.Regression.NAME.getPreferredName()));
assertTrue(names.contains(org.elasticsearch.client.ml.dataframe.Classification.NAME.getPreferredName()));
assertTrue(names.contains(OutlierDetectionStats.NAME.getPreferredName()));
assertTrue(names.contains(RegressionStats.NAME.getPreferredName()));
assertTrue(names.contains(ClassificationStats.NAME.getPreferredName()));
assertEquals(Integer.valueOf(1), categories.get(SyncConfig.class));
assertTrue(names.contains(TimeSyncConfig.NAME));
assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class));

View File

@ -20,6 +20,13 @@
package org.elasticsearch.client.ml.dataframe;
import org.elasticsearch.client.ml.NodeAttributesTests;
import org.elasticsearch.client.ml.dataframe.stats.AnalysisStats;
import org.elasticsearch.client.ml.dataframe.stats.AnalysisStatsNamedXContentProvider;
import org.elasticsearch.client.ml.dataframe.stats.classification.ClassificationStatsTests;
import org.elasticsearch.client.ml.dataframe.stats.common.MemoryUsageTests;
import org.elasticsearch.client.ml.dataframe.stats.outlierdetection.OutlierDetectionStatsTests;
import org.elasticsearch.client.ml.dataframe.stats.regression.RegressionStatsTests;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.test.ESTestCase;
@ -31,23 +38,38 @@ import static org.elasticsearch.test.AbstractXContentTestCase.xContentTester;
public class DataFrameAnalyticsStatsTests extends ESTestCase {
@Override
protected NamedXContentRegistry xContentRegistry() {
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
namedXContent.addAll(new AnalysisStatsNamedXContentProvider().getNamedXContentParsers());
return new NamedXContentRegistry(namedXContent);
}
public void testFromXContent() throws IOException {
xContentTester(this::createParser,
DataFrameAnalyticsStatsTests::randomDataFrameAnalyticsStats,
DataFrameAnalyticsStatsTests::toXContent,
DataFrameAnalyticsStats::fromXContent)
.supportsUnknownFields(true)
.randomFieldsExcludeFilter(field -> field.startsWith("node.attributes"))
.randomFieldsExcludeFilter(field -> field.startsWith("node.attributes") || field.startsWith("analysis_stats"))
.test();
}
public static DataFrameAnalyticsStats randomDataFrameAnalyticsStats() {
AnalysisStats analysisStats = randomBoolean() ? null :
randomFrom(
ClassificationStatsTests.createRandom(),
OutlierDetectionStatsTests.createRandom(),
RegressionStatsTests.createRandom()
);
return new DataFrameAnalyticsStats(
randomAlphaOfLengthBetween(1, 10),
randomFrom(DataFrameAnalyticsState.values()),
randomBoolean() ? null : randomAlphaOfLength(10),
randomBoolean() ? null : createRandomProgress(),
randomBoolean() ? null : MemoryUsageTests.createRandom(),
analysisStats,
randomBoolean() ? null : NodeAttributesTests.createRandom(),
randomBoolean() ? null : randomAlphaOfLengthBetween(1, 20));
}
@ -74,6 +96,11 @@ public class DataFrameAnalyticsStatsTests extends ESTestCase {
if (stats.getMemoryUsage() != null) {
builder.field(DataFrameAnalyticsStats.MEMORY_USAGE.getPreferredName(), stats.getMemoryUsage());
}
if (stats.getAnalysisStats() != null) {
builder.startObject("analysis_stats");
builder.field(stats.getAnalysisStats().getName(), stats.getAnalysisStats());
builder.endObject();
}
if (stats.getNode() != null) {
builder.field(DataFrameAnalyticsStats.NODE.getPreferredName(), stats.getNode());
}

View File

@ -0,0 +1,53 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe.stats.classification;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
import java.time.Instant;
public class ClassificationStatsTests extends AbstractXContentTestCase<ClassificationStats> {
@Override
protected boolean supportsUnknownFields() {
return true;
}
@Override
protected ClassificationStats doParseInstance(XContentParser parser) throws IOException {
return ClassificationStats.PARSER.apply(parser, null);
}
@Override
protected ClassificationStats createTestInstance() {
return createRandom();
}
public static ClassificationStats createRandom() {
return new ClassificationStats(
Instant.now(),
randomBoolean() ? null : randomIntBetween(1, Integer.MAX_VALUE),
HyperparametersTests.createRandom(),
TimingStatsTests.createRandom(),
ValidationLossTests.createRandom()
);
}
}

View File

@ -0,0 +1,62 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe.stats.classification;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
public class HyperparametersTests extends AbstractXContentTestCase<Hyperparameters> {
@Override
protected boolean supportsUnknownFields() {
return true;
}
@Override
protected Hyperparameters doParseInstance(XContentParser parser) throws IOException {
return Hyperparameters.PARSER.apply(parser, null);
}
@Override
protected Hyperparameters createTestInstance() {
return createRandom();
}
public static Hyperparameters createRandom() {
return new Hyperparameters(
randomBoolean() ? null : randomAlphaOfLength(10),
randomBoolean() ? null : randomDouble(),
randomBoolean() ? null : randomDouble(),
randomBoolean() ? null : randomDouble(),
randomBoolean() ? null : randomDouble(),
randomBoolean() ? null : randomIntBetween(0, Integer.MAX_VALUE),
randomBoolean() ? null : randomIntBetween(0, Integer.MAX_VALUE),
randomBoolean() ? null : randomIntBetween(0, Integer.MAX_VALUE),
randomBoolean() ? null : randomIntBetween(0, Integer.MAX_VALUE),
randomBoolean() ? null : randomIntBetween(0, Integer.MAX_VALUE),
randomBoolean() ? null : randomDouble(),
randomBoolean() ? null : randomDouble(),
randomBoolean() ? null : randomDouble(),
randomBoolean() ? null : randomDouble(),
randomBoolean() ? null : randomDouble()
);
}
}

View File

@ -0,0 +1,50 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe.stats.classification;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
public class TimingStatsTests extends AbstractXContentTestCase<TimingStats> {
@Override
protected boolean supportsUnknownFields() {
return true;
}
@Override
protected TimingStats doParseInstance(XContentParser parser) throws IOException {
return TimingStats.PARSER.apply(parser, null);
}
@Override
protected TimingStats createTestInstance() {
return createRandom();
}
public static TimingStats createRandom() {
return new TimingStats(
randomBoolean() ? null : TimeValue.timeValueMillis(randomNonNegativeLong()),
randomBoolean() ? null : TimeValue.timeValueMillis(randomNonNegativeLong())
);
}
}

View File

@ -0,0 +1,50 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe.stats.classification;
import org.elasticsearch.client.ml.dataframe.stats.common.FoldValuesTests;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
public class ValidationLossTests extends AbstractXContentTestCase<ValidationLoss> {
@Override
protected boolean supportsUnknownFields() {
return true;
}
@Override
protected ValidationLoss doParseInstance(XContentParser parser) throws IOException {
return ValidationLoss.PARSER.apply(parser, null);
}
@Override
protected ValidationLoss createTestInstance() {
return createRandom();
}
public static ValidationLoss createRandom() {
return new ValidationLoss(
randomBoolean() ? null : randomAlphaOfLength(10),
randomBoolean() ? null : randomList(5, () -> FoldValuesTests.createRandom())
);
}
}

View File

@ -0,0 +1,51 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe.stats.common;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
public class FoldValuesTests extends AbstractXContentTestCase<FoldValues> {
@Override
protected boolean supportsUnknownFields() {
return true;
}
@Override
protected FoldValues doParseInstance(XContentParser parser) throws IOException {
return FoldValues.PARSER.apply(parser, null);
}
@Override
protected FoldValues createTestInstance() {
return createRandom();
}
public static FoldValues createRandom() {
int valuesSize = randomIntBetween(0, 10);
double[] values = new double[valuesSize];
for (int i = 0; i < valuesSize; i++) {
values[i] = randomDouble();
}
return new FoldValues(randomIntBetween(0, Integer.MAX_VALUE), values);
}
}

View File

@ -16,7 +16,7 @@
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe;
package org.elasticsearch.client.ml.dataframe.stats.common;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;

View File

@ -0,0 +1,51 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe.stats.outlierdetection;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
import java.time.Instant;
public class OutlierDetectionStatsTests extends AbstractXContentTestCase<OutlierDetectionStats> {
@Override
protected boolean supportsUnknownFields() {
return true;
}
@Override
protected OutlierDetectionStats doParseInstance(XContentParser parser) throws IOException {
return OutlierDetectionStats.PARSER.apply(parser, null);
}
@Override
protected OutlierDetectionStats createTestInstance() {
return createRandom();
}
public static OutlierDetectionStats createRandom() {
return new OutlierDetectionStats(
Instant.now(),
ParametersTests.createRandom(),
TimingStatsTests.createRandom()
);
}
}

View File

@ -0,0 +1,53 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe.stats.outlierdetection;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
public class ParametersTests extends AbstractXContentTestCase<Parameters> {
@Override
protected boolean supportsUnknownFields() {
return true;
}
@Override
protected Parameters doParseInstance(XContentParser parser) throws IOException {
return Parameters.PARSER.apply(parser, null);
}
@Override
protected Parameters createTestInstance() {
return createRandom();
}
public static Parameters createRandom() {
return new Parameters(
randomBoolean() ? null : randomIntBetween(1, Integer.MAX_VALUE),
randomBoolean() ? null : randomAlphaOfLength(5),
randomBoolean() ? null : randomBoolean(),
randomBoolean() ? null : randomDouble(),
randomBoolean() ? null : randomDouble(),
randomBoolean() ? null : randomBoolean()
);
}
}

View File

@ -0,0 +1,48 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe.stats.outlierdetection;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
public class TimingStatsTests extends AbstractXContentTestCase<TimingStats> {
@Override
protected boolean supportsUnknownFields() {
return true;
}
@Override
protected TimingStats doParseInstance(XContentParser parser) throws IOException {
return TimingStats.PARSER.apply(parser, null);
}
@Override
protected TimingStats createTestInstance() {
return createRandom();
}
public static TimingStats createRandom() {
return new TimingStats(randomBoolean() ? null : TimeValue.timeValueMillis(randomNonNegativeLong()));
}
}

View File

@ -0,0 +1,62 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe.stats.regression;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
public class HyperparametersTests extends AbstractXContentTestCase<Hyperparameters> {
@Override
protected boolean supportsUnknownFields() {
return true;
}
@Override
protected Hyperparameters doParseInstance(XContentParser parser) throws IOException {
return Hyperparameters.PARSER.apply(parser, null);
}
@Override
protected Hyperparameters createTestInstance() {
return createRandom();
}
public static Hyperparameters createRandom() {
return new Hyperparameters(
randomBoolean() ? null : randomDouble(),
randomBoolean() ? null : randomDouble(),
randomBoolean() ? null : randomDouble(),
randomBoolean() ? null : randomDouble(),
randomBoolean() ? null : randomIntBetween(0, Integer.MAX_VALUE),
randomBoolean() ? null : randomIntBetween(0, Integer.MAX_VALUE),
randomBoolean() ? null : randomIntBetween(0, Integer.MAX_VALUE),
randomBoolean() ? null : randomIntBetween(0, Integer.MAX_VALUE),
randomBoolean() ? null : randomIntBetween(0, Integer.MAX_VALUE),
randomBoolean() ? null : randomDouble(),
randomBoolean() ? null : randomDouble(),
randomBoolean() ? null : randomDouble(),
randomBoolean() ? null : randomDouble(),
randomBoolean() ? null : randomDouble()
);
}
}

View File

@ -0,0 +1,54 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe.stats.regression;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
import java.time.Instant;
public class RegressionStatsTests extends AbstractXContentTestCase<RegressionStats> {
@Override
protected boolean supportsUnknownFields() {
return true;
}
@Override
protected RegressionStats doParseInstance(XContentParser parser) throws IOException {
return RegressionStats.PARSER.apply(parser, null);
}
@Override
protected RegressionStats createTestInstance() {
return createRandom();
}
public static RegressionStats createRandom() {
return new RegressionStats(
Instant.now(),
randomBoolean() ? null : randomIntBetween(1, Integer.MAX_VALUE),
HyperparametersTests.createRandom(),
TimingStatsTests.createRandom(),
ValidationLossTests.createRandom()
);
}
}

View File

@ -0,0 +1,50 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe.stats.regression;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
public class TimingStatsTests extends AbstractXContentTestCase<TimingStats> {
@Override
protected boolean supportsUnknownFields() {
return true;
}
@Override
protected TimingStats doParseInstance(XContentParser parser) throws IOException {
return TimingStats.PARSER.apply(parser, null);
}
@Override
protected TimingStats createTestInstance() {
return createRandom();
}
public static TimingStats createRandom() {
return new TimingStats(
randomBoolean() ? null : TimeValue.timeValueMillis(randomNonNegativeLong()),
randomBoolean() ? null : TimeValue.timeValueMillis(randomNonNegativeLong())
);
}
}

View File

@ -0,0 +1,50 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe.stats.regression;
import org.elasticsearch.client.ml.dataframe.stats.common.FoldValuesTests;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
public class ValidationLossTests extends AbstractXContentTestCase<ValidationLoss> {
@Override
protected boolean supportsUnknownFields() {
return true;
}
@Override
protected ValidationLoss doParseInstance(XContentParser parser) throws IOException {
return ValidationLoss.PARSER.apply(parser, null);
}
@Override
protected ValidationLoss createTestInstance() {
return createRandom();
}
public static ValidationLoss createRandom() {
return new ValidationLoss(
randomBoolean() ? null : randomAlphaOfLength(10),
randomBoolean() ? null : randomList(5, () -> FoldValuesTests.createRandom())
);
}
}

View File

@ -39,18 +39,15 @@ import org.elasticsearch.transport.Transport;
import org.elasticsearch.xpack.core.action.XPackInfoAction;
import org.elasticsearch.xpack.core.action.XPackUsageAction;
import org.elasticsearch.xpack.core.analytics.AnalyticsFeatureSetUsage;
import org.elasticsearch.xpack.core.search.action.DeleteAsyncSearchAction;
import org.elasticsearch.xpack.core.search.action.GetAsyncSearchAction;
import org.elasticsearch.xpack.core.search.action.SubmitAsyncSearchAction;
import org.elasticsearch.xpack.core.ccr.AutoFollowMetadata;
import org.elasticsearch.xpack.core.ccr.CCRFeatureSet;
import org.elasticsearch.xpack.core.deprecation.DeprecationInfoAction;
import org.elasticsearch.xpack.core.enrich.EnrichFeatureSet;
import org.elasticsearch.xpack.core.eql.EqlFeatureSetUsage;
import org.elasticsearch.xpack.core.enrich.action.DeleteEnrichPolicyAction;
import org.elasticsearch.xpack.core.enrich.action.ExecuteEnrichPolicyAction;
import org.elasticsearch.xpack.core.enrich.action.GetEnrichPolicyAction;
import org.elasticsearch.xpack.core.enrich.action.PutEnrichPolicyAction;
import org.elasticsearch.xpack.core.eql.EqlFeatureSetUsage;
import org.elasticsearch.xpack.core.flattened.FlattenedFeatureSetUsage;
import org.elasticsearch.xpack.core.frozen.FrozenIndicesFeatureSetUsage;
import org.elasticsearch.xpack.core.frozen.action.FreezeIndexAction;
@ -152,6 +149,10 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStats;
import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStats;
import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStats;
import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStats;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.CustomWordEmbedding;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
@ -184,6 +185,9 @@ import org.elasticsearch.xpack.core.rollup.action.StartRollupJobAction;
import org.elasticsearch.xpack.core.rollup.action.StopRollupJobAction;
import org.elasticsearch.xpack.core.rollup.job.RollupJob;
import org.elasticsearch.xpack.core.rollup.job.RollupJobStatus;
import org.elasticsearch.xpack.core.search.action.DeleteAsyncSearchAction;
import org.elasticsearch.xpack.core.search.action.GetAsyncSearchAction;
import org.elasticsearch.xpack.core.search.action.SubmitAsyncSearchAction;
import org.elasticsearch.xpack.core.security.SecurityFeatureSetUsage;
import org.elasticsearch.xpack.core.security.SecurityField;
import org.elasticsearch.xpack.core.security.SecuritySettings;
@ -505,6 +509,9 @@ public class XPackClientPlugin extends Plugin implements ActionPlugin, NetworkPl
new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, OutlierDetection.NAME.getPreferredName(), OutlierDetection::new),
new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, Regression.NAME.getPreferredName(), Regression::new),
new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, Classification.NAME.getPreferredName(), Classification::new),
new NamedWriteableRegistry.Entry(AnalysisStats.class, OutlierDetectionStats.TYPE_VALUE, OutlierDetectionStats::new),
new NamedWriteableRegistry.Entry(AnalysisStats.class, RegressionStats.TYPE_VALUE, RegressionStats::new),
new NamedWriteableRegistry.Entry(AnalysisStats.class, ClassificationStats.TYPE_VALUE, ClassificationStats::new),
// ML - Inference preprocessing
new NamedWriteableRegistry.Entry(PreProcessor.class, FrequencyEncoding.NAME.getPreferredName(), FrequencyEncoding::new),
new NamedWriteableRegistry.Entry(PreProcessor.class, OneHotEncoding.NAME.getPreferredName(), OneHotEncoding::new),

View File

@ -28,6 +28,7 @@ import org.elasticsearch.xpack.core.action.util.PageParams;
import org.elasticsearch.xpack.core.action.util.QueryPage;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStats;
import org.elasticsearch.xpack.core.ml.dataframe.stats.MemoryUsage;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
@ -167,18 +168,23 @@ public class GetDataFrameAnalyticsStatsAction extends ActionType<GetDataFrameAna
@Nullable
private final MemoryUsage memoryUsage;
@Nullable
private final AnalysisStats analysisStats;
@Nullable
private final DiscoveryNode node;
@Nullable
private final String assignmentExplanation;
public Stats(String id, DataFrameAnalyticsState state, @Nullable String failureReason, List<PhaseProgress> progress,
@Nullable MemoryUsage memoryUsage, @Nullable DiscoveryNode node, @Nullable String assignmentExplanation) {
@Nullable MemoryUsage memoryUsage, @Nullable AnalysisStats analysisStats, @Nullable DiscoveryNode node,
@Nullable String assignmentExplanation) {
this.id = Objects.requireNonNull(id);
this.state = Objects.requireNonNull(state);
this.failureReason = failureReason;
this.progress = Objects.requireNonNull(progress);
this.memoryUsage = memoryUsage;
this.analysisStats = analysisStats;
this.node = node;
this.assignmentExplanation = assignmentExplanation;
}
@ -197,6 +203,11 @@ public class GetDataFrameAnalyticsStatsAction extends ActionType<GetDataFrameAna
} else {
memoryUsage = null;
}
if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
analysisStats = in.readOptionalNamedWriteable(AnalysisStats.class);
} else {
analysisStats = null;
}
node = in.readOptionalWriteable(DiscoveryNode::new);
assignmentExplanation = in.readOptionalString();
}
@ -285,6 +296,11 @@ public class GetDataFrameAnalyticsStatsAction extends ActionType<GetDataFrameAna
if (memoryUsage != null) {
builder.field("memory_usage", memoryUsage);
}
if (analysisStats != null) {
builder.startObject("analysis_stats");
builder.field(analysisStats.getWriteableName(), analysisStats);
builder.endObject();
}
if (node != null) {
builder.startObject("node");
builder.field("id", node.getId());
@ -318,6 +334,9 @@ public class GetDataFrameAnalyticsStatsAction extends ActionType<GetDataFrameAna
if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
out.writeOptionalWriteable(memoryUsage);
}
if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
out.writeOptionalNamedWriteable(analysisStats);
}
out.writeOptionalWriteable(node);
out.writeOptionalString(assignmentExplanation);
}
@ -350,7 +369,7 @@ public class GetDataFrameAnalyticsStatsAction extends ActionType<GetDataFrameAna
@Override
public int hashCode() {
return Objects.hash(id, state, failureReason, progress, memoryUsage, node, assignmentExplanation);
return Objects.hash(id, state, failureReason, progress, memoryUsage, analysisStats, node, assignmentExplanation);
}
@Override
@ -367,6 +386,7 @@ public class GetDataFrameAnalyticsStatsAction extends ActionType<GetDataFrameAna
&& Objects.equals(this.failureReason, other.failureReason)
&& Objects.equals(this.progress, other.progress)
&& Objects.equals(this.memoryUsage, other.memoryUsage)
&& Objects.equals(this.analysisStats, other.analysisStats)
&& Objects.equals(this.node, other.node)
&& Objects.equals(this.assignmentExplanation, other.assignmentExplanation);
}

View File

@ -0,0 +1,15 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.stats;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.xcontent.ToXContentObject;
/**
* Statistics for the data frame analysis
*/
public interface AnalysisStats extends ToXContentObject, NamedWriteable {
}

View File

@ -0,0 +1,25 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.stats;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStats;
import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStats;
import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStats;
import java.util.Arrays;
import java.util.List;
public class AnalysisStatsNamedWriteablesProvider {
public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
return Arrays.asList(
new NamedWriteableRegistry.Entry(AnalysisStats.class, ClassificationStats.TYPE_VALUE, ClassificationStats::new),
new NamedWriteableRegistry.Entry(AnalysisStats.class, OutlierDetectionStats.TYPE_VALUE, OutlierDetectionStats::new),
new NamedWriteableRegistry.Entry(AnalysisStats.class, RegressionStats.TYPE_VALUE, RegressionStats::new)
);
}
}

View File

@ -0,0 +1,20 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.stats;
import org.elasticsearch.common.ParseField;
/**
* A collection of parse fields commonly used by stats objects
*/
public final class Fields {
public static final ParseField TYPE = new ParseField("type");
public static final ParseField JOB_ID = new ParseField("job_id");
public static final ParseField TIMESTAMP = new ParseField("timestamp");
private Fields() {}
}

View File

@ -26,9 +26,6 @@ public class MemoryUsage implements Writeable, ToXContentObject {
public static final String TYPE_VALUE = "analytics_memory_usage";
public static final ParseField TYPE = new ParseField("type");
public static final ParseField JOB_ID = new ParseField("job_id");
public static final ParseField TIMESTAMP = new ParseField("timestamp");
public static final ParseField PEAK_USAGE_BYTES = new ParseField("peak_usage_bytes");
public static final ConstructingObjectParser<MemoryUsage, Void> STRICT_PARSER = createParser(false);
@ -38,11 +35,11 @@ public class MemoryUsage implements Writeable, ToXContentObject {
ConstructingObjectParser<MemoryUsage, Void> parser = new ConstructingObjectParser<>(TYPE_VALUE,
ignoreUnknownFields, a -> new MemoryUsage((String) a[0], (Instant) a[1], (long) a[2]));
parser.declareString((bucket, s) -> {}, TYPE);
parser.declareString(ConstructingObjectParser.constructorArg(), JOB_ID);
parser.declareString((bucket, s) -> {}, Fields.TYPE);
parser.declareString(ConstructingObjectParser.constructorArg(), Fields.JOB_ID);
parser.declareField(ConstructingObjectParser.constructorArg(),
p -> TimeUtils.parseTimeFieldToInstant(p, TIMESTAMP.getPreferredName()),
TIMESTAMP,
p -> TimeUtils.parseTimeFieldToInstant(p, Fields.TIMESTAMP.getPreferredName()),
Fields.TIMESTAMP,
ObjectParser.ValueType.VALUE);
parser.declareLong(ConstructingObjectParser.constructorArg(), PEAK_USAGE_BYTES);
return parser;
@ -56,7 +53,7 @@ public class MemoryUsage implements Writeable, ToXContentObject {
this.jobId = Objects.requireNonNull(jobId);
// We intend to store this timestamp in millis granularity. Thus we're rounding here to ensure
// internal representation matches toXContent
this.timestamp = Instant.ofEpochMilli(ExceptionsHelper.requireNonNull(timestamp, TIMESTAMP).toEpochMilli());
this.timestamp = Instant.ofEpochMilli(ExceptionsHelper.requireNonNull(timestamp, Fields.TIMESTAMP).toEpochMilli());
this.peakUsageBytes = peakUsageBytes;
}
@ -77,10 +74,10 @@ public class MemoryUsage implements Writeable, ToXContentObject {
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false)) {
builder.field(TYPE.getPreferredName(), TYPE_VALUE);
builder.field(JOB_ID.getPreferredName(), jobId);
builder.field(Fields.TYPE.getPreferredName(), TYPE_VALUE);
builder.field(Fields.JOB_ID.getPreferredName(), jobId);
}
builder.timeField(TIMESTAMP.getPreferredName(), TIMESTAMP.getPreferredName() + "_string", timestamp.toEpochMilli());
builder.timeField(Fields.TIMESTAMP.getPreferredName(), Fields.TIMESTAMP.getPreferredName() + "_string", timestamp.toEpochMilli());
builder.field(PEAK_USAGE_BYTES.getPreferredName(), peakUsageBytes);
builder.endObject();
return builder;

View File

@ -0,0 +1,148 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.stats.classification;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.common.time.TimeUtils;
import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStats;
import org.elasticsearch.xpack.core.ml.dataframe.stats.Fields;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
import java.io.IOException;
import java.time.Instant;
import java.util.Objects;
public class ClassificationStats implements AnalysisStats {
public static final String TYPE_VALUE = "classification_stats";
public static final ParseField ITERATION = new ParseField("iteration");
public static final ParseField HYPERPARAMETERS = new ParseField("hyperparameters");
public static final ParseField TIMING_STATS = new ParseField("timing_stats");
public static final ParseField VALIDATION_LOSS = new ParseField("validation_loss");
public static final ConstructingObjectParser<ClassificationStats, Void> STRICT_PARSER = createParser(false);
public static final ConstructingObjectParser<ClassificationStats, Void> LENIENT_PARSER = createParser(true);
private static ConstructingObjectParser<ClassificationStats, Void> createParser(boolean ignoreUnknownFields) {
ConstructingObjectParser<ClassificationStats, Void> parser = new ConstructingObjectParser<>(TYPE_VALUE, ignoreUnknownFields,
a -> new ClassificationStats(
(String) a[0],
(Instant) a[1],
(int) a[2],
(Hyperparameters) a[3],
(TimingStats) a[4],
(ValidationLoss) a[5]
)
);
parser.declareString((bucket, s) -> {}, Fields.TYPE);
parser.declareString(ConstructingObjectParser.constructorArg(), Fields.JOB_ID);
parser.declareField(ConstructingObjectParser.constructorArg(),
p -> TimeUtils.parseTimeFieldToInstant(p, Fields.TIMESTAMP.getPreferredName()),
Fields.TIMESTAMP,
ObjectParser.ValueType.VALUE);
parser.declareInt(ConstructingObjectParser.constructorArg(), ITERATION);
parser.declareObject(ConstructingObjectParser.constructorArg(),
(p, c) -> Hyperparameters.fromXContent(p, ignoreUnknownFields), HYPERPARAMETERS);
parser.declareObject(ConstructingObjectParser.constructorArg(),
(p, c) -> TimingStats.fromXContent(p, ignoreUnknownFields), TIMING_STATS);
parser.declareObject(ConstructingObjectParser.constructorArg(),
(p, c) -> ValidationLoss.fromXContent(p, ignoreUnknownFields), VALIDATION_LOSS);
return parser;
}
private final String jobId;
private final Instant timestamp;
private final int iteration;
private final Hyperparameters hyperparameters;
private final TimingStats timingStats;
private final ValidationLoss validationLoss;
public ClassificationStats(String jobId, Instant timestamp, int iteration, Hyperparameters hyperparameters, TimingStats timingStats,
ValidationLoss validationLoss) {
this.jobId = Objects.requireNonNull(jobId);
// We intend to store this timestamp in millis granularity. Thus we're rounding here to ensure
// internal representation matches toXContent
this.timestamp = Instant.ofEpochMilli(ExceptionsHelper.requireNonNull(timestamp, Fields.TIMESTAMP).toEpochMilli());
this.iteration = iteration;
this.hyperparameters = Objects.requireNonNull(hyperparameters);
this.timingStats = Objects.requireNonNull(timingStats);
this.validationLoss = Objects.requireNonNull(validationLoss);
}
public ClassificationStats(StreamInput in) throws IOException {
this.jobId = in.readString();
this.timestamp = in.readInstant();
this.iteration = in.readVInt();
this.hyperparameters = new Hyperparameters(in);
this.timingStats = new TimingStats(in);
this.validationLoss = new ValidationLoss(in);
}
@Override
public String getWriteableName() {
return TYPE_VALUE;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(jobId);
out.writeInstant(timestamp);
out.writeVInt(iteration);
hyperparameters.writeTo(out);
timingStats.writeTo(out);
validationLoss.writeTo(out);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false)) {
builder.field(Fields.TYPE.getPreferredName(), TYPE_VALUE);
builder.field(Fields.JOB_ID.getPreferredName(), jobId);
}
builder.timeField(Fields.TIMESTAMP.getPreferredName(), Fields.TIMESTAMP.getPreferredName() + "_string", timestamp.toEpochMilli());
builder.field(ITERATION.getPreferredName(), iteration);
builder.field(HYPERPARAMETERS.getPreferredName(), hyperparameters);
builder.field(TIMING_STATS.getPreferredName(), timingStats);
builder.field(VALIDATION_LOSS.getPreferredName(), validationLoss);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ClassificationStats that = (ClassificationStats) o;
return Objects.equals(jobId, that.jobId)
&& Objects.equals(timestamp, that.timestamp)
&& iteration == that.iteration
&& Objects.equals(hyperparameters, that.hyperparameters)
&& Objects.equals(timingStats, that.timingStats)
&& Objects.equals(validationLoss, that.validationLoss);
}
@Override
public int hashCode() {
return Objects.hash(jobId, timestamp, iteration, hyperparameters, timingStats, validationLoss);
}
public String documentId(String jobId) {
return documentIdPrefix(jobId) + timestamp.toEpochMilli();
}
public static String documentIdPrefix(String jobId) {
return TYPE_VALUE + "_" + jobId + "_";
}
}

View File

@ -0,0 +1,237 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.stats.classification;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.Objects;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
public class Hyperparameters implements ToXContentObject, Writeable {
public static final ParseField CLASS_ASSIGNMENT_OBJECTIVE = new ParseField("class_assignment_objective");
public static final ParseField DOWNSAMPLE_FACTOR = new ParseField("downsample_factor");
public static final ParseField ETA = new ParseField("eta");
public static final ParseField ETA_GROWTH_RATE_PER_TREE = new ParseField("eta_growth_rate_per_tree");
public static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction");
public static final ParseField MAX_ATTEMPTS_TO_ADD_TREE = new ParseField("max_attempts_to_add_tree");
public static final ParseField MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER = new ParseField(
"max_optimization_rounds_per_hyperparameter");
public static final ParseField MAX_TREES = new ParseField("max_trees");
public static final ParseField NUM_FOLDS = new ParseField("num_folds");
public static final ParseField NUM_SPLITS_PER_FEATURE = new ParseField("num_splits_per_feature");
public static final ParseField REGULARIZATION_DEPTH_PENALTY_MULTIPLIER = new ParseField("regularization_depth_penalty_multiplier");
public static final ParseField REGULARIZATION_LEAF_WEIGHT_PENALTY_MULTIPLIER
= new ParseField("regularization_leaf_weight_penalty_multiplier");
public static final ParseField REGULARIZATION_SOFT_TREE_DEPTH_LIMIT = new ParseField("regularization_soft_tree_depth_limit");
public static final ParseField REGULARIZATION_SOFT_TREE_DEPTH_TOLERANCE = new ParseField("regularization_soft_tree_depth_tolerance");
public static final ParseField REGULARIZATION_TREE_SIZE_PENALTY_MULTIPLIER =
new ParseField("regularization_tree_size_penalty_multiplier");
public static Hyperparameters fromXContent(XContentParser parser, boolean ignoreUnknownFields) {
return createParser(ignoreUnknownFields).apply(parser, null);
}
private static ConstructingObjectParser<Hyperparameters, Void> createParser(boolean ignoreUnknownFields) {
ConstructingObjectParser<Hyperparameters, Void> parser = new ConstructingObjectParser<>("classification_hyperparameters",
ignoreUnknownFields,
a -> new Hyperparameters(
(String) a[0],
(double) a[1],
(double) a[2],
(double) a[3],
(double) a[4],
(int) a[5],
(int) a[6],
(int) a[7],
(int) a[8],
(int) a[9],
(double) a[10],
(double) a[11],
(double) a[12],
(double) a[13],
(double) a[14]
));
parser.declareString(constructorArg(), CLASS_ASSIGNMENT_OBJECTIVE);
parser.declareDouble(constructorArg(), DOWNSAMPLE_FACTOR);
parser.declareDouble(constructorArg(), ETA);
parser.declareDouble(constructorArg(), ETA_GROWTH_RATE_PER_TREE);
parser.declareDouble(constructorArg(), FEATURE_BAG_FRACTION);
parser.declareInt(constructorArg(), MAX_ATTEMPTS_TO_ADD_TREE);
parser.declareInt(constructorArg(), MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER);
parser.declareInt(constructorArg(), MAX_TREES);
parser.declareInt(constructorArg(), NUM_FOLDS);
parser.declareInt(constructorArg(), NUM_SPLITS_PER_FEATURE);
parser.declareDouble(constructorArg(), REGULARIZATION_DEPTH_PENALTY_MULTIPLIER);
parser.declareDouble(constructorArg(), REGULARIZATION_LEAF_WEIGHT_PENALTY_MULTIPLIER);
parser.declareDouble(constructorArg(), REGULARIZATION_SOFT_TREE_DEPTH_LIMIT);
parser.declareDouble(constructorArg(), REGULARIZATION_SOFT_TREE_DEPTH_TOLERANCE);
parser.declareDouble(constructorArg(), REGULARIZATION_TREE_SIZE_PENALTY_MULTIPLIER);
return parser;
}
private final String classAssignmentObjective;
private final double downsampleFactor;
private final double eta;
private final double etaGrowthRatePerTree;
private final double featureBagFraction;
private final int maxAttemptsToAddTree;
private final int maxOptimizationRoundsPerHyperparameter;
private final int maxTrees;
private final int numFolds;
private final int numSplitsPerFeature;
private final double regularizationDepthPenaltyMultiplier;
private final double regularizationLeafWeightPenaltyMultiplier;
private final double regularizationSoftTreeDepthLimit;
private final double regularizationSoftTreeDepthTolerance;
private final double regularizationTreeSizePenaltyMultiplier;
public Hyperparameters(String classAssignmentObjective,
double downsampleFactor,
double eta,
double etaGrowthRatePerTree,
double featureBagFraction,
int maxAttemptsToAddTree,
int maxOptimizationRoundsPerHyperparameter,
int maxTrees,
int numFolds,
int numSplitsPerFeature,
double regularizationDepthPenaltyMultiplier,
double regularizationLeafWeightPenaltyMultiplier,
double regularizationSoftTreeDepthLimit,
double regularizationSoftTreeDepthTolerance,
double regularizationTreeSizePenaltyMultiplier) {
this.classAssignmentObjective = Objects.requireNonNull(classAssignmentObjective);
this.downsampleFactor = downsampleFactor;
this.eta = eta;
this.etaGrowthRatePerTree = etaGrowthRatePerTree;
this.featureBagFraction = featureBagFraction;
this.maxAttemptsToAddTree = maxAttemptsToAddTree;
this.maxOptimizationRoundsPerHyperparameter = maxOptimizationRoundsPerHyperparameter;
this.maxTrees = maxTrees;
this.numFolds = numFolds;
this.numSplitsPerFeature = numSplitsPerFeature;
this.regularizationDepthPenaltyMultiplier = regularizationDepthPenaltyMultiplier;
this.regularizationLeafWeightPenaltyMultiplier = regularizationLeafWeightPenaltyMultiplier;
this.regularizationSoftTreeDepthLimit = regularizationSoftTreeDepthLimit;
this.regularizationSoftTreeDepthTolerance = regularizationSoftTreeDepthTolerance;
this.regularizationTreeSizePenaltyMultiplier = regularizationTreeSizePenaltyMultiplier;
}
public Hyperparameters(StreamInput in) throws IOException {
this.classAssignmentObjective = in.readString();
this.downsampleFactor = in.readDouble();
this.eta = in.readDouble();
this.etaGrowthRatePerTree = in.readDouble();
this.featureBagFraction = in.readDouble();
this.maxAttemptsToAddTree = in.readVInt();
this.maxOptimizationRoundsPerHyperparameter = in.readVInt();
this.maxTrees = in.readVInt();
this.numFolds = in.readVInt();
this.numSplitsPerFeature = in.readVInt();
this.regularizationDepthPenaltyMultiplier = in.readDouble();
this.regularizationLeafWeightPenaltyMultiplier = in.readDouble();
this.regularizationSoftTreeDepthLimit = in.readDouble();
this.regularizationSoftTreeDepthTolerance = in.readDouble();
this.regularizationTreeSizePenaltyMultiplier = in.readDouble();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(classAssignmentObjective);
out.writeDouble(downsampleFactor);
out.writeDouble(eta);
out.writeDouble(etaGrowthRatePerTree);
out.writeDouble(featureBagFraction);
out.writeVInt(maxAttemptsToAddTree);
out.writeVInt(maxOptimizationRoundsPerHyperparameter);
out.writeVInt(maxTrees);
out.writeVInt(numFolds);
out.writeVInt(numSplitsPerFeature);
out.writeDouble(regularizationDepthPenaltyMultiplier);
out.writeDouble(regularizationLeafWeightPenaltyMultiplier);
out.writeDouble(regularizationSoftTreeDepthLimit);
out.writeDouble(regularizationSoftTreeDepthTolerance);
out.writeDouble(regularizationTreeSizePenaltyMultiplier);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(CLASS_ASSIGNMENT_OBJECTIVE.getPreferredName(), classAssignmentObjective);
builder.field(DOWNSAMPLE_FACTOR.getPreferredName(), downsampleFactor);
builder.field(ETA.getPreferredName(), eta);
builder.field(ETA_GROWTH_RATE_PER_TREE.getPreferredName(), etaGrowthRatePerTree);
builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction);
builder.field(MAX_ATTEMPTS_TO_ADD_TREE.getPreferredName(), maxAttemptsToAddTree);
builder.field(MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER.getPreferredName(), maxOptimizationRoundsPerHyperparameter);
builder.field(MAX_TREES.getPreferredName(), maxTrees);
builder.field(NUM_FOLDS.getPreferredName(), numFolds);
builder.field(NUM_SPLITS_PER_FEATURE.getPreferredName(), numSplitsPerFeature);
builder.field(REGULARIZATION_DEPTH_PENALTY_MULTIPLIER.getPreferredName(), regularizationDepthPenaltyMultiplier);
builder.field(REGULARIZATION_LEAF_WEIGHT_PENALTY_MULTIPLIER.getPreferredName(), regularizationLeafWeightPenaltyMultiplier);
builder.field(REGULARIZATION_SOFT_TREE_DEPTH_LIMIT.getPreferredName(), regularizationSoftTreeDepthLimit);
builder.field(REGULARIZATION_SOFT_TREE_DEPTH_TOLERANCE.getPreferredName(), regularizationSoftTreeDepthTolerance);
builder.field(REGULARIZATION_TREE_SIZE_PENALTY_MULTIPLIER.getPreferredName(), regularizationTreeSizePenaltyMultiplier);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Hyperparameters that = (Hyperparameters) o;
return Objects.equals(classAssignmentObjective, that.classAssignmentObjective)
&& downsampleFactor == that.downsampleFactor
&& eta == that.eta
&& etaGrowthRatePerTree == that.etaGrowthRatePerTree
&& featureBagFraction == that.featureBagFraction
&& maxAttemptsToAddTree == that.maxAttemptsToAddTree
&& maxOptimizationRoundsPerHyperparameter == that.maxOptimizationRoundsPerHyperparameter
&& maxTrees == that.maxTrees
&& numFolds == that.numFolds
&& numSplitsPerFeature == that.numSplitsPerFeature
&& regularizationDepthPenaltyMultiplier == that.regularizationDepthPenaltyMultiplier
&& regularizationLeafWeightPenaltyMultiplier == that.regularizationLeafWeightPenaltyMultiplier
&& regularizationSoftTreeDepthLimit == that.regularizationSoftTreeDepthLimit
&& regularizationSoftTreeDepthTolerance == that.regularizationSoftTreeDepthTolerance
&& regularizationTreeSizePenaltyMultiplier == that.regularizationTreeSizePenaltyMultiplier;
}
@Override
public int hashCode() {
return Objects.hash(
classAssignmentObjective,
downsampleFactor,
eta,
etaGrowthRatePerTree,
featureBagFraction,
maxAttemptsToAddTree,
maxOptimizationRoundsPerHyperparameter,
maxTrees,
numFolds,
numSplitsPerFeature,
regularizationDepthPenaltyMultiplier,
regularizationLeafWeightPenaltyMultiplier,
regularizationSoftTreeDepthLimit,
regularizationSoftTreeDepthTolerance,
regularizationTreeSizePenaltyMultiplier
);
}
}

View File

@ -0,0 +1,80 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.stats.classification;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.Objects;
public class TimingStats implements Writeable, ToXContentObject {
public static final ParseField ELAPSED_TIME = new ParseField("elapsed_time");
public static final ParseField ITERATION_TIME = new ParseField("iteration_time");
public static TimingStats fromXContent(XContentParser parser, boolean ignoreUnknownFields) {
return createParser(ignoreUnknownFields).apply(parser, null);
}
private static ConstructingObjectParser<TimingStats, Void> createParser(boolean ignoreUnknownFields) {
ConstructingObjectParser<TimingStats, Void> parser = new ConstructingObjectParser<>("classification_timing_stats",
ignoreUnknownFields,
a -> new TimingStats(TimeValue.timeValueMillis((long) a[0]), TimeValue.timeValueMillis((long) a[1])));
parser.declareLong(ConstructingObjectParser.constructorArg(), ELAPSED_TIME);
parser.declareLong(ConstructingObjectParser.constructorArg(), ITERATION_TIME);
return parser;
}
private final TimeValue elapsedTime;
private final TimeValue iterationTime;
public TimingStats(TimeValue elapsedTime, TimeValue iterationTime) {
this.elapsedTime = Objects.requireNonNull(elapsedTime);
this.iterationTime = Objects.requireNonNull(iterationTime);
}
public TimingStats(StreamInput in) throws IOException {
this.elapsedTime = in.readTimeValue();
this.iterationTime = in.readTimeValue();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeTimeValue(elapsedTime);
out.writeTimeValue(iterationTime);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.humanReadableField(ELAPSED_TIME.getPreferredName(), ELAPSED_TIME.getPreferredName() + "_string", elapsedTime);
builder.humanReadableField(ITERATION_TIME.getPreferredName(), ITERATION_TIME.getPreferredName() + "_string", iterationTime);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
TimingStats that = (TimingStats) o;
return Objects.equals(elapsedTime, that.elapsedTime) && Objects.equals(iterationTime, that.iterationTime);
}
@Override
public int hashCode() {
return Objects.hash(elapsedTime, iterationTime);
}
}

View File

@ -0,0 +1,82 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.stats.classification;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.dataframe.stats.common.FoldValues;
import java.io.IOException;
import java.util.List;
import java.util.Objects;
public class ValidationLoss implements ToXContentObject, Writeable {
public static final ParseField LOSS_TYPE = new ParseField("loss_type");
public static final ParseField FOLD_VALUES = new ParseField("fold_values");
public static ValidationLoss fromXContent(XContentParser parser, boolean ignoreUnknownFields) {
return createParser(ignoreUnknownFields).apply(parser, null);
}
private static ConstructingObjectParser<ValidationLoss, Void> createParser(boolean ignoreUnknownFields) {
ConstructingObjectParser<ValidationLoss, Void> parser = new ConstructingObjectParser<>("classification_validation_loss",
ignoreUnknownFields,
a -> new ValidationLoss((String) a[0], (List<FoldValues>) a[1]));
parser.declareString(ConstructingObjectParser.constructorArg(), LOSS_TYPE);
parser.declareObjectArray(ConstructingObjectParser.constructorArg(),
(p, c) -> FoldValues.fromXContent(p, ignoreUnknownFields), FOLD_VALUES);
return parser;
}
private final String lossType;
private final List<FoldValues> foldValues;
public ValidationLoss(String lossType, List<FoldValues> values) {
this.lossType = Objects.requireNonNull(lossType);
this.foldValues = Objects.requireNonNull(values);
}
public ValidationLoss(StreamInput in) throws IOException {
lossType = in.readString();
foldValues = in.readList(FoldValues::new);
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(lossType);
out.writeList(foldValues);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(LOSS_TYPE.getPreferredName(), lossType);
builder.field(FOLD_VALUES.getPreferredName(), foldValues);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ValidationLoss that = (ValidationLoss) o;
return Objects.equals(lossType, that.lossType) && Objects.equals(foldValues, that.foldValues);
}
@Override
public int hashCode() {
return Objects.hash(lossType, foldValues);
}
}

View File

@ -0,0 +1,84 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.stats.common;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
public class FoldValues implements Writeable, ToXContentObject {
public static final ParseField FOLD = new ParseField("fold");
public static final ParseField VALUES = new ParseField("values");
public static FoldValues fromXContent(XContentParser parser, boolean ignoreUnknownFields) {
return createParser(ignoreUnknownFields).apply(parser, null);
}
private static ConstructingObjectParser<FoldValues, Void> createParser(boolean ignoreUnknownFields) {
ConstructingObjectParser<FoldValues, Void> parser = new ConstructingObjectParser<>("fold_values", ignoreUnknownFields,
a -> new FoldValues((int) a[0], (List<Double>) a[1]));
parser.declareInt(ConstructingObjectParser.constructorArg(), FOLD);
parser.declareDoubleArray(ConstructingObjectParser.constructorArg(), VALUES);
return parser;
}
private final int fold;
private final double[] values;
private FoldValues(int fold, List<Double> values) {
this(fold, values.stream().mapToDouble(Double::doubleValue).toArray());
}
public FoldValues(int fold, double[] values) {
this.fold = fold;
this.values = values;
}
public FoldValues(StreamInput in) throws IOException {
fold = in.readVInt();
values = in.readDoubleArray();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeVInt(fold);
out.writeDoubleArray(values);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(FOLD.getPreferredName(), fold);
builder.array(VALUES.getPreferredName(), values);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (o == this) return true;
if (o == null || getClass() != o.getClass()) return false;
FoldValues other = (FoldValues) o;
return fold == other.fold && Arrays.equals(values, other.values);
}
@Override
public int hashCode() {
return Objects.hash(fold, Arrays.hashCode(values));
}
}

View File

@ -0,0 +1,122 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.common.time.TimeUtils;
import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStats;
import org.elasticsearch.xpack.core.ml.dataframe.stats.Fields;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
import java.io.IOException;
import java.time.Instant;
import java.util.Objects;
public class OutlierDetectionStats implements AnalysisStats {
public static final String TYPE_VALUE = "outlier_detection_stats";
public static final ParseField PARAMETERS = new ParseField("parameters");
public static final ParseField TIMING_STATS = new ParseField("timings_stats");
public static final ConstructingObjectParser<OutlierDetectionStats, Void> STRICT_PARSER = createParser(false);
public static final ConstructingObjectParser<OutlierDetectionStats, Void> LENIENT_PARSER = createParser(true);
private static ConstructingObjectParser<OutlierDetectionStats, Void> createParser(boolean ignoreUnknownFields) {
ConstructingObjectParser<OutlierDetectionStats, Void> parser = new ConstructingObjectParser<>(TYPE_VALUE, ignoreUnknownFields,
a -> new OutlierDetectionStats((String) a[0], (Instant) a[1], (Parameters) a[2], (TimingStats) a[3]));
parser.declareString((bucket, s) -> {}, Fields.TYPE);
parser.declareString(ConstructingObjectParser.constructorArg(), Fields.JOB_ID);
parser.declareField(ConstructingObjectParser.constructorArg(),
p -> TimeUtils.parseTimeFieldToInstant(p, Fields.TIMESTAMP.getPreferredName()),
Fields.TIMESTAMP,
ObjectParser.ValueType.VALUE);
parser.declareObject(ConstructingObjectParser.constructorArg(),
(p, c) -> Parameters.fromXContent(p, ignoreUnknownFields), PARAMETERS);
parser.declareObject(ConstructingObjectParser.constructorArg(),
(p, c) -> TimingStats.fromXContent(p, ignoreUnknownFields), TIMING_STATS);
return parser;
}
private final String jobId;
private final Instant timestamp;
private final Parameters parameters;
private final TimingStats timingStats;
public OutlierDetectionStats(String jobId, Instant timestamp, Parameters parameters, TimingStats timingStats) {
this.jobId = Objects.requireNonNull(jobId);
// We intend to store this timestamp in millis granularity. Thus we're rounding here to ensure
// internal representation matches toXContent
this.timestamp = Instant.ofEpochMilli(ExceptionsHelper.requireNonNull(timestamp, Fields.TIMESTAMP).toEpochMilli());
this.parameters = Objects.requireNonNull(parameters);
this.timingStats = Objects.requireNonNull(timingStats);
}
public OutlierDetectionStats(StreamInput in) throws IOException {
this.jobId = in.readString();
this.timestamp = in.readInstant();
this.parameters = new Parameters(in);
this.timingStats = new TimingStats(in);
}
@Override
public String getWriteableName() {
return TYPE_VALUE;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(jobId);
out.writeInstant(timestamp);
parameters.writeTo(out);
timingStats.writeTo(out);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false)) {
builder.field(Fields.TYPE.getPreferredName(), TYPE_VALUE);
builder.field(Fields.JOB_ID.getPreferredName(), jobId);
}
builder.timeField(Fields.TIMESTAMP.getPreferredName(), Fields.TIMESTAMP.getPreferredName() + "_string", timestamp.toEpochMilli());
builder.field(PARAMETERS.getPreferredName(), parameters);
builder.field(TIMING_STATS.getPreferredName(), timingStats);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
OutlierDetectionStats that = (OutlierDetectionStats) o;
return Objects.equals(jobId, that.jobId)
&& Objects.equals(timestamp, that.timestamp)
&& Objects.equals(parameters, that.parameters)
&& Objects.equals(timingStats, that.timingStats);
}
@Override
public int hashCode() {
return Objects.hash(jobId, timestamp, parameters, timingStats);
}
public String documentId(String jobId) {
return documentIdPrefix(jobId) + timestamp.toEpochMilli();
}
public static String documentIdPrefix(String jobId) {
return TYPE_VALUE + "_" + jobId + "_";
}
}

View File

@ -0,0 +1,125 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.Objects;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
public class Parameters implements Writeable, ToXContentObject {
public static final ParseField N_NEIGHBORS = new ParseField("n_neighbors");
public static final ParseField METHOD = new ParseField("method");
public static final ParseField FEATURE_INFLUENCE_THRESHOLD = new ParseField("feature_influence_threshold");
public static final ParseField COMPUTE_FEATURE_INFLUENCE = new ParseField("compute_feature_influence");
public static final ParseField OUTLIER_FRACTION = new ParseField("outlier_fraction");
public static final ParseField STANDARDIZATION_ENABLED = new ParseField("standardization_enabled");
public static Parameters fromXContent(XContentParser parser, boolean ignoreUnknownFields) {
return createParser(ignoreUnknownFields).apply(parser, null);
}
private static ConstructingObjectParser<Parameters, Void> createParser(boolean ignoreUnknownFields) {
ConstructingObjectParser<Parameters, Void> parser = new ConstructingObjectParser<>("outlier_detection_parameters",
ignoreUnknownFields,
a -> new Parameters(
(int) a[0],
(String) a[1],
(boolean) a[2],
(double) a[3],
(double) a[4],
(boolean) a[5]
));
parser.declareInt(constructorArg(), N_NEIGHBORS);
parser.declareString(constructorArg(), METHOD);
parser.declareBoolean(constructorArg(), COMPUTE_FEATURE_INFLUENCE);
parser.declareDouble(constructorArg(), FEATURE_INFLUENCE_THRESHOLD);
parser.declareDouble(constructorArg(), OUTLIER_FRACTION);
parser.declareBoolean(constructorArg(), STANDARDIZATION_ENABLED);
return parser;
}
private final int nNeighbors;
private final String method;
private final boolean computeFeatureInfluence;
private final double featureInfluenceThreshold;
private final double outlierFraction;
private final boolean standardizationEnabled;
public Parameters(int nNeighbors, String method, boolean computeFeatureInfluence, double featureInfluenceThreshold,
double outlierFraction, boolean standardizationEnabled) {
this.nNeighbors = nNeighbors;
this.method = method;
this.computeFeatureInfluence = computeFeatureInfluence;
this.featureInfluenceThreshold = featureInfluenceThreshold;
this.outlierFraction = outlierFraction;
this.standardizationEnabled = standardizationEnabled;
}
public Parameters(StreamInput in) throws IOException {
this.nNeighbors = in.readVInt();
this.method = in.readString();
this.computeFeatureInfluence = in.readBoolean();
this.featureInfluenceThreshold = in.readDouble();
this.outlierFraction = in.readDouble();
this.standardizationEnabled = in.readBoolean();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeVInt(nNeighbors);
out.writeString(method);
out.writeBoolean(computeFeatureInfluence);
out.writeDouble(featureInfluenceThreshold);
out.writeDouble(outlierFraction);
out.writeBoolean(standardizationEnabled);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(N_NEIGHBORS.getPreferredName(), nNeighbors);
builder.field(METHOD.getPreferredName(), method);
builder.field(COMPUTE_FEATURE_INFLUENCE.getPreferredName(), computeFeatureInfluence);
builder.field(FEATURE_INFLUENCE_THRESHOLD.getPreferredName(), featureInfluenceThreshold);
builder.field(OUTLIER_FRACTION.getPreferredName(), outlierFraction);
builder.field(STANDARDIZATION_ENABLED.getPreferredName(), standardizationEnabled);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Parameters that = (Parameters) o;
return nNeighbors == that.nNeighbors
&& Objects.equals(method, that.method)
&& computeFeatureInfluence == that.computeFeatureInfluence
&& featureInfluenceThreshold == that.featureInfluenceThreshold
&& outlierFraction == that.outlierFraction
&& standardizationEnabled == that.standardizationEnabled;
}
@Override
public int hashCode() {
return Objects.hash(nNeighbors, method, computeFeatureInfluence, featureInfluenceThreshold, outlierFraction,
standardizationEnabled);
}
}

View File

@ -0,0 +1,73 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.Objects;
public class TimingStats implements Writeable, ToXContentObject {
public static final ParseField ELAPSED_TIME = new ParseField("elapsed_time");
public static TimingStats fromXContent(XContentParser parser, boolean ignoreUnknownFields) {
return createParser(ignoreUnknownFields).apply(parser, null);
}
private static ConstructingObjectParser<TimingStats, Void> createParser(boolean ignoreUnknownFields) {
ConstructingObjectParser<TimingStats, Void> parser = new ConstructingObjectParser<>("outlier_detection_timing_stats",
ignoreUnknownFields,
a -> new TimingStats(TimeValue.timeValueMillis((long) a[0])));
parser.declareLong(ConstructingObjectParser.constructorArg(), ELAPSED_TIME);
return parser;
}
private final TimeValue elapsedTime;
public TimingStats(TimeValue elapsedTime) {
this.elapsedTime = Objects.requireNonNull(elapsedTime);
}
public TimingStats(StreamInput in) throws IOException {
this.elapsedTime = in.readTimeValue();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeTimeValue(elapsedTime);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.humanReadableField(ELAPSED_TIME.getPreferredName(), ELAPSED_TIME.getPreferredName() + "_string", elapsedTime);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
TimingStats that = (TimingStats) o;
return Objects.equals(elapsedTime, that.elapsedTime);
}
@Override
public int hashCode() {
return Objects.hash(elapsedTime);
}
}

View File

@ -0,0 +1,226 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.stats.regression;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.Objects;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
public class Hyperparameters implements ToXContentObject, Writeable {
public static final ParseField DOWNSAMPLE_FACTOR = new ParseField("downsample_factor");
public static final ParseField ETA = new ParseField("eta");
public static final ParseField ETA_GROWTH_RATE_PER_TREE = new ParseField("eta_growth_rate_per_tree");
public static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction");
public static final ParseField MAX_ATTEMPTS_TO_ADD_TREE = new ParseField("max_attempts_to_add_tree");
public static final ParseField MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER = new ParseField(
"max_optimization_rounds_per_hyperparameter");
public static final ParseField MAX_TREES = new ParseField("max_trees");
public static final ParseField NUM_FOLDS = new ParseField("num_folds");
public static final ParseField NUM_SPLITS_PER_FEATURE = new ParseField("num_splits_per_feature");
public static final ParseField REGULARIZATION_DEPTH_PENALTY_MULTIPLIER = new ParseField("regularization_depth_penalty_multiplier");
public static final ParseField REGULARIZATION_LEAF_WEIGHT_PENALTY_MULTIPLIER
= new ParseField("regularization_leaf_weight_penalty_multiplier");
public static final ParseField REGULARIZATION_SOFT_TREE_DEPTH_LIMIT = new ParseField("regularization_soft_tree_depth_limit");
public static final ParseField REGULARIZATION_SOFT_TREE_DEPTH_TOLERANCE = new ParseField("regularization_soft_tree_depth_tolerance");
public static final ParseField REGULARIZATION_TREE_SIZE_PENALTY_MULTIPLIER =
new ParseField("regularization_tree_size_penalty_multiplier");
public static Hyperparameters fromXContent(XContentParser parser, boolean ignoreUnknownFields) {
return createParser(ignoreUnknownFields).apply(parser, null);
}
private static ConstructingObjectParser<Hyperparameters, Void> createParser(boolean ignoreUnknownFields) {
ConstructingObjectParser<Hyperparameters, Void> parser = new ConstructingObjectParser<>("regression_hyperparameters",
ignoreUnknownFields,
a -> new Hyperparameters(
(double) a[0],
(double) a[1],
(double) a[2],
(double) a[3],
(int) a[4],
(int) a[5],
(int) a[6],
(int) a[7],
(int) a[8],
(double) a[9],
(double) a[10],
(double) a[11],
(double) a[12],
(double) a[13]
));
parser.declareDouble(constructorArg(), DOWNSAMPLE_FACTOR);
parser.declareDouble(constructorArg(), ETA);
parser.declareDouble(constructorArg(), ETA_GROWTH_RATE_PER_TREE);
parser.declareDouble(constructorArg(), FEATURE_BAG_FRACTION);
parser.declareInt(constructorArg(), MAX_ATTEMPTS_TO_ADD_TREE);
parser.declareInt(constructorArg(), MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER);
parser.declareInt(constructorArg(), MAX_TREES);
parser.declareInt(constructorArg(), NUM_FOLDS);
parser.declareInt(constructorArg(), NUM_SPLITS_PER_FEATURE);
parser.declareDouble(constructorArg(), REGULARIZATION_DEPTH_PENALTY_MULTIPLIER);
parser.declareDouble(constructorArg(), REGULARIZATION_LEAF_WEIGHT_PENALTY_MULTIPLIER);
parser.declareDouble(constructorArg(), REGULARIZATION_SOFT_TREE_DEPTH_LIMIT);
parser.declareDouble(constructorArg(), REGULARIZATION_SOFT_TREE_DEPTH_TOLERANCE);
parser.declareDouble(constructorArg(), REGULARIZATION_TREE_SIZE_PENALTY_MULTIPLIER);
return parser;
}
private final double downsampleFactor;
private final double eta;
private final double etaGrowthRatePerTree;
private final double featureBagFraction;
private final int maxAttemptsToAddTree;
private final int maxOptimizationRoundsPerHyperparameter;
private final int maxTrees;
private final int numFolds;
private final int numSplitsPerFeature;
private final double regularizationDepthPenaltyMultiplier;
private final double regularizationLeafWeightPenaltyMultiplier;
private final double regularizationSoftTreeDepthLimit;
private final double regularizationSoftTreeDepthTolerance;
private final double regularizationTreeSizePenaltyMultiplier;
public Hyperparameters(double downsampleFactor,
double eta,
double etaGrowthRatePerTree,
double featureBagFraction,
int maxAttemptsToAddTree,
int maxOptimizationRoundsPerHyperparameter,
int maxTrees,
int numFolds,
int numSplitsPerFeature,
double regularizationDepthPenaltyMultiplier,
double regularizationLeafWeightPenaltyMultiplier,
double regularizationSoftTreeDepthLimit,
double regularizationSoftTreeDepthTolerance,
double regularizationTreeSizePenaltyMultiplier) {
this.downsampleFactor = downsampleFactor;
this.eta = eta;
this.etaGrowthRatePerTree = etaGrowthRatePerTree;
this.featureBagFraction = featureBagFraction;
this.maxAttemptsToAddTree = maxAttemptsToAddTree;
this.maxOptimizationRoundsPerHyperparameter = maxOptimizationRoundsPerHyperparameter;
this.maxTrees = maxTrees;
this.numFolds = numFolds;
this.numSplitsPerFeature = numSplitsPerFeature;
this.regularizationDepthPenaltyMultiplier = regularizationDepthPenaltyMultiplier;
this.regularizationLeafWeightPenaltyMultiplier = regularizationLeafWeightPenaltyMultiplier;
this.regularizationSoftTreeDepthLimit = regularizationSoftTreeDepthLimit;
this.regularizationSoftTreeDepthTolerance = regularizationSoftTreeDepthTolerance;
this.regularizationTreeSizePenaltyMultiplier = regularizationTreeSizePenaltyMultiplier;
}
public Hyperparameters(StreamInput in) throws IOException {
this.downsampleFactor = in.readDouble();
this.eta = in.readDouble();
this.etaGrowthRatePerTree = in.readDouble();
this.featureBagFraction = in.readDouble();
this.maxAttemptsToAddTree = in.readVInt();
this.maxOptimizationRoundsPerHyperparameter = in.readVInt();
this.maxTrees = in.readVInt();
this.numFolds = in.readVInt();
this.numSplitsPerFeature = in.readVInt();
this.regularizationDepthPenaltyMultiplier = in.readDouble();
this.regularizationLeafWeightPenaltyMultiplier = in.readDouble();
this.regularizationSoftTreeDepthLimit = in.readDouble();
this.regularizationSoftTreeDepthTolerance = in.readDouble();
this.regularizationTreeSizePenaltyMultiplier = in.readDouble();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeDouble(downsampleFactor);
out.writeDouble(eta);
out.writeDouble(etaGrowthRatePerTree);
out.writeDouble(featureBagFraction);
out.writeVInt(maxAttemptsToAddTree);
out.writeVInt(maxOptimizationRoundsPerHyperparameter);
out.writeVInt(maxTrees);
out.writeVInt(numFolds);
out.writeVInt(numSplitsPerFeature);
out.writeDouble(regularizationDepthPenaltyMultiplier);
out.writeDouble(regularizationLeafWeightPenaltyMultiplier);
out.writeDouble(regularizationSoftTreeDepthLimit);
out.writeDouble(regularizationSoftTreeDepthTolerance);
out.writeDouble(regularizationTreeSizePenaltyMultiplier);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(DOWNSAMPLE_FACTOR.getPreferredName(), downsampleFactor);
builder.field(ETA.getPreferredName(), eta);
builder.field(ETA_GROWTH_RATE_PER_TREE.getPreferredName(), etaGrowthRatePerTree);
builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction);
builder.field(MAX_ATTEMPTS_TO_ADD_TREE.getPreferredName(), maxAttemptsToAddTree);
builder.field(MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER.getPreferredName(), maxOptimizationRoundsPerHyperparameter);
builder.field(MAX_TREES.getPreferredName(), maxTrees);
builder.field(NUM_FOLDS.getPreferredName(), numFolds);
builder.field(NUM_SPLITS_PER_FEATURE.getPreferredName(), numSplitsPerFeature);
builder.field(REGULARIZATION_DEPTH_PENALTY_MULTIPLIER.getPreferredName(), regularizationDepthPenaltyMultiplier);
builder.field(REGULARIZATION_LEAF_WEIGHT_PENALTY_MULTIPLIER.getPreferredName(), regularizationLeafWeightPenaltyMultiplier);
builder.field(REGULARIZATION_SOFT_TREE_DEPTH_LIMIT.getPreferredName(), regularizationSoftTreeDepthLimit);
builder.field(REGULARIZATION_SOFT_TREE_DEPTH_TOLERANCE.getPreferredName(), regularizationSoftTreeDepthTolerance);
builder.field(REGULARIZATION_TREE_SIZE_PENALTY_MULTIPLIER.getPreferredName(), regularizationTreeSizePenaltyMultiplier);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Hyperparameters that = (Hyperparameters) o;
return downsampleFactor == that.downsampleFactor
&& eta == that.eta
&& etaGrowthRatePerTree == that.etaGrowthRatePerTree
&& featureBagFraction == that.featureBagFraction
&& maxAttemptsToAddTree == that.maxAttemptsToAddTree
&& maxOptimizationRoundsPerHyperparameter == that.maxOptimizationRoundsPerHyperparameter
&& maxTrees == that.maxTrees
&& numFolds == that.numFolds
&& numSplitsPerFeature == that.numSplitsPerFeature
&& regularizationDepthPenaltyMultiplier == that.regularizationDepthPenaltyMultiplier
&& regularizationLeafWeightPenaltyMultiplier == that.regularizationLeafWeightPenaltyMultiplier
&& regularizationSoftTreeDepthLimit == that.regularizationSoftTreeDepthLimit
&& regularizationSoftTreeDepthTolerance == that.regularizationSoftTreeDepthTolerance
&& regularizationTreeSizePenaltyMultiplier == that.regularizationTreeSizePenaltyMultiplier;
}
@Override
public int hashCode() {
return Objects.hash(
downsampleFactor,
eta,
etaGrowthRatePerTree,
featureBagFraction,
maxAttemptsToAddTree,
maxOptimizationRoundsPerHyperparameter,
maxTrees,
numFolds,
numSplitsPerFeature,
regularizationDepthPenaltyMultiplier,
regularizationLeafWeightPenaltyMultiplier,
regularizationSoftTreeDepthLimit,
regularizationSoftTreeDepthTolerance,
regularizationTreeSizePenaltyMultiplier
);
}
}

View File

@ -0,0 +1,148 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.stats.regression;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.common.time.TimeUtils;
import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStats;
import org.elasticsearch.xpack.core.ml.dataframe.stats.Fields;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
import java.io.IOException;
import java.time.Instant;
import java.util.Objects;
public class RegressionStats implements AnalysisStats {
public static final String TYPE_VALUE = "regression_stats";
public static final ParseField ITERATION = new ParseField("iteration");
public static final ParseField HYPERPARAMETERS = new ParseField("hyperparameters");
public static final ParseField TIMING_STATS = new ParseField("timing_stats");
public static final ParseField VALIDATION_LOSS = new ParseField("validation_loss");
public static final ConstructingObjectParser<RegressionStats, Void> STRICT_PARSER = createParser(false);
public static final ConstructingObjectParser<RegressionStats, Void> LENIENT_PARSER = createParser(true);
private static ConstructingObjectParser<RegressionStats, Void> createParser(boolean ignoreUnknownFields) {
ConstructingObjectParser<RegressionStats, Void> parser = new ConstructingObjectParser<>(TYPE_VALUE, ignoreUnknownFields,
a -> new RegressionStats(
(String) a[0],
(Instant) a[1],
(int) a[2],
(Hyperparameters) a[3],
(TimingStats) a[4],
(ValidationLoss) a[5]
)
);
parser.declareString((bucket, s) -> {}, Fields.TYPE);
parser.declareString(ConstructingObjectParser.constructorArg(), Fields.JOB_ID);
parser.declareField(ConstructingObjectParser.constructorArg(),
p -> TimeUtils.parseTimeFieldToInstant(p, Fields.TIMESTAMP.getPreferredName()),
Fields.TIMESTAMP,
ObjectParser.ValueType.VALUE);
parser.declareInt(ConstructingObjectParser.constructorArg(), ITERATION);
parser.declareObject(ConstructingObjectParser.constructorArg(),
(p, c) -> Hyperparameters.fromXContent(p, ignoreUnknownFields), HYPERPARAMETERS);
parser.declareObject(ConstructingObjectParser.constructorArg(),
(p, c) -> TimingStats.fromXContent(p, ignoreUnknownFields), TIMING_STATS);
parser.declareObject(ConstructingObjectParser.constructorArg(),
(p, c) -> ValidationLoss.fromXContent(p, ignoreUnknownFields), VALIDATION_LOSS);
return parser;
}
private final String jobId;
private final Instant timestamp;
private final int iteration;
private final Hyperparameters hyperparameters;
private final TimingStats timingStats;
private final ValidationLoss validationLoss;
public RegressionStats(String jobId, Instant timestamp, int iteration, Hyperparameters hyperparameters, TimingStats timingStats,
ValidationLoss validationLoss) {
this.jobId = Objects.requireNonNull(jobId);
// We intend to store this timestamp in millis granularity. Thus we're rounding here to ensure
// internal representation matches toXContent
this.timestamp = Instant.ofEpochMilli(ExceptionsHelper.requireNonNull(timestamp, Fields.TIMESTAMP).toEpochMilli());
this.iteration = iteration;
this.hyperparameters = Objects.requireNonNull(hyperparameters);
this.timingStats = Objects.requireNonNull(timingStats);
this.validationLoss = Objects.requireNonNull(validationLoss);
}
public RegressionStats(StreamInput in) throws IOException {
this.jobId = in.readString();
this.timestamp = in.readInstant();
this.iteration = in.readVInt();
this.hyperparameters = new Hyperparameters(in);
this.timingStats = new TimingStats(in);
this.validationLoss = new ValidationLoss(in);
}
@Override
public String getWriteableName() {
return TYPE_VALUE;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(jobId);
out.writeInstant(timestamp);
out.writeVInt(iteration);
hyperparameters.writeTo(out);
timingStats.writeTo(out);
validationLoss.writeTo(out);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false)) {
builder.field(Fields.TYPE.getPreferredName(), TYPE_VALUE);
builder.field(Fields.JOB_ID.getPreferredName(), jobId);
}
builder.timeField(Fields.TIMESTAMP.getPreferredName(), Fields.TIMESTAMP.getPreferredName() + "_string", timestamp.toEpochMilli());
builder.field(ITERATION.getPreferredName(), iteration);
builder.field(HYPERPARAMETERS.getPreferredName(), hyperparameters);
builder.field(TIMING_STATS.getPreferredName(), timingStats);
builder.field(VALIDATION_LOSS.getPreferredName(), validationLoss);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
RegressionStats that = (RegressionStats) o;
return Objects.equals(jobId, that.jobId)
&& Objects.equals(timestamp, that.timestamp)
&& iteration == that.iteration
&& Objects.equals(hyperparameters, that.hyperparameters)
&& Objects.equals(timingStats, that.timingStats)
&& Objects.equals(validationLoss, that.validationLoss);
}
@Override
public int hashCode() {
return Objects.hash(jobId, timestamp, iteration, hyperparameters, timingStats, validationLoss);
}
public String documentId(String jobId) {
return documentIdPrefix(jobId) + timestamp.toEpochMilli();
}
public static String documentIdPrefix(String jobId) {
return TYPE_VALUE + "_" + jobId + "_";
}
}

View File

@ -0,0 +1,79 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.stats.regression;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.Objects;
public class TimingStats implements Writeable, ToXContentObject {
public static final ParseField ELAPSED_TIME = new ParseField("elapsed_time");
public static final ParseField ITERATION_TIME = new ParseField("iteration_time");
public static TimingStats fromXContent(XContentParser parser, boolean ignoreUnknownFields) {
return createParser(ignoreUnknownFields).apply(parser, null);
}
private static ConstructingObjectParser<TimingStats, Void> createParser(boolean ignoreUnknownFields) {
ConstructingObjectParser<TimingStats, Void> parser = new ConstructingObjectParser<>("regression_timing_stats", ignoreUnknownFields,
a -> new TimingStats(TimeValue.timeValueMillis((long) a[0]), TimeValue.timeValueMillis((long) a[1])));
parser.declareLong(ConstructingObjectParser.constructorArg(), ELAPSED_TIME);
parser.declareLong(ConstructingObjectParser.constructorArg(), ITERATION_TIME);
return parser;
}
private final TimeValue elapsedTime;
private final TimeValue iterationTime;
public TimingStats(TimeValue elapsedTime, TimeValue iterationTime) {
this.elapsedTime = Objects.requireNonNull(elapsedTime);
this.iterationTime = Objects.requireNonNull(iterationTime);
}
public TimingStats(StreamInput in) throws IOException {
this.elapsedTime = in.readTimeValue();
this.iterationTime = in.readTimeValue();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeTimeValue(elapsedTime);
out.writeTimeValue(iterationTime);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.humanReadableField(ELAPSED_TIME.getPreferredName(), ELAPSED_TIME.getPreferredName() + "_string", elapsedTime);
builder.humanReadableField(ITERATION_TIME.getPreferredName(), ITERATION_TIME.getPreferredName() + "_string", iterationTime);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
TimingStats that = (TimingStats) o;
return Objects.equals(elapsedTime, that.elapsedTime) && Objects.equals(iterationTime, that.iterationTime);
}
@Override
public int hashCode() {
return Objects.hash(elapsedTime, iterationTime);
}
}

View File

@ -0,0 +1,82 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.stats.regression;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.dataframe.stats.common.FoldValues;
import java.io.IOException;
import java.util.List;
import java.util.Objects;
public class ValidationLoss implements ToXContentObject, Writeable {
public static final ParseField LOSS_TYPE = new ParseField("loss_type");
public static final ParseField FOLD_VALUES = new ParseField("fold_values");
public static ValidationLoss fromXContent(XContentParser parser, boolean ignoreUnknownFields) {
return createParser(ignoreUnknownFields).apply(parser, null);
}
private static ConstructingObjectParser<ValidationLoss, Void> createParser(boolean ignoreUnknownFields) {
ConstructingObjectParser<ValidationLoss, Void> parser = new ConstructingObjectParser<>("regression_validation_loss",
ignoreUnknownFields,
a -> new ValidationLoss((String) a[0], (List<FoldValues>) a[1]));
parser.declareString(ConstructingObjectParser.constructorArg(), LOSS_TYPE);
parser.declareObjectArray(ConstructingObjectParser.constructorArg(),
(p, c) -> FoldValues.fromXContent(p, ignoreUnknownFields), FOLD_VALUES);
return parser;
}
private final String lossType;
private final List<FoldValues> foldValues;
public ValidationLoss(String lossType, List<FoldValues> values) {
this.lossType = Objects.requireNonNull(lossType);
this.foldValues = Objects.requireNonNull(values);
}
public ValidationLoss(StreamInput in) throws IOException {
lossType = in.readString();
foldValues = in.readList(FoldValues::new);
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(lossType);
out.writeList(foldValues);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(LOSS_TYPE.getPreferredName(), lossType);
builder.field(FOLD_VALUES.getPreferredName(), foldValues);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ValidationLoss that = (ValidationLoss) o;
return Objects.equals(lossType, that.lossType) && Objects.equals(foldValues, that.foldValues);
}
@Override
public int hashCode() {
return Objects.hash(lossType, foldValues);
}
}

View File

@ -3,18 +3,120 @@
"_meta": {
"version" : "${xpack.ml.version}"
},
"dynamic": false,
"properties" : {
"type" : {
"type" : "keyword"
"iteration": {
"type": "integer"
},
"hyperparameters": {
"properties": {
"class_assignment_objective": {
"type": "keyword"
},
"downsample_factor": {
"type": "double"
},
"eta": {
"type": "double"
},
"eta_growth_rate_per_tree": {
"type": "double"
},
"feature_bag_fraction": {
"type": "double"
},
"max_attempts_to_add_tree": {
"type": "integer"
},
"max_optimization_rounds_per_hyperparameter": {
"type": "integer"
},
"max_trees": {
"type": "integer"
},
"num_folds": {
"type": "integer"
},
"num_splits_per_feature": {
"type": "integer"
},
"regularization_depth_penalty_multiplier": {
"type": "double"
},
"regularization_leaf_weight_penalty_multiplier": {
"type": "double"
},
"regularization_soft_tree_depth_limit": {
"type": "double"
},
"regularization_soft_tree_depth_tolerance": {
"type": "double"
},
"regularization_tree_size_penalty_multiplier": {
"type": "double"
}
}
},
"job_id" : {
"type" : "keyword"
},
"timestamp" : {
"type" : "date"
"parameters": {
"properties": {
"compute_feature_influence": {
"type": "boolean"
},
"feature_influence_threshold": {
"type": "double"
},
"method": {
"type": "keyword"
},
"n_neighbors": {
"type": "integer"
},
"outlier_fraction": {
"type": "double"
},
"standardization_enabled": {
"type": "boolean"
}
}
},
"peak_usage_bytes" : {
"type" : "long"
},
"timestamp" : {
"type" : "date"
},
"timing_stats": {
"properties": {
"elapsed_time": {
"type": "long"
},
"iteration_time": {
"type": "long"
}
}
},
"type" : {
"type" : "keyword"
},
"validation_loss": {
"properties": {
"fold_values": {
"properties": {
"fold": {
"type": "integer"
},
"values": {
"type": "double"
}
}
},
"loss_type": {
"type": "keyword"
}
}
}
}
}

View File

@ -5,14 +5,20 @@
*/
package org.elasticsearch.xpack.core.ml.action;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.action.util.QueryPage;
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction.Response;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigTests;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStats;
import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStatsNamedWriteablesProvider;
import org.elasticsearch.xpack.core.ml.dataframe.stats.MemoryUsage;
import org.elasticsearch.xpack.core.ml.dataframe.stats.MemoryUsageTests;
import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStatsTests;
import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStatsTests;
import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStatsTests;
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
import java.util.ArrayList;
@ -21,6 +27,13 @@ import java.util.stream.IntStream;
public class GetDataFrameAnalyticsStatsActionResponseTests extends AbstractWireSerializingTestCase<Response> {
@Override
protected NamedWriteableRegistry getNamedWriteableRegistry() {
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
namedWriteables.addAll(new AnalysisStatsNamedWriteablesProvider().getNamedWriteables());
return new NamedWriteableRegistry(namedWriteables);
}
public static Response randomResponse(int listSize) {
List<Response.Stats> analytics = new ArrayList<>(listSize);
for (int j = 0; j < listSize; j++) {
@ -30,8 +43,15 @@ public class GetDataFrameAnalyticsStatsActionResponseTests extends AbstractWireS
IntStream.of(progressSize).forEach(progressIndex -> progress.add(
new PhaseProgress(randomAlphaOfLength(10), randomIntBetween(0, 100))));
MemoryUsage memoryUsage = randomBoolean() ? null : MemoryUsageTests.createRandom();
AnalysisStats analysisStats = randomBoolean() ? null :
randomFrom(
ClassificationStatsTests.createRandom(),
OutlierDetectionStatsTests.createRandom(),
RegressionStatsTests.createRandom()
);
Response.Stats stats = new Response.Stats(DataFrameAnalyticsConfigTests.randomValidId(),
randomFrom(DataFrameAnalyticsState.values()), failureReason, progress, memoryUsage, null, randomAlphaOfLength(20));
randomFrom(DataFrameAnalyticsState.values()), failureReason, progress, memoryUsage, analysisStats, null,
randomAlphaOfLength(20));
analytics.add(stats);
}
return new Response(new QueryPage<>(analytics, analytics.size(), GetDataFrameAnalyticsAction.Response.RESULTS_FIELD));

View File

@ -0,0 +1,69 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.stats.classification;
import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
import org.junit.Before;
import java.io.IOException;
import java.time.Instant;
import java.util.Collections;
public class ClassificationStatsTests extends AbstractBWCSerializationTestCase<ClassificationStats> {
private boolean lenient;
@Before
public void chooseLenient() {
lenient = randomBoolean();
}
@Override
protected boolean supportsUnknownFields() {
return lenient;
}
@Override
protected ClassificationStats mutateInstanceForVersion(ClassificationStats instance, Version version) {
return instance;
}
@Override
protected ClassificationStats doParseInstance(XContentParser parser) throws IOException {
return lenient ? ClassificationStats.LENIENT_PARSER.apply(parser, null) : ClassificationStats.STRICT_PARSER.apply(parser, null);
}
@Override
protected ToXContent.Params getToXContentParams() {
return new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true"));
}
@Override
protected Writeable.Reader<ClassificationStats> instanceReader() {
return ClassificationStats::new;
}
@Override
protected ClassificationStats createTestInstance() {
return createRandom();
}
public static ClassificationStats createRandom() {
return new ClassificationStats(
randomAlphaOfLength(10),
Instant.now(),
randomIntBetween(1, Integer.MAX_VALUE),
HyperparametersTests.createRandom(),
TimingStatsTests.createRandom(),
ValidationLossTests.createRandom()
);
}
}

View File

@ -0,0 +1,69 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.stats.classification;
import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
import org.junit.Before;
import java.io.IOException;
public class HyperparametersTests extends AbstractBWCSerializationTestCase<Hyperparameters> {
private boolean lenient;
@Before
public void chooseLenient() {
lenient = randomBoolean();
}
@Override
protected boolean supportsUnknownFields() {
return lenient;
}
@Override
protected Hyperparameters mutateInstanceForVersion(Hyperparameters instance, Version version) {
return instance;
}
@Override
protected Hyperparameters doParseInstance(XContentParser parser) throws IOException {
return Hyperparameters.fromXContent(parser, lenient);
}
@Override
protected Writeable.Reader<Hyperparameters> instanceReader() {
return Hyperparameters::new;
}
@Override
protected Hyperparameters createTestInstance() {
return createRandom();
}
public static Hyperparameters createRandom() {
return new Hyperparameters(
randomAlphaOfLength(10),
randomDouble(),
randomDouble(),
randomDouble(),
randomDouble(),
randomIntBetween(0, Integer.MAX_VALUE),
randomIntBetween(0, Integer.MAX_VALUE),
randomIntBetween(0, Integer.MAX_VALUE),
randomIntBetween(0, Integer.MAX_VALUE),
randomIntBetween(0, Integer.MAX_VALUE),
randomDouble(),
randomDouble(),
randomDouble(),
randomDouble(),
randomDouble()
);
}
}

View File

@ -0,0 +1,54 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.stats.classification;
import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
import org.junit.Before;
import java.io.IOException;
public class TimingStatsTests extends AbstractBWCSerializationTestCase<TimingStats> {
private boolean lenient;
@Before
public void chooseLenient() {
lenient = randomBoolean();
}
@Override
protected boolean supportsUnknownFields() {
return lenient;
}
@Override
protected TimingStats mutateInstanceForVersion(TimingStats instance, Version version) {
return instance;
}
@Override
protected TimingStats doParseInstance(XContentParser parser) throws IOException {
return TimingStats.fromXContent(parser, lenient);
}
@Override
protected Writeable.Reader<TimingStats> instanceReader() {
return TimingStats::new;
}
@Override
protected TimingStats createTestInstance() {
return createRandom();
}
public static TimingStats createRandom() {
return new TimingStats(TimeValue.timeValueMillis(randomNonNegativeLong()), TimeValue.timeValueMillis(randomNonNegativeLong()));
}
}

View File

@ -0,0 +1,52 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.stats.classification;
import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.stats.common.FoldValuesTests;
import org.junit.Before;
import java.io.IOException;
public class ValidationLossTests extends AbstractBWCSerializationTestCase<ValidationLoss> {
private boolean lenient;
@Before
public void chooseStrictOrLenient() {
lenient = randomBoolean();
}
@Override
protected ValidationLoss doParseInstance(XContentParser parser) throws IOException {
return ValidationLoss.fromXContent(parser, lenient);
}
@Override
protected Writeable.Reader<ValidationLoss> instanceReader() {
return ValidationLoss::new;
}
@Override
protected ValidationLoss createTestInstance() {
return createRandom();
}
public static ValidationLoss createRandom() {
return new ValidationLoss(
randomAlphaOfLength(10),
randomList(5, () -> FoldValuesTests.createRandom())
);
}
@Override
protected ValidationLoss mutateInstanceForVersion(ValidationLoss instance, Version version) {
return instance;
}
}

View File

@ -0,0 +1,58 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.stats.common;
import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
import org.junit.Before;
import java.io.IOException;
public class FoldValuesTests extends AbstractBWCSerializationTestCase<FoldValues> {
private boolean lenient;
@Before
public void chooseLenient() {
lenient = randomBoolean();
}
@Override
protected boolean supportsUnknownFields() {
return lenient;
}
@Override
protected FoldValues doParseInstance(XContentParser parser) throws IOException {
return FoldValues.fromXContent(parser, lenient);
}
@Override
protected Writeable.Reader<FoldValues> instanceReader() {
return FoldValues::new;
}
@Override
protected FoldValues createTestInstance() {
return createRandom();
}
public static FoldValues createRandom() {
int valuesSize = randomIntBetween(0, 10);
double[] values = new double[valuesSize];
for (int i = 0; i < valuesSize; i++) {
values[i] = randomDouble();
}
return new FoldValues(randomIntBetween(0, Integer.MAX_VALUE), values);
}
@Override
protected FoldValues mutateInstanceForVersion(FoldValues instance, Version version) {
return instance;
}
}

View File

@ -0,0 +1,68 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection;
import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
import org.junit.Before;
import java.io.IOException;
import java.time.Instant;
import java.util.Collections;
public class OutlierDetectionStatsTests extends AbstractBWCSerializationTestCase<OutlierDetectionStats> {
private boolean lenient;
@Before
public void chooseLenient() {
lenient = randomBoolean();
}
@Override
protected boolean supportsUnknownFields() {
return lenient;
}
@Override
protected OutlierDetectionStats mutateInstanceForVersion(OutlierDetectionStats instance, Version version) {
return instance;
}
@Override
protected OutlierDetectionStats doParseInstance(XContentParser parser) throws IOException {
return lenient ? OutlierDetectionStats.LENIENT_PARSER.apply(parser, null)
: OutlierDetectionStats.STRICT_PARSER.apply(parser, null);
}
@Override
protected ToXContent.Params getToXContentParams() {
return new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true"));
}
@Override
protected Writeable.Reader<OutlierDetectionStats> instanceReader() {
return OutlierDetectionStats::new;
}
@Override
protected OutlierDetectionStats createTestInstance() {
return createRandom();
}
public static OutlierDetectionStats createRandom() {
return new OutlierDetectionStats(
randomAlphaOfLength(10),
Instant.now(),
ParametersTests.createRandom(),
TimingStatsTests.createRandom()
);
}
}

View File

@ -0,0 +1,61 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection;
import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
import org.junit.Before;
import java.io.IOException;
public class ParametersTests extends AbstractBWCSerializationTestCase<Parameters> {
private boolean lenient;
@Before
public void chooseLenient() {
lenient = randomBoolean();
}
@Override
protected boolean supportsUnknownFields() {
return lenient;
}
@Override
protected Parameters mutateInstanceForVersion(Parameters instance, Version version) {
return instance;
}
@Override
protected Parameters doParseInstance(XContentParser parser) throws IOException {
return Parameters.fromXContent(parser, lenient);
}
@Override
protected Writeable.Reader<Parameters> instanceReader() {
return Parameters::new;
}
@Override
protected Parameters createTestInstance() {
return createRandom();
}
public static Parameters createRandom() {
return new Parameters(
randomIntBetween(1, Integer.MAX_VALUE),
randomAlphaOfLength(5),
randomBoolean(),
randomDouble(),
randomDouble(),
randomBoolean()
);
}
}

View File

@ -0,0 +1,54 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection;
import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
import org.junit.Before;
import java.io.IOException;
public class TimingStatsTests extends AbstractBWCSerializationTestCase<TimingStats> {
private boolean lenient;
@Before
public void chooseLenient() {
lenient = randomBoolean();
}
@Override
protected boolean supportsUnknownFields() {
return lenient;
}
@Override
protected TimingStats mutateInstanceForVersion(TimingStats instance, Version version) {
return instance;
}
@Override
protected TimingStats doParseInstance(XContentParser parser) throws IOException {
return TimingStats.fromXContent(parser, lenient);
}
@Override
protected Writeable.Reader<TimingStats> instanceReader() {
return TimingStats::new;
}
@Override
protected TimingStats createTestInstance() {
return createRandom();
}
public static TimingStats createRandom() {
return new TimingStats(TimeValue.timeValueMillis(randomNonNegativeLong()));
}
}

View File

@ -0,0 +1,68 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.stats.regression;
import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
import org.junit.Before;
import java.io.IOException;
public class HyperparametersTests extends AbstractBWCSerializationTestCase<Hyperparameters> {
private boolean lenient;
@Before
public void chooseLenient() {
lenient = randomBoolean();
}
@Override
protected boolean supportsUnknownFields() {
return lenient;
}
@Override
protected Hyperparameters mutateInstanceForVersion(Hyperparameters instance, Version version) {
return instance;
}
@Override
protected Hyperparameters doParseInstance(XContentParser parser) throws IOException {
return Hyperparameters.fromXContent(parser, lenient);
}
@Override
protected Writeable.Reader<Hyperparameters> instanceReader() {
return Hyperparameters::new;
}
@Override
protected Hyperparameters createTestInstance() {
return createRandom();
}
public static Hyperparameters createRandom() {
return new Hyperparameters(
randomDouble(),
randomDouble(),
randomDouble(),
randomDouble(),
randomIntBetween(0, Integer.MAX_VALUE),
randomIntBetween(0, Integer.MAX_VALUE),
randomIntBetween(0, Integer.MAX_VALUE),
randomIntBetween(0, Integer.MAX_VALUE),
randomIntBetween(0, Integer.MAX_VALUE),
randomDouble(),
randomDouble(),
randomDouble(),
randomDouble(),
randomDouble()
);
}
}

View File

@ -0,0 +1,69 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.stats.regression;
import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
import org.junit.Before;
import java.io.IOException;
import java.time.Instant;
import java.util.Collections;
public class RegressionStatsTests extends AbstractBWCSerializationTestCase<RegressionStats> {
private boolean lenient;
@Before
public void chooseLenient() {
lenient = randomBoolean();
}
@Override
protected boolean supportsUnknownFields() {
return lenient;
}
@Override
protected RegressionStats mutateInstanceForVersion(RegressionStats instance, Version version) {
return instance;
}
@Override
protected RegressionStats doParseInstance(XContentParser parser) throws IOException {
return lenient ? RegressionStats.LENIENT_PARSER.apply(parser, null) : RegressionStats.STRICT_PARSER.apply(parser, null);
}
@Override
protected ToXContent.Params getToXContentParams() {
return new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true"));
}
@Override
protected Writeable.Reader<RegressionStats> instanceReader() {
return RegressionStats::new;
}
@Override
protected RegressionStats createTestInstance() {
return createRandom();
}
public static RegressionStats createRandom() {
return new RegressionStats(
randomAlphaOfLength(10),
Instant.now(),
randomIntBetween(1, Integer.MAX_VALUE),
HyperparametersTests.createRandom(),
TimingStatsTests.createRandom(),
ValidationLossTests.createRandom()
);
}
}

View File

@ -0,0 +1,54 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.stats.regression;
import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
import org.junit.Before;
import java.io.IOException;
public class TimingStatsTests extends AbstractBWCSerializationTestCase<TimingStats> {
private boolean lenient;
@Before
public void chooseLenient() {
lenient = randomBoolean();
}
@Override
protected boolean supportsUnknownFields() {
return lenient;
}
@Override
protected TimingStats mutateInstanceForVersion(TimingStats instance, Version version) {
return instance;
}
@Override
protected TimingStats doParseInstance(XContentParser parser) throws IOException {
return TimingStats.fromXContent(parser, lenient);
}
@Override
protected Writeable.Reader<TimingStats> instanceReader() {
return TimingStats::new;
}
@Override
protected TimingStats createTestInstance() {
return createRandom();
}
public static TimingStats createRandom() {
return new TimingStats(TimeValue.timeValueMillis(randomNonNegativeLong()), TimeValue.timeValueMillis(randomNonNegativeLong()));
}
}

View File

@ -0,0 +1,52 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.stats.regression;
import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.stats.common.FoldValuesTests;
import org.junit.Before;
import java.io.IOException;
public class ValidationLossTests extends AbstractBWCSerializationTestCase<ValidationLoss> {
private boolean lenient;
@Before
public void chooseStrictOrLenient() {
lenient = randomBoolean();
}
@Override
protected ValidationLoss doParseInstance(XContentParser parser) throws IOException {
return ValidationLoss.fromXContent(parser, lenient);
}
@Override
protected Writeable.Reader<ValidationLoss> instanceReader() {
return ValidationLoss::new;
}
@Override
protected ValidationLoss createTestInstance() {
return createRandom();
}
public static ValidationLoss createRandom() {
return new ValidationLoss(
randomAlphaOfLength(10),
randomList(5, () -> FoldValuesTests.createRandom())
);
}
@Override
protected ValidationLoss mutateInstanceForVersion(ValidationLoss instance, Version version) {
return instance;
}
}

View File

@ -41,7 +41,12 @@ import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction.R
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState;
import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStats;
import org.elasticsearch.xpack.core.ml.dataframe.stats.Fields;
import org.elasticsearch.xpack.core.ml.dataframe.stats.MemoryUsage;
import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStats;
import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStats;
import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStats;
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
@ -103,7 +108,8 @@ public class TransportGetDataFrameAnalyticsStatsAction
Stats stats = buildStats(
task.getParams().getId(),
statsHolder.getProgressTracker().report(),
statsHolder.getMemoryUsage()
statsHolder.getMemoryUsage(),
statsHolder.getAnalysisStats()
);
listener.onResponse(new QueryPage<>(Collections.singletonList(stats), 1,
GetDataFrameAnalyticsAction.Response.RESULTS_FIELD));
@ -192,7 +198,10 @@ public class TransportGetDataFrameAnalyticsStatsAction
MultiSearchRequest multiSearchRequest = new MultiSearchRequest();
multiSearchRequest.add(buildStoredProgressSearch(configId));
multiSearchRequest.add(buildMemoryUsageSearch(configId));
multiSearchRequest.add(buildStatsDocSearch(configId, MemoryUsage.TYPE_VALUE));
multiSearchRequest.add(buildStatsDocSearch(configId, OutlierDetectionStats.TYPE_VALUE));
multiSearchRequest.add(buildStatsDocSearch(configId, ClassificationStats.TYPE_VALUE));
multiSearchRequest.add(buildStatsDocSearch(configId, RegressionStats.TYPE_VALUE));
executeAsyncWithOrigin(client, ML_ORIGIN, MultiSearchAction.INSTANCE, multiSearchRequest, ActionListener.wrap(
multiSearchResponse -> {
@ -213,7 +222,8 @@ public class TransportGetDataFrameAnalyticsStatsAction
}
listener.onResponse(buildStats(configId,
retrievedStatsHolder.progress.get(),
retrievedStatsHolder.memoryUsage
retrievedStatsHolder.memoryUsage,
retrievedStatsHolder.analysisStats
));
},
e -> listener.onFailure(ExceptionsHelper.serverError("Error searching for stats", e))
@ -228,18 +238,17 @@ public class TransportGetDataFrameAnalyticsStatsAction
return searchRequest;
}
private static SearchRequest buildMemoryUsageSearch(String configId) {
private static SearchRequest buildStatsDocSearch(String configId, String statsType) {
SearchRequest searchRequest = new SearchRequest(MlStatsIndex.indexPattern());
searchRequest.indicesOptions(IndicesOptions.lenientExpandOpen());
searchRequest.source().size(1);
QueryBuilder query = QueryBuilders.boolQuery()
.filter(QueryBuilders.termQuery(MemoryUsage.JOB_ID.getPreferredName(), configId))
.filter(QueryBuilders.termQuery(MemoryUsage.TYPE.getPreferredName(), MemoryUsage.TYPE_VALUE));
.filter(QueryBuilders.termQuery(Fields.JOB_ID.getPreferredName(), configId))
.filter(QueryBuilders.termQuery(Fields.TYPE.getPreferredName(), statsType));
searchRequest.source().query(query);
searchRequest.source().sort(SortBuilders.fieldSort(MemoryUsage.TIMESTAMP.getPreferredName()).order(SortOrder.DESC)
searchRequest.source().sort(SortBuilders.fieldSort(Fields.TIMESTAMP.getPreferredName()).order(SortOrder.DESC)
// We need this for the search not to fail when there are no mappings yet in the index
.unmappedType("long"));
searchRequest.source().sort(MemoryUsage.TIMESTAMP.getPreferredName(), SortOrder.DESC);
return searchRequest;
}
@ -249,6 +258,12 @@ public class TransportGetDataFrameAnalyticsStatsAction
retrievedStatsHolder.progress = MlParserUtils.parse(hit, StoredProgress.PARSER);
} else if (hitId.startsWith(MemoryUsage.documentIdPrefix(configId))) {
retrievedStatsHolder.memoryUsage = MlParserUtils.parse(hit, MemoryUsage.LENIENT_PARSER);
} else if (hitId.startsWith(OutlierDetectionStats.documentIdPrefix(configId))) {
retrievedStatsHolder.analysisStats = MlParserUtils.parse(hit, OutlierDetectionStats.LENIENT_PARSER);
} else if (hitId.startsWith(ClassificationStats.documentIdPrefix(configId))) {
retrievedStatsHolder.analysisStats = MlParserUtils.parse(hit, ClassificationStats.LENIENT_PARSER);
} else if (hitId.startsWith(RegressionStats.documentIdPrefix(configId))) {
retrievedStatsHolder.analysisStats = MlParserUtils.parse(hit, RegressionStats.LENIENT_PARSER);
} else {
throw ExceptionsHelper.serverError("unexpected doc id [" + hitId + "]");
}
@ -256,7 +271,8 @@ public class TransportGetDataFrameAnalyticsStatsAction
private GetDataFrameAnalyticsStatsAction.Response.Stats buildStats(String concreteAnalyticsId,
List<PhaseProgress> progress,
MemoryUsage memoryUsage) {
MemoryUsage memoryUsage,
AnalysisStats analysisStats) {
ClusterState clusterState = clusterService.state();
PersistentTasksCustomMetaData tasks = clusterState.getMetaData().custom(PersistentTasksCustomMetaData.TYPE);
PersistentTasksCustomMetaData.PersistentTask<?> analyticsTask = MlTasks.getDataFrameAnalyticsTask(concreteAnalyticsId, tasks);
@ -278,6 +294,7 @@ public class TransportGetDataFrameAnalyticsStatsAction
failureReason,
progress,
memoryUsage,
analysisStats,
node,
assignmentExplanation
);
@ -287,5 +304,6 @@ public class TransportGetDataFrameAnalyticsStatsAction
private volatile StoredProgress progress = new StoredProgress(new ProgressTracker().report());
private volatile MemoryUsage memoryUsage;
private volatile AnalysisStats analysisStats;
}
}

View File

@ -23,6 +23,9 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.elasticsearch.xpack.core.ml.dataframe.stats.MemoryUsage;
import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStats;
import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStats;
import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStats;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
@ -175,6 +178,21 @@ public class AnalyticsResultProcessor {
statsHolder.setMemoryUsage(memoryUsage);
indexStatsResult(memoryUsage, memoryUsage::documentId);
}
OutlierDetectionStats outlierDetectionStats = result.getOutlierDetectionStats();
if (outlierDetectionStats != null) {
statsHolder.setAnalysisStats(outlierDetectionStats);
indexStatsResult(outlierDetectionStats, outlierDetectionStats::documentId);
}
ClassificationStats classificationStats = result.getClassificationStats();
if (classificationStats != null) {
statsHolder.setAnalysisStats(classificationStats);
indexStatsResult(classificationStats, classificationStats::documentId);
}
RegressionStats regressionStats = result.getRegressionStats();
if (regressionStats != null) {
statsHolder.setAnalysisStats(regressionStats);
indexStatsResult(regressionStats, regressionStats::documentId);
}
}
private void createAndIndexInferenceModel(TrainedModelDefinition.Builder inferenceModel) {

View File

@ -12,6 +12,9 @@ import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.dataframe.stats.MemoryUsage;
import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStats;
import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStats;
import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStats;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import java.io.IOException;
@ -28,9 +31,20 @@ public class AnalyticsResult implements ToXContentObject {
private static final ParseField PROGRESS_PERCENT = new ParseField("progress_percent");
private static final ParseField INFERENCE_MODEL = new ParseField("inference_model");
private static final ParseField ANALYTICS_MEMORY_USAGE = new ParseField("analytics_memory_usage");
private static final ParseField OUTLIER_DETECTION_STATS = new ParseField("outlier_detection_stats");
private static final ParseField CLASSIFICATION_STATS = new ParseField("classification_stats");
private static final ParseField REGRESSION_STATS = new ParseField("regression_stats");
public static final ConstructingObjectParser<AnalyticsResult, Void> PARSER = new ConstructingObjectParser<>(TYPE.getPreferredName(),
a -> new AnalyticsResult((RowResults) a[0], (Integer) a[1], (TrainedModelDefinition.Builder) a[2], (MemoryUsage) a[3]));
a -> new AnalyticsResult(
(RowResults) a[0],
(Integer) a[1],
(TrainedModelDefinition.Builder) a[2],
(MemoryUsage) a[3],
(OutlierDetectionStats) a[4],
(ClassificationStats) a[5],
(RegressionStats) a[6]
));
static {
PARSER.declareObject(optionalConstructorArg(), RowResults.PARSER, RowResults.TYPE);
@ -38,6 +52,9 @@ public class AnalyticsResult implements ToXContentObject {
// TODO change back to STRICT_PARSER once native side is aligned
PARSER.declareObject(optionalConstructorArg(), TrainedModelDefinition.LENIENT_PARSER, INFERENCE_MODEL);
PARSER.declareObject(optionalConstructorArg(), MemoryUsage.STRICT_PARSER, ANALYTICS_MEMORY_USAGE);
PARSER.declareObject(optionalConstructorArg(), OutlierDetectionStats.STRICT_PARSER, OUTLIER_DETECTION_STATS);
PARSER.declareObject(optionalConstructorArg(), ClassificationStats.STRICT_PARSER, CLASSIFICATION_STATS);
PARSER.declareObject(optionalConstructorArg(), RegressionStats.STRICT_PARSER, REGRESSION_STATS);
}
private final RowResults rowResults;
@ -45,16 +62,25 @@ public class AnalyticsResult implements ToXContentObject {
private final TrainedModelDefinition.Builder inferenceModelBuilder;
private final TrainedModelDefinition inferenceModel;
private final MemoryUsage memoryUsage;
private final OutlierDetectionStats outlierDetectionStats;
private final ClassificationStats classificationStats;
private final RegressionStats regressionStats;
public AnalyticsResult(@Nullable RowResults rowResults,
@Nullable Integer progressPercent,
@Nullable TrainedModelDefinition.Builder inferenceModelBuilder,
@Nullable MemoryUsage memoryUsage) {
@Nullable MemoryUsage memoryUsage,
@Nullable OutlierDetectionStats outlierDetectionStats,
@Nullable ClassificationStats classificationStats,
@Nullable RegressionStats regressionStats) {
this.rowResults = rowResults;
this.progressPercent = progressPercent;
this.inferenceModelBuilder = inferenceModelBuilder;
this.inferenceModel = inferenceModelBuilder == null ? null : inferenceModelBuilder.build();
this.memoryUsage = memoryUsage;
this.outlierDetectionStats = outlierDetectionStats;
this.classificationStats = classificationStats;
this.regressionStats = regressionStats;
}
public RowResults getRowResults() {
@ -73,6 +99,18 @@ public class AnalyticsResult implements ToXContentObject {
return memoryUsage;
}
public OutlierDetectionStats getOutlierDetectionStats() {
return outlierDetectionStats;
}
public ClassificationStats getClassificationStats() {
return classificationStats;
}
public RegressionStats getRegressionStats() {
return regressionStats;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
@ -90,6 +128,15 @@ public class AnalyticsResult implements ToXContentObject {
if (memoryUsage != null) {
builder.field(ANALYTICS_MEMORY_USAGE.getPreferredName(), memoryUsage, params);
}
if (outlierDetectionStats != null) {
builder.field(OUTLIER_DETECTION_STATS.getPreferredName(), outlierDetectionStats, params);
}
if (classificationStats != null) {
builder.field(CLASSIFICATION_STATS.getPreferredName(), classificationStats, params);
}
if (regressionStats != null) {
builder.field(REGRESSION_STATS.getPreferredName(), regressionStats, params);
}
builder.endObject();
return builder;
}
@ -107,11 +154,15 @@ public class AnalyticsResult implements ToXContentObject {
return Objects.equals(rowResults, that.rowResults)
&& Objects.equals(progressPercent, that.progressPercent)
&& Objects.equals(inferenceModel, that.inferenceModel)
&& Objects.equals(memoryUsage, that.memoryUsage);
&& Objects.equals(memoryUsage, that.memoryUsage)
&& Objects.equals(outlierDetectionStats, that.outlierDetectionStats)
&& Objects.equals(classificationStats, that.classificationStats)
&& Objects.equals(regressionStats, that.regressionStats);
}
@Override
public int hashCode() {
return Objects.hash(rowResults, progressPercent, inferenceModel, memoryUsage);
return Objects.hash(rowResults, progressPercent, inferenceModel, memoryUsage, outlierDetectionStats, classificationStats,
regressionStats);
}
}

View File

@ -5,6 +5,7 @@
*/
package org.elasticsearch.xpack.ml.dataframe.stats;
import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStats;
import org.elasticsearch.xpack.core.ml.dataframe.stats.MemoryUsage;
import java.util.concurrent.atomic.AtomicReference;
@ -17,10 +18,12 @@ public class StatsHolder {
private final ProgressTracker progressTracker;
private final AtomicReference<MemoryUsage> memoryUsageHolder;
private final AtomicReference<AnalysisStats> analysisStatsHolder;
public StatsHolder() {
progressTracker = new ProgressTracker();
memoryUsageHolder = new AtomicReference<>();
analysisStatsHolder = new AtomicReference<>();
}
public ProgressTracker getProgressTracker() {
@ -34,4 +37,12 @@ public class StatsHolder {
public MemoryUsage getMemoryUsage() {
return memoryUsageHolder.get();
}
public void setAnalysisStats(AnalysisStats analysisStats) {
analysisStatsHolder.set(analysisStats);
}
public AnalysisStats getAnalysisStats() {
return analysisStatsHolder.get();
}
}

View File

@ -56,7 +56,7 @@ public class AnalyticsProcessManagerTests extends ESTestCase {
private static final String CONFIG_ID = "config-id";
private static final int NUM_ROWS = 100;
private static final int NUM_COLS = 4;
private static final AnalyticsResult PROCESS_RESULT = new AnalyticsResult(null, null, null, null);
private static final AnalyticsResult PROCESS_RESULT = new AnalyticsResult(null, null, null, null, null, null, null);
private Client client;
private DataFrameAnalyticsAuditor auditor;

View File

@ -100,7 +100,9 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
public void testProcess_GivenEmptyResults() {
givenDataFrameRows(2);
givenProcessResults(Arrays.asList(new AnalyticsResult(null, 50, null, null), new AnalyticsResult(null, 100, null, null)));
givenProcessResults(Arrays.asList(
new AnalyticsResult(null, 50, null, null, null, null, null),
new AnalyticsResult(null, 100, null, null, null, null, null)));
AnalyticsResultProcessor resultProcessor = createResultProcessor();
resultProcessor.process(process);
@ -115,8 +117,8 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
givenDataFrameRows(2);
RowResults rowResults1 = mock(RowResults.class);
RowResults rowResults2 = mock(RowResults.class);
givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1, 50, null, null),
new AnalyticsResult(rowResults2, 100, null, null)));
givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1, 50, null, null, null, null, null),
new AnalyticsResult(rowResults2, 100, null, null, null, null, null)));
AnalyticsResultProcessor resultProcessor = createResultProcessor();
resultProcessor.process(process);
@ -133,8 +135,8 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
givenDataFrameRows(2);
RowResults rowResults1 = mock(RowResults.class);
RowResults rowResults2 = mock(RowResults.class);
givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1, 50, null, null),
new AnalyticsResult(rowResults2, 100, null, null)));
givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1, 50, null, null, null, null, null),
new AnalyticsResult(rowResults2, 100, null, null, null, null, null)));
doThrow(new RuntimeException("some failure")).when(dataFrameRowsJoiner).processRowResults(any(RowResults.class));
@ -167,7 +169,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
extractedFieldList.add(new MultiField("bar", new DocValueField("bar.keyword", Collections.emptySet())));
extractedFieldList.add(new DocValueField("baz", Collections.emptySet()));
TrainedModelDefinition.Builder inferenceModel = TrainedModelDefinitionTests.createRandomBuilder();
givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel, null)));
givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel, null, null, null, null)));
AnalyticsResultProcessor resultProcessor = createResultProcessor(extractedFieldList);
resultProcessor.process(process);
@ -212,7 +214,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
}).when(trainedModelProvider).storeTrainedModel(any(TrainedModelConfig.class), any(ActionListener.class));
TrainedModelDefinition.Builder inferenceModel = TrainedModelDefinitionTests.createRandomBuilder();
givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel, null)));
givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel, null, null, null, null)));
AnalyticsResultProcessor resultProcessor = createResultProcessor();
resultProcessor.process(process);

View File

@ -11,7 +11,14 @@ import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.test.AbstractXContentTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.stats.MemoryUsage;
import org.elasticsearch.xpack.core.ml.dataframe.stats.MemoryUsageTests;
import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStats;
import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStatsTests;
import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStats;
import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStatsTests;
import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStats;
import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStatsTests;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
@ -36,6 +43,10 @@ public class AnalyticsResultTests extends AbstractXContentTestCase<AnalyticsResu
RowResults rowResults = null;
Integer progressPercent = null;
TrainedModelDefinition.Builder inferenceModel = null;
MemoryUsage memoryUsage = null;
OutlierDetectionStats outlierDetectionStats = null;
ClassificationStats classificationStats = null;
RegressionStats regressionStats = null;
if (randomBoolean()) {
rowResults = RowResultsTests.createRandom();
}
@ -45,7 +56,20 @@ public class AnalyticsResultTests extends AbstractXContentTestCase<AnalyticsResu
if (randomBoolean()) {
inferenceModel = TrainedModelDefinitionTests.createRandomBuilder();
}
return new AnalyticsResult(rowResults, progressPercent, inferenceModel, MemoryUsageTests.createRandom());
if (randomBoolean()) {
memoryUsage = MemoryUsageTests.createRandom();
}
if (randomBoolean()) {
outlierDetectionStats = OutlierDetectionStatsTests.createRandom();
}
if (randomBoolean()) {
classificationStats = ClassificationStatsTests.createRandom();
}
if (randomBoolean()) {
regressionStats = RegressionStatsTests.createRandom();
}
return new AnalyticsResult(rowResults, progressPercent, inferenceModel, memoryUsage, outlierDetectionStats, classificationStats,
regressionStats);
}
@Override