diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStats.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStats.java index 53e3adf2b84..70169ce09b2 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStats.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStats.java @@ -20,12 +20,15 @@ package org.elasticsearch.client.ml.dataframe; import org.elasticsearch.client.ml.NodeAttributes; +import org.elasticsearch.client.ml.dataframe.stats.AnalysisStats; +import org.elasticsearch.client.ml.dataframe.stats.common.MemoryUsage; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.inject.internal.ToStringBuilder; import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.common.xcontent.XContentParserUtils; import java.io.IOException; import java.util.List; @@ -45,6 +48,7 @@ public class DataFrameAnalyticsStats { static final ParseField FAILURE_REASON = new ParseField("failure_reason"); static final ParseField PROGRESS = new ParseField("progress"); static final ParseField MEMORY_USAGE = new ParseField("memory_usage"); + static final ParseField ANALYSIS_STATS = new ParseField("analysis_stats"); static final ParseField NODE = new ParseField("node"); static final ParseField ASSIGNMENT_EXPLANATION = new ParseField("assignment_explanation"); @@ -57,8 +61,9 @@ public class DataFrameAnalyticsStats { (String) args[2], (List) args[3], (MemoryUsage) args[4], - (NodeAttributes) args[5], - (String) args[6])); + (AnalysisStats) args[5], + (NodeAttributes) args[6], + (String) args[7])); static { PARSER.declareString(constructorArg(), ID); @@ -71,26 +76,38 @@ public class DataFrameAnalyticsStats { PARSER.declareString(optionalConstructorArg(), FAILURE_REASON); PARSER.declareObjectArray(optionalConstructorArg(), PhaseProgress.PARSER, PROGRESS); PARSER.declareObject(optionalConstructorArg(), MemoryUsage.PARSER, MEMORY_USAGE); + PARSER.declareObject(optionalConstructorArg(), (p, c) -> parseAnalysisStats(p), ANALYSIS_STATS); PARSER.declareObject(optionalConstructorArg(), NodeAttributes.PARSER, NODE); PARSER.declareString(optionalConstructorArg(), ASSIGNMENT_EXPLANATION); } + private static AnalysisStats parseAnalysisStats(XContentParser parser) throws IOException { + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser::getTokenLocation); + XContentParserUtils.ensureExpectedToken(XContentParser.Token.FIELD_NAME, parser.nextToken(), parser::getTokenLocation); + AnalysisStats analysisStats = parser.namedObject(AnalysisStats.class, parser.currentName(), true); + XContentParserUtils.ensureExpectedToken(XContentParser.Token.END_OBJECT, parser.nextToken(), parser::getTokenLocation); + return analysisStats; + } + private final String id; private final DataFrameAnalyticsState state; private final String failureReason; private final List progress; private final MemoryUsage memoryUsage; + private final AnalysisStats analysisStats; private final NodeAttributes node; private final String assignmentExplanation; public DataFrameAnalyticsStats(String id, DataFrameAnalyticsState state, @Nullable String failureReason, @Nullable List progress, @Nullable MemoryUsage memoryUsage, - @Nullable NodeAttributes node, @Nullable String assignmentExplanation) { + @Nullable AnalysisStats analysisStats, @Nullable NodeAttributes node, + @Nullable String assignmentExplanation) { this.id = id; this.state = state; this.failureReason = failureReason; this.progress = progress; this.memoryUsage = memoryUsage; + this.analysisStats = analysisStats; this.node = node; this.assignmentExplanation = assignmentExplanation; } @@ -116,6 +133,11 @@ public class DataFrameAnalyticsStats { return memoryUsage; } + @Nullable + public AnalysisStats getAnalysisStats() { + return analysisStats; + } + public NodeAttributes getNode() { return node; } @@ -135,13 +157,14 @@ public class DataFrameAnalyticsStats { && Objects.equals(failureReason, other.failureReason) && Objects.equals(progress, other.progress) && Objects.equals(memoryUsage, other.memoryUsage) + && Objects.equals(analysisStats, other.analysisStats) && Objects.equals(node, other.node) && Objects.equals(assignmentExplanation, other.assignmentExplanation); } @Override public int hashCode() { - return Objects.hash(id, state, failureReason, progress, memoryUsage, node, assignmentExplanation); + return Objects.hash(id, state, failureReason, progress, memoryUsage, analysisStats, node, assignmentExplanation); } @Override @@ -152,6 +175,7 @@ public class DataFrameAnalyticsStats { .add("failureReason", failureReason) .add("progress", progress) .add("memoryUsage", memoryUsage) + .add("analysisStats", analysisStats) .add("node", node) .add("assignmentExplanation", assignmentExplanation) .toString(); diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/stats/AnalysisStats.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/stats/AnalysisStats.java new file mode 100644 index 00000000000..c1a823682a7 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/stats/AnalysisStats.java @@ -0,0 +1,29 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.stats; + +import org.elasticsearch.common.xcontent.ToXContentObject; + +/** + * Statistics for the data frame analysis + */ +public interface AnalysisStats extends ToXContentObject { + + String getName(); +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/stats/AnalysisStatsNamedXContentProvider.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/stats/AnalysisStatsNamedXContentProvider.java new file mode 100644 index 00000000000..8c9bc615e86 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/stats/AnalysisStatsNamedXContentProvider.java @@ -0,0 +1,52 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.stats; + +import org.elasticsearch.client.ml.dataframe.stats.classification.ClassificationStats; +import org.elasticsearch.client.ml.dataframe.stats.outlierdetection.OutlierDetectionStats; +import org.elasticsearch.client.ml.dataframe.stats.regression.RegressionStats; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.plugins.spi.NamedXContentProvider; + +import java.util.Arrays; +import java.util.List; + +public class AnalysisStatsNamedXContentProvider implements NamedXContentProvider { + + @Override + public List getNamedXContentParsers() { + return Arrays.asList( + new NamedXContentRegistry.Entry( + AnalysisStats.class, + ClassificationStats.NAME, + (p, c) -> ClassificationStats.PARSER.apply(p, null) + ), + new NamedXContentRegistry.Entry( + AnalysisStats.class, + OutlierDetectionStats.NAME, + (p, c) -> OutlierDetectionStats.PARSER.apply(p, null) + ), + new NamedXContentRegistry.Entry( + AnalysisStats.class, + RegressionStats.NAME, + (p, c) -> RegressionStats.PARSER.apply(p, null) + ) + ); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/stats/classification/ClassificationStats.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/stats/classification/ClassificationStats.java new file mode 100644 index 00000000000..101f74f2fe2 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/stats/classification/ClassificationStats.java @@ -0,0 +1,135 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.stats.classification; + +import org.elasticsearch.client.common.TimeUtil; +import org.elasticsearch.client.ml.dataframe.stats.AnalysisStats; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; + +import java.io.IOException; +import java.time.Instant; +import java.util.Objects; + +public class ClassificationStats implements AnalysisStats { + + public static final ParseField NAME = new ParseField("classification_stats"); + + public static final ParseField TIMESTAMP = new ParseField("timestamp"); + public static final ParseField ITERATION = new ParseField("iteration"); + public static final ParseField HYPERPARAMETERS = new ParseField("hyperparameters"); + public static final ParseField TIMING_STATS = new ParseField("timing_stats"); + public static final ParseField VALIDATION_LOSS = new ParseField("validation_loss"); + + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME.getPreferredName(), + true, + a -> new ClassificationStats( + (Instant) a[0], + (Integer) a[1], + (Hyperparameters) a[2], + (TimingStats) a[3], + (ValidationLoss) a[4] + ) + ); + + static { + PARSER.declareField(ConstructingObjectParser.constructorArg(), + p -> TimeUtil.parseTimeFieldToInstant(p, TIMESTAMP.getPreferredName()), + TIMESTAMP, + ObjectParser.ValueType.VALUE); + PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), ITERATION); + PARSER.declareObject(ConstructingObjectParser.constructorArg(), Hyperparameters.PARSER, HYPERPARAMETERS); + PARSER.declareObject(ConstructingObjectParser.constructorArg(), TimingStats.PARSER, TIMING_STATS); + PARSER.declareObject(ConstructingObjectParser.constructorArg(), ValidationLoss.PARSER, VALIDATION_LOSS); + } + + private final Instant timestamp; + private final Integer iteration; + private final Hyperparameters hyperparameters; + private final TimingStats timingStats; + private final ValidationLoss validationLoss; + + public ClassificationStats(Instant timestamp, Integer iteration, Hyperparameters hyperparameters, TimingStats timingStats, + ValidationLoss validationLoss) { + this.timestamp = Instant.ofEpochMilli(Objects.requireNonNull(timestamp).toEpochMilli()); + this.iteration = iteration; + this.hyperparameters = Objects.requireNonNull(hyperparameters); + this.timingStats = Objects.requireNonNull(timingStats); + this.validationLoss = Objects.requireNonNull(validationLoss); + } + + public Instant getTimestamp() { + return timestamp; + } + + public Integer getIteration() { + return iteration; + } + + public Hyperparameters getHyperparameters() { + return hyperparameters; + } + + public TimingStats getTimingStats() { + return timingStats; + } + + public ValidationLoss getValidationLoss() { + return validationLoss; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(); + builder.timeField(TIMESTAMP.getPreferredName(), TIMESTAMP.getPreferredName() + "_string", timestamp.toEpochMilli()); + if (iteration != null) { + builder.field(ITERATION.getPreferredName(), iteration); + } + builder.field(HYPERPARAMETERS.getPreferredName(), hyperparameters); + builder.field(TIMING_STATS.getPreferredName(), timingStats); + builder.field(VALIDATION_LOSS.getPreferredName(), validationLoss); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ClassificationStats that = (ClassificationStats) o; + return Objects.equals(timestamp, that.timestamp) + && Objects.equals(iteration, that.iteration) + && Objects.equals(hyperparameters, that.hyperparameters) + && Objects.equals(timingStats, that.timingStats) + && Objects.equals(validationLoss, that.validationLoss); + } + + @Override + public int hashCode() { + return Objects.hash(timestamp, iteration, hyperparameters, timingStats, validationLoss); + } + + @Override + public String getName() { + return NAME.getPreferredName(); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/stats/classification/Hyperparameters.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/stats/classification/Hyperparameters.java new file mode 100644 index 00000000000..c8d581b1d9c --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/stats/classification/Hyperparameters.java @@ -0,0 +1,293 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.stats.classification; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; + +public class Hyperparameters implements ToXContentObject { + + public static final ParseField CLASS_ASSIGNMENT_OBJECTIVE = new ParseField("class_assignment_objective"); + public static final ParseField DOWNSAMPLE_FACTOR = new ParseField("downsample_factor"); + public static final ParseField ETA = new ParseField("eta"); + public static final ParseField ETA_GROWTH_RATE_PER_TREE = new ParseField("eta_growth_rate_per_tree"); + public static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction"); + public static final ParseField MAX_ATTEMPTS_TO_ADD_TREE = new ParseField("max_attempts_to_add_tree"); + public static final ParseField MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER = new ParseField( + "max_optimization_rounds_per_hyperparameter"); + public static final ParseField MAX_TREES = new ParseField("max_trees"); + public static final ParseField NUM_FOLDS = new ParseField("num_folds"); + public static final ParseField NUM_SPLITS_PER_FEATURE = new ParseField("num_splits_per_feature"); + public static final ParseField REGULARIZATION_DEPTH_PENALTY_MULTIPLIER = new ParseField("regularization_depth_penalty_multiplier"); + public static final ParseField REGULARIZATION_LEAF_WEIGHT_PENALTY_MULTIPLIER + = new ParseField("regularization_leaf_weight_penalty_multiplier"); + public static final ParseField REGULARIZATION_SOFT_TREE_DEPTH_LIMIT = new ParseField("regularization_soft_tree_depth_limit"); + public static final ParseField REGULARIZATION_SOFT_TREE_DEPTH_TOLERANCE = new ParseField("regularization_soft_tree_depth_tolerance"); + public static final ParseField REGULARIZATION_TREE_SIZE_PENALTY_MULTIPLIER = + new ParseField("regularization_tree_size_penalty_multiplier"); + + public static ConstructingObjectParser PARSER = new ConstructingObjectParser<>("classification_hyperparameters", + true, + a -> new Hyperparameters( + (String) a[0], + (Double) a[1], + (Double) a[2], + (Double) a[3], + (Double) a[4], + (Integer) a[5], + (Integer) a[6], + (Integer) a[7], + (Integer) a[8], + (Integer) a[9], + (Double) a[10], + (Double) a[11], + (Double) a[12], + (Double) a[13], + (Double) a[14] + )); + + static { + PARSER.declareString(optionalConstructorArg(), CLASS_ASSIGNMENT_OBJECTIVE); + PARSER.declareDouble(optionalConstructorArg(), DOWNSAMPLE_FACTOR); + PARSER.declareDouble(optionalConstructorArg(), ETA); + PARSER.declareDouble(optionalConstructorArg(), ETA_GROWTH_RATE_PER_TREE); + PARSER.declareDouble(optionalConstructorArg(), FEATURE_BAG_FRACTION); + PARSER.declareInt(optionalConstructorArg(), MAX_ATTEMPTS_TO_ADD_TREE); + PARSER.declareInt(optionalConstructorArg(), MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER); + PARSER.declareInt(optionalConstructorArg(), MAX_TREES); + PARSER.declareInt(optionalConstructorArg(), NUM_FOLDS); + PARSER.declareInt(optionalConstructorArg(), NUM_SPLITS_PER_FEATURE); + PARSER.declareDouble(optionalConstructorArg(), REGULARIZATION_DEPTH_PENALTY_MULTIPLIER); + PARSER.declareDouble(optionalConstructorArg(), REGULARIZATION_LEAF_WEIGHT_PENALTY_MULTIPLIER); + PARSER.declareDouble(optionalConstructorArg(), REGULARIZATION_SOFT_TREE_DEPTH_LIMIT); + PARSER.declareDouble(optionalConstructorArg(), REGULARIZATION_SOFT_TREE_DEPTH_TOLERANCE); + PARSER.declareDouble(optionalConstructorArg(), REGULARIZATION_TREE_SIZE_PENALTY_MULTIPLIER); + } + + private final String classAssignmentObjective; + private final Double downsampleFactor; + private final Double eta; + private final Double etaGrowthRatePerTree; + private final Double featureBagFraction; + private final Integer maxAttemptsToAddTree; + private final Integer maxOptimizationRoundsPerHyperparameter; + private final Integer maxTrees; + private final Integer numFolds; + private final Integer numSplitsPerFeature; + private final Double regularizationDepthPenaltyMultiplier; + private final Double regularizationLeafWeightPenaltyMultiplier; + private final Double regularizationSoftTreeDepthLimit; + private final Double regularizationSoftTreeDepthTolerance; + private final Double regularizationTreeSizePenaltyMultiplier; + + public Hyperparameters(String classAssignmentObjective, + Double downsampleFactor, + Double eta, + Double etaGrowthRatePerTree, + Double featureBagFraction, + Integer maxAttemptsToAddTree, + Integer maxOptimizationRoundsPerHyperparameter, + Integer maxTrees, + Integer numFolds, + Integer numSplitsPerFeature, + Double regularizationDepthPenaltyMultiplier, + Double regularizationLeafWeightPenaltyMultiplier, + Double regularizationSoftTreeDepthLimit, + Double regularizationSoftTreeDepthTolerance, + Double regularizationTreeSizePenaltyMultiplier) { + this.classAssignmentObjective = classAssignmentObjective; + this.downsampleFactor = downsampleFactor; + this.eta = eta; + this.etaGrowthRatePerTree = etaGrowthRatePerTree; + this.featureBagFraction = featureBagFraction; + this.maxAttemptsToAddTree = maxAttemptsToAddTree; + this.maxOptimizationRoundsPerHyperparameter = maxOptimizationRoundsPerHyperparameter; + this.maxTrees = maxTrees; + this.numFolds = numFolds; + this.numSplitsPerFeature = numSplitsPerFeature; + this.regularizationDepthPenaltyMultiplier = regularizationDepthPenaltyMultiplier; + this.regularizationLeafWeightPenaltyMultiplier = regularizationLeafWeightPenaltyMultiplier; + this.regularizationSoftTreeDepthLimit = regularizationSoftTreeDepthLimit; + this.regularizationSoftTreeDepthTolerance = regularizationSoftTreeDepthTolerance; + this.regularizationTreeSizePenaltyMultiplier = regularizationTreeSizePenaltyMultiplier; + } + + public String getClassAssignmentObjective() { + return classAssignmentObjective; + } + + public Double getDownsampleFactor() { + return downsampleFactor; + } + + public Double getEta() { + return eta; + } + + public Double getEtaGrowthRatePerTree() { + return etaGrowthRatePerTree; + } + + public Double getFeatureBagFraction() { + return featureBagFraction; + } + + public Integer getMaxAttemptsToAddTree() { + return maxAttemptsToAddTree; + } + + public Integer getMaxOptimizationRoundsPerHyperparameter() { + return maxOptimizationRoundsPerHyperparameter; + } + + public Integer getMaxTrees() { + return maxTrees; + } + + public Integer getNumFolds() { + return numFolds; + } + + public Integer getNumSplitsPerFeature() { + return numSplitsPerFeature; + } + + public Double getRegularizationDepthPenaltyMultiplier() { + return regularizationDepthPenaltyMultiplier; + } + + public Double getRegularizationLeafWeightPenaltyMultiplier() { + return regularizationLeafWeightPenaltyMultiplier; + } + + public Double getRegularizationSoftTreeDepthLimit() { + return regularizationSoftTreeDepthLimit; + } + + public Double getRegularizationSoftTreeDepthTolerance() { + return regularizationSoftTreeDepthTolerance; + } + + public Double getRegularizationTreeSizePenaltyMultiplier() { + return regularizationTreeSizePenaltyMultiplier; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (classAssignmentObjective != null) { + builder.field(CLASS_ASSIGNMENT_OBJECTIVE.getPreferredName(), classAssignmentObjective); + } + if (downsampleFactor != null) { + builder.field(DOWNSAMPLE_FACTOR.getPreferredName(), downsampleFactor); + } + if (eta != null) { + builder.field(ETA.getPreferredName(), eta); + } + if (etaGrowthRatePerTree != null) { + builder.field(ETA_GROWTH_RATE_PER_TREE.getPreferredName(), etaGrowthRatePerTree); + } + if (featureBagFraction != null) { + builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction); + } + if (maxAttemptsToAddTree != null) { + builder.field(MAX_ATTEMPTS_TO_ADD_TREE.getPreferredName(), maxAttemptsToAddTree); + } + if (maxOptimizationRoundsPerHyperparameter != null) { + builder.field(MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER.getPreferredName(), maxOptimizationRoundsPerHyperparameter); + } + if (maxTrees != null) { + builder.field(MAX_TREES.getPreferredName(), maxTrees); + } + if (numFolds != null) { + builder.field(NUM_FOLDS.getPreferredName(), numFolds); + } + if (numSplitsPerFeature != null) { + builder.field(NUM_SPLITS_PER_FEATURE.getPreferredName(), numSplitsPerFeature); + } + if (regularizationDepthPenaltyMultiplier != null) { + builder.field(REGULARIZATION_DEPTH_PENALTY_MULTIPLIER.getPreferredName(), regularizationDepthPenaltyMultiplier); + } + if (regularizationLeafWeightPenaltyMultiplier != null) { + builder.field(REGULARIZATION_LEAF_WEIGHT_PENALTY_MULTIPLIER.getPreferredName(), regularizationLeafWeightPenaltyMultiplier); + } + if (regularizationSoftTreeDepthLimit != null) { + builder.field(REGULARIZATION_SOFT_TREE_DEPTH_LIMIT.getPreferredName(), regularizationSoftTreeDepthLimit); + } + if (regularizationSoftTreeDepthTolerance != null) { + builder.field(REGULARIZATION_SOFT_TREE_DEPTH_TOLERANCE.getPreferredName(), regularizationSoftTreeDepthTolerance); + } + if (regularizationTreeSizePenaltyMultiplier != null) { + builder.field(REGULARIZATION_TREE_SIZE_PENALTY_MULTIPLIER.getPreferredName(), regularizationTreeSizePenaltyMultiplier); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + Hyperparameters that = (Hyperparameters) o; + return Objects.equals(classAssignmentObjective, that.classAssignmentObjective) + && Objects.equals(downsampleFactor, that.downsampleFactor) + && Objects.equals(eta, that.eta) + && Objects.equals(etaGrowthRatePerTree, that.etaGrowthRatePerTree) + && Objects.equals(featureBagFraction, that.featureBagFraction) + && Objects.equals(maxAttemptsToAddTree, that.maxAttemptsToAddTree) + && Objects.equals(maxOptimizationRoundsPerHyperparameter, that.maxOptimizationRoundsPerHyperparameter) + && Objects.equals(maxTrees, that.maxTrees) + && Objects.equals(numFolds, that.numFolds) + && Objects.equals(numSplitsPerFeature, that.numSplitsPerFeature) + && Objects.equals(regularizationDepthPenaltyMultiplier, that.regularizationDepthPenaltyMultiplier) + && Objects.equals(regularizationLeafWeightPenaltyMultiplier, that.regularizationLeafWeightPenaltyMultiplier) + && Objects.equals(regularizationSoftTreeDepthLimit, that.regularizationSoftTreeDepthLimit) + && Objects.equals(regularizationSoftTreeDepthTolerance, that.regularizationSoftTreeDepthTolerance) + && Objects.equals(regularizationTreeSizePenaltyMultiplier, that.regularizationTreeSizePenaltyMultiplier); + } + + @Override + public int hashCode() { + return Objects.hash( + classAssignmentObjective, + downsampleFactor, + eta, + etaGrowthRatePerTree, + featureBagFraction, + maxAttemptsToAddTree, + maxOptimizationRoundsPerHyperparameter, + maxTrees, + numFolds, + numSplitsPerFeature, + regularizationDepthPenaltyMultiplier, + regularizationLeafWeightPenaltyMultiplier, + regularizationSoftTreeDepthLimit, + regularizationSoftTreeDepthTolerance, + regularizationTreeSizePenaltyMultiplier + ); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/stats/classification/TimingStats.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/stats/classification/TimingStats.java new file mode 100644 index 00000000000..bad599298a7 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/stats/classification/TimingStats.java @@ -0,0 +1,87 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.stats.classification; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Objects; + +public class TimingStats implements ToXContentObject { + + public static final ParseField ELAPSED_TIME = new ParseField("elapsed_time"); + public static final ParseField ITERATION_TIME = new ParseField("iteration_time"); + + public static ConstructingObjectParser PARSER = new ConstructingObjectParser<>("classification_timing_stats", true, + a -> new TimingStats( + a[0] == null ? null : TimeValue.timeValueMillis((long) a[0]), + a[1] == null ? null : TimeValue.timeValueMillis((long) a[1]) + )); + + static { + PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), ELAPSED_TIME); + PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), ITERATION_TIME); + } + + private final TimeValue elapsedTime; + private final TimeValue iterationTime; + + public TimingStats(TimeValue elapsedTime, TimeValue iterationTime) { + this.elapsedTime = elapsedTime; + this.iterationTime = iterationTime; + } + + public TimeValue getElapsedTime() { + return elapsedTime; + } + + public TimeValue getIterationTime() { + return iterationTime; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (elapsedTime != null) { + builder.humanReadableField(ELAPSED_TIME.getPreferredName(), ELAPSED_TIME.getPreferredName() + "_string", elapsedTime); + } + if (iterationTime != null) { + builder.humanReadableField(ITERATION_TIME.getPreferredName(), ITERATION_TIME.getPreferredName() + "_string", iterationTime); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + TimingStats that = (TimingStats) o; + return Objects.equals(elapsedTime, that.elapsedTime) && Objects.equals(iterationTime, that.iterationTime); + } + + @Override + public int hashCode() { + return Objects.hash(elapsedTime, iterationTime); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/stats/classification/ValidationLoss.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/stats/classification/ValidationLoss.java new file mode 100644 index 00000000000..a552f5d85e1 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/stats/classification/ValidationLoss.java @@ -0,0 +1,87 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.stats.classification; + +import org.elasticsearch.client.ml.dataframe.stats.common.FoldValues; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +public class ValidationLoss implements ToXContentObject { + + public static final ParseField LOSS_TYPE = new ParseField("loss_type"); + public static final ParseField FOLD_VALUES = new ParseField("fold_values"); + + @SuppressWarnings("unchecked") + public static ConstructingObjectParser PARSER = new ConstructingObjectParser<>("classification_validation_loss", + true, + a -> new ValidationLoss((String) a[0], (List) a[1])); + + static { + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), LOSS_TYPE); + PARSER.declareObjectArray(ConstructingObjectParser.optionalConstructorArg(), FoldValues.PARSER, FOLD_VALUES); + } + + private final String lossType; + private final List foldValues; + + public ValidationLoss(String lossType, List values) { + this.lossType = lossType; + this.foldValues = values; + } + + public String getLossType() { + return lossType; + } + + public List getFoldValues() { + return foldValues; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (lossType != null) { + builder.field(LOSS_TYPE.getPreferredName(), lossType); + } + if (foldValues != null) { + builder.field(FOLD_VALUES.getPreferredName(), foldValues); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ValidationLoss that = (ValidationLoss) o; + return Objects.equals(lossType, that.lossType) && Objects.equals(foldValues, that.foldValues); + } + + @Override + public int hashCode() { + return Objects.hash(lossType, foldValues); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/stats/common/FoldValues.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/stats/common/FoldValues.java new file mode 100644 index 00000000000..30490981d96 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/stats/common/FoldValues.java @@ -0,0 +1,87 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.stats.common; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; + +public class FoldValues implements ToXContentObject { + + public static final ParseField FOLD = new ParseField("fold"); + public static final ParseField VALUES = new ParseField("values"); + + @SuppressWarnings("unchecked") + public static ConstructingObjectParser PARSER = new ConstructingObjectParser<>("fold_values", true, + a -> new FoldValues((int) a[0], (List) a[1])); + + static { + PARSER.declareInt(ConstructingObjectParser.constructorArg(), FOLD); + PARSER.declareDoubleArray(ConstructingObjectParser.constructorArg(), VALUES); + } + + private final int fold; + private final double[] values; + + private FoldValues(int fold, List values) { + this(fold, values.stream().mapToDouble(Double::doubleValue).toArray()); + } + + public FoldValues(int fold, double[] values) { + this.fold = fold; + this.values = values; + } + + public int getFold() { + return fold; + } + + public double[] getValues() { + return values; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(FOLD.getPreferredName(), fold); + builder.array(VALUES.getPreferredName(), values); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (o == this) return true; + if (o == null || getClass() != o.getClass()) return false; + + FoldValues other = (FoldValues) o; + return fold == other.fold && Arrays.equals(values, other.values); + } + + @Override + public int hashCode() { + return Objects.hash(fold, Arrays.hashCode(values)); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/MemoryUsage.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/stats/common/MemoryUsage.java similarity index 94% rename from client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/MemoryUsage.java rename to client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/stats/common/MemoryUsage.java index 323ebb52a7a..f492d26528e 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/MemoryUsage.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/stats/common/MemoryUsage.java @@ -16,7 +16,7 @@ * specific language governing permissions and limitations * under the License. */ -package org.elasticsearch.client.ml.dataframe; +package org.elasticsearch.client.ml.dataframe.stats.common; import org.elasticsearch.client.common.TimeUtil; import org.elasticsearch.common.ParseField; @@ -54,6 +54,14 @@ public class MemoryUsage implements ToXContentObject { this.peakUsageBytes = peakUsageBytes; } + public Instant getTimestamp() { + return timestamp; + } + + public long getPeakUsageBytes() { + return peakUsageBytes; + } + @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/stats/outlierdetection/OutlierDetectionStats.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/stats/outlierdetection/OutlierDetectionStats.java new file mode 100644 index 00000000000..15098e3770f --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/stats/outlierdetection/OutlierDetectionStats.java @@ -0,0 +1,105 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.stats.outlierdetection; + +import org.elasticsearch.client.common.TimeUtil; +import org.elasticsearch.client.ml.dataframe.stats.AnalysisStats; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; + +import java.io.IOException; +import java.time.Instant; +import java.util.Objects; + +public class OutlierDetectionStats implements AnalysisStats { + + public static final ParseField NAME = new ParseField("outlier_detection_stats"); + + public static final ParseField TIMESTAMP = new ParseField("timestamp"); + public static final ParseField PARAMETERS = new ParseField("parameters"); + public static final ParseField TIMING_STATS = new ParseField("timings_stats"); + + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + NAME.getPreferredName(), true, + a -> new OutlierDetectionStats((Instant) a[0], (Parameters) a[1], (TimingStats) a[2])); + + static { + PARSER.declareField(ConstructingObjectParser.constructorArg(), + p -> TimeUtil.parseTimeFieldToInstant(p, TIMESTAMP.getPreferredName()), + TIMESTAMP, + ObjectParser.ValueType.VALUE); + PARSER.declareObject(ConstructingObjectParser.constructorArg(), Parameters.PARSER, PARAMETERS); + PARSER.declareObject(ConstructingObjectParser.constructorArg(), TimingStats.PARSER, TIMING_STATS); + } + + private final Instant timestamp; + private final Parameters parameters; + private final TimingStats timingStats; + + public OutlierDetectionStats(Instant timestamp, Parameters parameters, TimingStats timingStats) { + this.timestamp = Instant.ofEpochMilli(Objects.requireNonNull(timestamp).toEpochMilli()); + this.parameters = Objects.requireNonNull(parameters); + this.timingStats = Objects.requireNonNull(timingStats); + } + + public Instant getTimestamp() { + return timestamp; + } + + public Parameters getParameters() { + return parameters; + } + + public TimingStats getTimingStats() { + return timingStats; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(); + builder.timeField(TIMESTAMP.getPreferredName(), TIMESTAMP.getPreferredName() + "_string", timestamp.toEpochMilli()); + builder.field(PARAMETERS.getPreferredName(), parameters); + builder.field(TIMING_STATS.getPreferredName(), timingStats); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + OutlierDetectionStats that = (OutlierDetectionStats) o; + return Objects.equals(timestamp, that.timestamp) + && Objects.equals(parameters, that.parameters) + && Objects.equals(timingStats, that.timingStats); + } + + @Override + public int hashCode() { + return Objects.hash(timestamp, parameters, timingStats); + } + + @Override + public String getName() { + return NAME.getPreferredName(); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/stats/outlierdetection/Parameters.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/stats/outlierdetection/Parameters.java new file mode 100644 index 00000000000..deafb55081d --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/stats/outlierdetection/Parameters.java @@ -0,0 +1,146 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.stats.outlierdetection; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; + +public class Parameters implements ToXContentObject { + + public static final ParseField N_NEIGHBORS = new ParseField("n_neighbors"); + public static final ParseField METHOD = new ParseField("method"); + public static final ParseField FEATURE_INFLUENCE_THRESHOLD = new ParseField("feature_influence_threshold"); + public static final ParseField COMPUTE_FEATURE_INFLUENCE = new ParseField("compute_feature_influence"); + public static final ParseField OUTLIER_FRACTION = new ParseField("outlier_fraction"); + public static final ParseField STANDARDIZATION_ENABLED = new ParseField("standardization_enabled"); + + @SuppressWarnings("unchecked") + public static ConstructingObjectParser PARSER = new ConstructingObjectParser<>("outlier_detection_parameters", + true, + a -> new Parameters( + (Integer) a[0], + (String) a[1], + (Boolean) a[2], + (Double) a[3], + (Double) a[4], + (Boolean) a[5] + )); + + static { + PARSER.declareInt(optionalConstructorArg(), N_NEIGHBORS); + PARSER.declareString(optionalConstructorArg(), METHOD); + PARSER.declareBoolean(optionalConstructorArg(), COMPUTE_FEATURE_INFLUENCE); + PARSER.declareDouble(optionalConstructorArg(), FEATURE_INFLUENCE_THRESHOLD); + PARSER.declareDouble(optionalConstructorArg(), OUTLIER_FRACTION); + PARSER.declareBoolean(optionalConstructorArg(), STANDARDIZATION_ENABLED); + } + + private final Integer nNeighbors; + private final String method; + private final Boolean computeFeatureInfluence; + private final Double featureInfluenceThreshold; + private final Double outlierFraction; + private final Boolean standardizationEnabled; + + public Parameters(Integer nNeighbors, String method, Boolean computeFeatureInfluence, Double featureInfluenceThreshold, + Double outlierFraction, Boolean standardizationEnabled) { + this.nNeighbors = nNeighbors; + this.method = method; + this.computeFeatureInfluence = computeFeatureInfluence; + this.featureInfluenceThreshold = featureInfluenceThreshold; + this.outlierFraction = outlierFraction; + this.standardizationEnabled = standardizationEnabled; + } + + public Integer getnNeighbors() { + return nNeighbors; + } + + public String getMethod() { + return method; + } + + public Boolean getComputeFeatureInfluence() { + return computeFeatureInfluence; + } + + public Double getFeatureInfluenceThreshold() { + return featureInfluenceThreshold; + } + + public Double getOutlierFraction() { + return outlierFraction; + } + + public Boolean getStandardizationEnabled() { + return standardizationEnabled; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (nNeighbors != null) { + builder.field(N_NEIGHBORS.getPreferredName(), nNeighbors); + } + if (method != null) { + builder.field(METHOD.getPreferredName(), method); + } + if (computeFeatureInfluence != null) { + builder.field(COMPUTE_FEATURE_INFLUENCE.getPreferredName(), computeFeatureInfluence); + } + if (featureInfluenceThreshold != null) { + builder.field(FEATURE_INFLUENCE_THRESHOLD.getPreferredName(), featureInfluenceThreshold); + } + if (outlierFraction != null) { + builder.field(OUTLIER_FRACTION.getPreferredName(), outlierFraction); + } + if (standardizationEnabled != null) { + builder.field(STANDARDIZATION_ENABLED.getPreferredName(), standardizationEnabled); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + Parameters that = (Parameters) o; + return Objects.equals(nNeighbors, that.nNeighbors) + && Objects.equals(method, that.method) + && Objects.equals(computeFeatureInfluence, that.computeFeatureInfluence) + && Objects.equals(featureInfluenceThreshold, that.featureInfluenceThreshold) + && Objects.equals(outlierFraction, that.outlierFraction) + && Objects.equals(standardizationEnabled, that.standardizationEnabled); + } + + @Override + public int hashCode() { + return Objects.hash(nNeighbors, method, computeFeatureInfluence, featureInfluenceThreshold, outlierFraction, + standardizationEnabled); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/stats/outlierdetection/TimingStats.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/stats/outlierdetection/TimingStats.java new file mode 100644 index 00000000000..96f93a6651d --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/stats/outlierdetection/TimingStats.java @@ -0,0 +1,74 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.stats.outlierdetection; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Objects; + +public class TimingStats implements ToXContentObject { + + public static final ParseField ELAPSED_TIME = new ParseField("elapsed_time"); + + public static ConstructingObjectParser PARSER = new ConstructingObjectParser<>("outlier_detection_timing_stats", + true, + a -> new TimingStats(a[0] == null ? null : TimeValue.timeValueMillis((long) a[0]))); + + static { + PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), ELAPSED_TIME); + } + + private final TimeValue elapsedTime; + + public TimingStats(TimeValue elapsedTime) { + this.elapsedTime = elapsedTime; + } + + public TimeValue getElapsedTime() { + return elapsedTime; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (elapsedTime != null) { + builder.humanReadableField(ELAPSED_TIME.getPreferredName(), ELAPSED_TIME.getPreferredName() + "_string", elapsedTime); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + TimingStats that = (TimingStats) o; + return Objects.equals(elapsedTime, that.elapsedTime); + } + + @Override + public int hashCode() { + return Objects.hash(elapsedTime); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/stats/regression/Hyperparameters.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/stats/regression/Hyperparameters.java new file mode 100644 index 00000000000..cb1a0b99ab5 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/stats/regression/Hyperparameters.java @@ -0,0 +1,278 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.stats.regression; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; + +public class Hyperparameters implements ToXContentObject { + + public static final ParseField DOWNSAMPLE_FACTOR = new ParseField("downsample_factor"); + public static final ParseField ETA = new ParseField("eta"); + public static final ParseField ETA_GROWTH_RATE_PER_TREE = new ParseField("eta_growth_rate_per_tree"); + public static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction"); + public static final ParseField MAX_ATTEMPTS_TO_ADD_TREE = new ParseField("max_attempts_to_add_tree"); + public static final ParseField MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER = new ParseField( + "max_optimization_rounds_per_hyperparameter"); + public static final ParseField MAX_TREES = new ParseField("max_trees"); + public static final ParseField NUM_FOLDS = new ParseField("num_folds"); + public static final ParseField NUM_SPLITS_PER_FEATURE = new ParseField("num_splits_per_feature"); + public static final ParseField REGULARIZATION_DEPTH_PENALTY_MULTIPLIER = new ParseField("regularization_depth_penalty_multiplier"); + public static final ParseField REGULARIZATION_LEAF_WEIGHT_PENALTY_MULTIPLIER + = new ParseField("regularization_leaf_weight_penalty_multiplier"); + public static final ParseField REGULARIZATION_SOFT_TREE_DEPTH_LIMIT = new ParseField("regularization_soft_tree_depth_limit"); + public static final ParseField REGULARIZATION_SOFT_TREE_DEPTH_TOLERANCE = new ParseField("regularization_soft_tree_depth_tolerance"); + public static final ParseField REGULARIZATION_TREE_SIZE_PENALTY_MULTIPLIER = + new ParseField("regularization_tree_size_penalty_multiplier"); + + public static ConstructingObjectParser PARSER = new ConstructingObjectParser<>("regression_hyperparameters", + true, + a -> new Hyperparameters( + (Double) a[0], + (Double) a[1], + (Double) a[2], + (Double) a[3], + (Integer) a[4], + (Integer) a[5], + (Integer) a[6], + (Integer) a[7], + (Integer) a[8], + (Double) a[9], + (Double) a[10], + (Double) a[11], + (Double) a[12], + (Double) a[13] + )); + + static { + PARSER.declareDouble(optionalConstructorArg(), DOWNSAMPLE_FACTOR); + PARSER.declareDouble(optionalConstructorArg(), ETA); + PARSER.declareDouble(optionalConstructorArg(), ETA_GROWTH_RATE_PER_TREE); + PARSER.declareDouble(optionalConstructorArg(), FEATURE_BAG_FRACTION); + PARSER.declareInt(optionalConstructorArg(), MAX_ATTEMPTS_TO_ADD_TREE); + PARSER.declareInt(optionalConstructorArg(), MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER); + PARSER.declareInt(optionalConstructorArg(), MAX_TREES); + PARSER.declareInt(optionalConstructorArg(), NUM_FOLDS); + PARSER.declareInt(optionalConstructorArg(), NUM_SPLITS_PER_FEATURE); + PARSER.declareDouble(optionalConstructorArg(), REGULARIZATION_DEPTH_PENALTY_MULTIPLIER); + PARSER.declareDouble(optionalConstructorArg(), REGULARIZATION_LEAF_WEIGHT_PENALTY_MULTIPLIER); + PARSER.declareDouble(optionalConstructorArg(), REGULARIZATION_SOFT_TREE_DEPTH_LIMIT); + PARSER.declareDouble(optionalConstructorArg(), REGULARIZATION_SOFT_TREE_DEPTH_TOLERANCE); + PARSER.declareDouble(optionalConstructorArg(), REGULARIZATION_TREE_SIZE_PENALTY_MULTIPLIER); + } + + private final Double downsampleFactor; + private final Double eta; + private final Double etaGrowthRatePerTree; + private final Double featureBagFraction; + private final Integer maxAttemptsToAddTree; + private final Integer maxOptimizationRoundsPerHyperparameter; + private final Integer maxTrees; + private final Integer numFolds; + private final Integer numSplitsPerFeature; + private final Double regularizationDepthPenaltyMultiplier; + private final Double regularizationLeafWeightPenaltyMultiplier; + private final Double regularizationSoftTreeDepthLimit; + private final Double regularizationSoftTreeDepthTolerance; + private final Double regularizationTreeSizePenaltyMultiplier; + + public Hyperparameters(Double downsampleFactor, + Double eta, + Double etaGrowthRatePerTree, + Double featureBagFraction, + Integer maxAttemptsToAddTree, + Integer maxOptimizationRoundsPerHyperparameter, + Integer maxTrees, + Integer numFolds, + Integer numSplitsPerFeature, + Double regularizationDepthPenaltyMultiplier, + Double regularizationLeafWeightPenaltyMultiplier, + Double regularizationSoftTreeDepthLimit, + Double regularizationSoftTreeDepthTolerance, + Double regularizationTreeSizePenaltyMultiplier) { + this.downsampleFactor = downsampleFactor; + this.eta = eta; + this.etaGrowthRatePerTree = etaGrowthRatePerTree; + this.featureBagFraction = featureBagFraction; + this.maxAttemptsToAddTree = maxAttemptsToAddTree; + this.maxOptimizationRoundsPerHyperparameter = maxOptimizationRoundsPerHyperparameter; + this.maxTrees = maxTrees; + this.numFolds = numFolds; + this.numSplitsPerFeature = numSplitsPerFeature; + this.regularizationDepthPenaltyMultiplier = regularizationDepthPenaltyMultiplier; + this.regularizationLeafWeightPenaltyMultiplier = regularizationLeafWeightPenaltyMultiplier; + this.regularizationSoftTreeDepthLimit = regularizationSoftTreeDepthLimit; + this.regularizationSoftTreeDepthTolerance = regularizationSoftTreeDepthTolerance; + this.regularizationTreeSizePenaltyMultiplier = regularizationTreeSizePenaltyMultiplier; + } + + public Double getDownsampleFactor() { + return downsampleFactor; + } + + public Double getEta() { + return eta; + } + + public Double getEtaGrowthRatePerTree() { + return etaGrowthRatePerTree; + } + + public Double getFeatureBagFraction() { + return featureBagFraction; + } + + public Integer getMaxAttemptsToAddTree() { + return maxAttemptsToAddTree; + } + + public Integer getMaxOptimizationRoundsPerHyperparameter() { + return maxOptimizationRoundsPerHyperparameter; + } + + public Integer getMaxTrees() { + return maxTrees; + } + + public Integer getNumFolds() { + return numFolds; + } + + public Integer getNumSplitsPerFeature() { + return numSplitsPerFeature; + } + + public Double getRegularizationDepthPenaltyMultiplier() { + return regularizationDepthPenaltyMultiplier; + } + + public Double getRegularizationLeafWeightPenaltyMultiplier() { + return regularizationLeafWeightPenaltyMultiplier; + } + + public Double getRegularizationSoftTreeDepthLimit() { + return regularizationSoftTreeDepthLimit; + } + + public Double getRegularizationSoftTreeDepthTolerance() { + return regularizationSoftTreeDepthTolerance; + } + + public Double getRegularizationTreeSizePenaltyMultiplier() { + return regularizationTreeSizePenaltyMultiplier; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (downsampleFactor != null) { + builder.field(DOWNSAMPLE_FACTOR.getPreferredName(), downsampleFactor); + } + if (eta != null) { + builder.field(ETA.getPreferredName(), eta); + } + if (etaGrowthRatePerTree != null) { + builder.field(ETA_GROWTH_RATE_PER_TREE.getPreferredName(), etaGrowthRatePerTree); + } + if (featureBagFraction != null) { + builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction); + } + if (maxAttemptsToAddTree != null) { + builder.field(MAX_ATTEMPTS_TO_ADD_TREE.getPreferredName(), maxAttemptsToAddTree); + } + if (maxOptimizationRoundsPerHyperparameter != null) { + builder.field(MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER.getPreferredName(), maxOptimizationRoundsPerHyperparameter); + } + if (maxTrees != null) { + builder.field(MAX_TREES.getPreferredName(), maxTrees); + } + if (numFolds != null) { + builder.field(NUM_FOLDS.getPreferredName(), numFolds); + } + if (numSplitsPerFeature != null) { + builder.field(NUM_SPLITS_PER_FEATURE.getPreferredName(), numSplitsPerFeature); + } + if (regularizationDepthPenaltyMultiplier != null) { + builder.field(REGULARIZATION_DEPTH_PENALTY_MULTIPLIER.getPreferredName(), regularizationDepthPenaltyMultiplier); + } + if (regularizationLeafWeightPenaltyMultiplier != null) { + builder.field(REGULARIZATION_LEAF_WEIGHT_PENALTY_MULTIPLIER.getPreferredName(), regularizationLeafWeightPenaltyMultiplier); + } + if (regularizationSoftTreeDepthLimit != null) { + builder.field(REGULARIZATION_SOFT_TREE_DEPTH_LIMIT.getPreferredName(), regularizationSoftTreeDepthLimit); + } + if (regularizationSoftTreeDepthTolerance != null) { + builder.field(REGULARIZATION_SOFT_TREE_DEPTH_TOLERANCE.getPreferredName(), regularizationSoftTreeDepthTolerance); + } + if (regularizationTreeSizePenaltyMultiplier != null) { + builder.field(REGULARIZATION_TREE_SIZE_PENALTY_MULTIPLIER.getPreferredName(), regularizationTreeSizePenaltyMultiplier); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + Hyperparameters that = (Hyperparameters) o; + return Objects.equals(downsampleFactor, that.downsampleFactor) + && Objects.equals(eta, that.eta) + && Objects.equals(etaGrowthRatePerTree, that.etaGrowthRatePerTree) + && Objects.equals(featureBagFraction, that.featureBagFraction) + && Objects.equals(maxAttemptsToAddTree, that.maxAttemptsToAddTree) + && Objects.equals(maxOptimizationRoundsPerHyperparameter, that.maxOptimizationRoundsPerHyperparameter) + && Objects.equals(maxTrees, that.maxTrees) + && Objects.equals(numFolds, that.numFolds) + && Objects.equals(numSplitsPerFeature, that.numSplitsPerFeature) + && Objects.equals(regularizationDepthPenaltyMultiplier, that.regularizationDepthPenaltyMultiplier) + && Objects.equals(regularizationLeafWeightPenaltyMultiplier, that.regularizationLeafWeightPenaltyMultiplier) + && Objects.equals(regularizationSoftTreeDepthLimit, that.regularizationSoftTreeDepthLimit) + && Objects.equals(regularizationSoftTreeDepthTolerance, that.regularizationSoftTreeDepthTolerance) + && Objects.equals(regularizationTreeSizePenaltyMultiplier, that.regularizationTreeSizePenaltyMultiplier); + } + + @Override + public int hashCode() { + return Objects.hash( + downsampleFactor, + eta, + etaGrowthRatePerTree, + featureBagFraction, + maxAttemptsToAddTree, + maxOptimizationRoundsPerHyperparameter, + maxTrees, + numFolds, + numSplitsPerFeature, + regularizationDepthPenaltyMultiplier, + regularizationLeafWeightPenaltyMultiplier, + regularizationSoftTreeDepthLimit, + regularizationSoftTreeDepthTolerance, + regularizationTreeSizePenaltyMultiplier + ); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/stats/regression/RegressionStats.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/stats/regression/RegressionStats.java new file mode 100644 index 00000000000..7e890c3618f --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/stats/regression/RegressionStats.java @@ -0,0 +1,135 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.stats.regression; + +import org.elasticsearch.client.common.TimeUtil; +import org.elasticsearch.client.ml.dataframe.stats.AnalysisStats; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; + +import java.io.IOException; +import java.time.Instant; +import java.util.Objects; + +public class RegressionStats implements AnalysisStats { + + public static final ParseField NAME = new ParseField("regression_stats"); + + public static final ParseField TIMESTAMP = new ParseField("timestamp"); + public static final ParseField ITERATION = new ParseField("iteration"); + public static final ParseField HYPERPARAMETERS = new ParseField("hyperparameters"); + public static final ParseField TIMING_STATS = new ParseField("timing_stats"); + public static final ParseField VALIDATION_LOSS = new ParseField("validation_loss"); + + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME.getPreferredName(), + true, + a -> new RegressionStats( + (Instant) a[0], + (Integer) a[1], + (Hyperparameters) a[2], + (TimingStats) a[3], + (ValidationLoss) a[4] + ) + ); + + static { + PARSER.declareField(ConstructingObjectParser.constructorArg(), + p -> TimeUtil.parseTimeFieldToInstant(p, TIMESTAMP.getPreferredName()), + TIMESTAMP, + ObjectParser.ValueType.VALUE); + PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), ITERATION); + PARSER.declareObject(ConstructingObjectParser.constructorArg(), Hyperparameters.PARSER, HYPERPARAMETERS); + PARSER.declareObject(ConstructingObjectParser.constructorArg(), TimingStats.PARSER, TIMING_STATS); + PARSER.declareObject(ConstructingObjectParser.constructorArg(), ValidationLoss.PARSER, VALIDATION_LOSS); + } + + private final Instant timestamp; + private final Integer iteration; + private final Hyperparameters hyperparameters; + private final TimingStats timingStats; + private final ValidationLoss validationLoss; + + public RegressionStats(Instant timestamp, Integer iteration, Hyperparameters hyperparameters, TimingStats timingStats, + ValidationLoss validationLoss) { + this.timestamp = Instant.ofEpochMilli(Objects.requireNonNull(timestamp).toEpochMilli()); + this.iteration = iteration; + this.hyperparameters = Objects.requireNonNull(hyperparameters); + this.timingStats = Objects.requireNonNull(timingStats); + this.validationLoss = Objects.requireNonNull(validationLoss); + } + + public Instant getTimestamp() { + return timestamp; + } + + public Integer getIteration() { + return iteration; + } + + public Hyperparameters getHyperparameters() { + return hyperparameters; + } + + public TimingStats getTimingStats() { + return timingStats; + } + + public ValidationLoss getValidationLoss() { + return validationLoss; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(); + builder.timeField(TIMESTAMP.getPreferredName(), TIMESTAMP.getPreferredName() + "_string", timestamp.toEpochMilli()); + if (iteration != null) { + builder.field(ITERATION.getPreferredName(), iteration); + } + builder.field(HYPERPARAMETERS.getPreferredName(), hyperparameters); + builder.field(TIMING_STATS.getPreferredName(), timingStats); + builder.field(VALIDATION_LOSS.getPreferredName(), validationLoss); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + RegressionStats that = (RegressionStats) o; + return Objects.equals(timestamp, that.timestamp) + && Objects.equals(iteration, that.iteration) + && Objects.equals(hyperparameters, that.hyperparameters) + && Objects.equals(timingStats, that.timingStats) + && Objects.equals(validationLoss, that.validationLoss); + } + + @Override + public int hashCode() { + return Objects.hash(timestamp, iteration, hyperparameters, timingStats, validationLoss); + } + + @Override + public String getName() { + return NAME.getPreferredName(); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/stats/regression/TimingStats.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/stats/regression/TimingStats.java new file mode 100644 index 00000000000..1a844a410f4 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/stats/regression/TimingStats.java @@ -0,0 +1,87 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.stats.regression; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Objects; + +public class TimingStats implements ToXContentObject { + + public static final ParseField ELAPSED_TIME = new ParseField("elapsed_time"); + public static final ParseField ITERATION_TIME = new ParseField("iteration_time"); + + public static ConstructingObjectParser PARSER = new ConstructingObjectParser<>("regression_timing_stats", true, + a -> new TimingStats( + a[0] == null ? null : TimeValue.timeValueMillis((long) a[0]), + a[1] == null ? null : TimeValue.timeValueMillis((long) a[1]) + )); + + static { + PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), ELAPSED_TIME); + PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), ITERATION_TIME); + } + + private final TimeValue elapsedTime; + private final TimeValue iterationTime; + + public TimingStats(TimeValue elapsedTime, TimeValue iterationTime) { + this.elapsedTime = elapsedTime; + this.iterationTime = iterationTime; + } + + public TimeValue getElapsedTime() { + return elapsedTime; + } + + public TimeValue getIterationTime() { + return iterationTime; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (elapsedTime != null) { + builder.humanReadableField(ELAPSED_TIME.getPreferredName(), ELAPSED_TIME.getPreferredName() + "_string", elapsedTime); + } + if (iterationTime != null) { + builder.humanReadableField(ITERATION_TIME.getPreferredName(), ITERATION_TIME.getPreferredName() + "_string", iterationTime); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + TimingStats that = (TimingStats) o; + return Objects.equals(elapsedTime, that.elapsedTime) && Objects.equals(iterationTime, that.iterationTime); + } + + @Override + public int hashCode() { + return Objects.hash(elapsedTime, iterationTime); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/stats/regression/ValidationLoss.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/stats/regression/ValidationLoss.java new file mode 100644 index 00000000000..ee2513b0f39 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/stats/regression/ValidationLoss.java @@ -0,0 +1,87 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.stats.regression; + +import org.elasticsearch.client.ml.dataframe.stats.common.FoldValues; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +public class ValidationLoss implements ToXContentObject { + + public static final ParseField LOSS_TYPE = new ParseField("loss_type"); + public static final ParseField FOLD_VALUES = new ParseField("fold_values"); + + @SuppressWarnings("unchecked") + public static ConstructingObjectParser PARSER = new ConstructingObjectParser<>("regression_validation_loss", + true, + a -> new ValidationLoss((String) a[0], (List) a[1])); + + static { + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), LOSS_TYPE); + PARSER.declareObjectArray(ConstructingObjectParser.optionalConstructorArg(), FoldValues.PARSER, FOLD_VALUES); + } + + private final String lossType; + private final List foldValues; + + public ValidationLoss(String lossType, List values) { + this.lossType = lossType; + this.foldValues = values; + } + + public String getLossType() { + return lossType; + } + + public List getFoldValues() { + return foldValues; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (lossType != null) { + builder.field(LOSS_TYPE.getPreferredName(), lossType); + } + if (foldValues != null) { + builder.field(FOLD_VALUES.getPreferredName(), foldValues); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ValidationLoss that = (ValidationLoss) o; + return Objects.equals(lossType, that.lossType) && Objects.equals(foldValues, that.foldValues); + } + + @Override + public int hashCode() { + return Objects.hash(lossType, foldValues); + } +} diff --git a/client/rest-high-level/src/main/resources/META-INF/services/org.elasticsearch.plugins.spi.NamedXContentProvider b/client/rest-high-level/src/main/resources/META-INF/services/org.elasticsearch.plugins.spi.NamedXContentProvider index c3facfa93ff..45d9e3da908 100644 --- a/client/rest-high-level/src/main/resources/META-INF/services/org.elasticsearch.plugins.spi.NamedXContentProvider +++ b/client/rest-high-level/src/main/resources/META-INF/services/org.elasticsearch.plugins.spi.NamedXContentProvider @@ -1,5 +1,6 @@ org.elasticsearch.client.indexlifecycle.IndexLifecycleNamedXContentProvider org.elasticsearch.client.ml.dataframe.MlDataFrameAnalysisNamedXContentProvider org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider +org.elasticsearch.client.ml.dataframe.stats.AnalysisStatsNamedXContentProvider org.elasticsearch.client.ml.inference.MlInferenceNamedXContentProvider org.elasticsearch.client.transform.TransformNamedXContentProvider diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java index 395c62ff837..a9a42c979ab 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java @@ -91,6 +91,7 @@ import org.elasticsearch.client.ml.datafeed.DatafeedConfigTests; import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.client.ml.dataframe.MlDataFrameAnalysisNamedXContentProvider; import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; +import org.elasticsearch.client.ml.dataframe.stats.AnalysisStatsNamedXContentProvider; import org.elasticsearch.client.ml.filestructurefinder.FileStructure; import org.elasticsearch.client.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.client.ml.inference.TrainedModelConfig; @@ -1067,6 +1068,7 @@ public class MLRequestConvertersTests extends ESTestCase { namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers()); namedXContent.addAll(new MlEvaluationNamedXContentProvider().getNamedXContentParsers()); namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); + namedXContent.addAll(new AnalysisStatsNamedXContentProvider().getNamedXContentParsers()); return new NamedXContentRegistry(namedXContent); } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java index f781856bfe6..e4c4f43356a 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java @@ -20,7 +20,6 @@ package org.elasticsearch.client; import com.fasterxml.jackson.core.JsonParseException; - import org.apache.http.HttpEntity; import org.apache.http.HttpHost; import org.apache.http.HttpResponse; @@ -69,6 +68,9 @@ import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.Binar import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.PrecisionMetric; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.RecallMetric; +import org.elasticsearch.client.ml.dataframe.stats.classification.ClassificationStats; +import org.elasticsearch.client.ml.dataframe.stats.outlierdetection.OutlierDetectionStats; +import org.elasticsearch.client.ml.dataframe.stats.regression.RegressionStats; import org.elasticsearch.client.ml.inference.preprocessing.CustomWordEmbedding; import org.elasticsearch.client.ml.inference.preprocessing.FrequencyEncoding; import org.elasticsearch.client.ml.inference.preprocessing.OneHotEncoding; @@ -697,7 +699,7 @@ public class RestHighLevelClientTests extends ESTestCase { public void testProvidedNamedXContents() { List namedXContents = RestHighLevelClient.getProvidedNamedXContents(); - assertEquals(59, namedXContents.size()); + assertEquals(62, namedXContents.size()); Map, Integer> categories = new HashMap<>(); List names = new ArrayList<>(); for (NamedXContentRegistry.Entry namedXContent : namedXContents) { @@ -707,7 +709,7 @@ public class RestHighLevelClientTests extends ESTestCase { categories.put(namedXContent.categoryClass, counter + 1); } } - assertEquals("Had: " + categories, 12, categories.size()); + assertEquals("Had: " + categories, 13, categories.size()); assertEquals(Integer.valueOf(3), categories.get(Aggregation.class)); assertTrue(names.contains(ChildrenAggregationBuilder.NAME)); assertTrue(names.contains(MatrixStatsAggregationBuilder.NAME)); @@ -737,6 +739,9 @@ public class RestHighLevelClientTests extends ESTestCase { assertTrue(names.contains(OutlierDetection.NAME.getPreferredName())); assertTrue(names.contains(org.elasticsearch.client.ml.dataframe.Regression.NAME.getPreferredName())); assertTrue(names.contains(org.elasticsearch.client.ml.dataframe.Classification.NAME.getPreferredName())); + assertTrue(names.contains(OutlierDetectionStats.NAME.getPreferredName())); + assertTrue(names.contains(RegressionStats.NAME.getPreferredName())); + assertTrue(names.contains(ClassificationStats.NAME.getPreferredName())); assertEquals(Integer.valueOf(1), categories.get(SyncConfig.class)); assertTrue(names.contains(TimeSyncConfig.NAME)); assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class)); diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStatsTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStatsTests.java index 48ebf71e360..25345181982 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStatsTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStatsTests.java @@ -20,6 +20,13 @@ package org.elasticsearch.client.ml.dataframe; import org.elasticsearch.client.ml.NodeAttributesTests; +import org.elasticsearch.client.ml.dataframe.stats.AnalysisStats; +import org.elasticsearch.client.ml.dataframe.stats.AnalysisStatsNamedXContentProvider; +import org.elasticsearch.client.ml.dataframe.stats.classification.ClassificationStatsTests; +import org.elasticsearch.client.ml.dataframe.stats.common.MemoryUsageTests; +import org.elasticsearch.client.ml.dataframe.stats.outlierdetection.OutlierDetectionStatsTests; +import org.elasticsearch.client.ml.dataframe.stats.regression.RegressionStatsTests; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.test.ESTestCase; @@ -31,23 +38,38 @@ import static org.elasticsearch.test.AbstractXContentTestCase.xContentTester; public class DataFrameAnalyticsStatsTests extends ESTestCase { + @Override + protected NamedXContentRegistry xContentRegistry() { + List namedXContent = new ArrayList<>(); + namedXContent.addAll(new AnalysisStatsNamedXContentProvider().getNamedXContentParsers()); + return new NamedXContentRegistry(namedXContent); + } + public void testFromXContent() throws IOException { xContentTester(this::createParser, DataFrameAnalyticsStatsTests::randomDataFrameAnalyticsStats, DataFrameAnalyticsStatsTests::toXContent, DataFrameAnalyticsStats::fromXContent) .supportsUnknownFields(true) - .randomFieldsExcludeFilter(field -> field.startsWith("node.attributes")) + .randomFieldsExcludeFilter(field -> field.startsWith("node.attributes") || field.startsWith("analysis_stats")) .test(); } public static DataFrameAnalyticsStats randomDataFrameAnalyticsStats() { + AnalysisStats analysisStats = randomBoolean() ? null : + randomFrom( + ClassificationStatsTests.createRandom(), + OutlierDetectionStatsTests.createRandom(), + RegressionStatsTests.createRandom() + ); + return new DataFrameAnalyticsStats( randomAlphaOfLengthBetween(1, 10), randomFrom(DataFrameAnalyticsState.values()), randomBoolean() ? null : randomAlphaOfLength(10), randomBoolean() ? null : createRandomProgress(), randomBoolean() ? null : MemoryUsageTests.createRandom(), + analysisStats, randomBoolean() ? null : NodeAttributesTests.createRandom(), randomBoolean() ? null : randomAlphaOfLengthBetween(1, 20)); } @@ -74,6 +96,11 @@ public class DataFrameAnalyticsStatsTests extends ESTestCase { if (stats.getMemoryUsage() != null) { builder.field(DataFrameAnalyticsStats.MEMORY_USAGE.getPreferredName(), stats.getMemoryUsage()); } + if (stats.getAnalysisStats() != null) { + builder.startObject("analysis_stats"); + builder.field(stats.getAnalysisStats().getName(), stats.getAnalysisStats()); + builder.endObject(); + } if (stats.getNode() != null) { builder.field(DataFrameAnalyticsStats.NODE.getPreferredName(), stats.getNode()); } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/stats/classification/ClassificationStatsTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/stats/classification/ClassificationStatsTests.java new file mode 100644 index 00000000000..d23633c01d2 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/stats/classification/ClassificationStatsTests.java @@ -0,0 +1,53 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.stats.classification; + +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; +import java.time.Instant; + +public class ClassificationStatsTests extends AbstractXContentTestCase { + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected ClassificationStats doParseInstance(XContentParser parser) throws IOException { + return ClassificationStats.PARSER.apply(parser, null); + } + + @Override + protected ClassificationStats createTestInstance() { + return createRandom(); + } + + public static ClassificationStats createRandom() { + return new ClassificationStats( + Instant.now(), + randomBoolean() ? null : randomIntBetween(1, Integer.MAX_VALUE), + HyperparametersTests.createRandom(), + TimingStatsTests.createRandom(), + ValidationLossTests.createRandom() + ); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/stats/classification/HyperparametersTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/stats/classification/HyperparametersTests.java new file mode 100644 index 00000000000..aa1ab12c542 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/stats/classification/HyperparametersTests.java @@ -0,0 +1,62 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.stats.classification; + +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; + +public class HyperparametersTests extends AbstractXContentTestCase { + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected Hyperparameters doParseInstance(XContentParser parser) throws IOException { + return Hyperparameters.PARSER.apply(parser, null); + } + + @Override + protected Hyperparameters createTestInstance() { + return createRandom(); + } + + public static Hyperparameters createRandom() { + return new Hyperparameters( + randomBoolean() ? null : randomAlphaOfLength(10), + randomBoolean() ? null : randomDouble(), + randomBoolean() ? null : randomDouble(), + randomBoolean() ? null : randomDouble(), + randomBoolean() ? null : randomDouble(), + randomBoolean() ? null : randomIntBetween(0, Integer.MAX_VALUE), + randomBoolean() ? null : randomIntBetween(0, Integer.MAX_VALUE), + randomBoolean() ? null : randomIntBetween(0, Integer.MAX_VALUE), + randomBoolean() ? null : randomIntBetween(0, Integer.MAX_VALUE), + randomBoolean() ? null : randomIntBetween(0, Integer.MAX_VALUE), + randomBoolean() ? null : randomDouble(), + randomBoolean() ? null : randomDouble(), + randomBoolean() ? null : randomDouble(), + randomBoolean() ? null : randomDouble(), + randomBoolean() ? null : randomDouble() + ); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/stats/classification/TimingStatsTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/stats/classification/TimingStatsTests.java new file mode 100644 index 00000000000..5e2c4c842e1 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/stats/classification/TimingStatsTests.java @@ -0,0 +1,50 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.stats.classification; + +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; + +public class TimingStatsTests extends AbstractXContentTestCase { + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected TimingStats doParseInstance(XContentParser parser) throws IOException { + return TimingStats.PARSER.apply(parser, null); + } + + @Override + protected TimingStats createTestInstance() { + return createRandom(); + } + + public static TimingStats createRandom() { + return new TimingStats( + randomBoolean() ? null : TimeValue.timeValueMillis(randomNonNegativeLong()), + randomBoolean() ? null : TimeValue.timeValueMillis(randomNonNegativeLong()) + ); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/stats/classification/ValidationLossTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/stats/classification/ValidationLossTests.java new file mode 100644 index 00000000000..c841af43d43 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/stats/classification/ValidationLossTests.java @@ -0,0 +1,50 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.stats.classification; + +import org.elasticsearch.client.ml.dataframe.stats.common.FoldValuesTests; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; + +public class ValidationLossTests extends AbstractXContentTestCase { + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected ValidationLoss doParseInstance(XContentParser parser) throws IOException { + return ValidationLoss.PARSER.apply(parser, null); + } + + @Override + protected ValidationLoss createTestInstance() { + return createRandom(); + } + + public static ValidationLoss createRandom() { + return new ValidationLoss( + randomBoolean() ? null : randomAlphaOfLength(10), + randomBoolean() ? null : randomList(5, () -> FoldValuesTests.createRandom()) + ); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/stats/common/FoldValuesTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/stats/common/FoldValuesTests.java new file mode 100644 index 00000000000..90d92193276 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/stats/common/FoldValuesTests.java @@ -0,0 +1,51 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.stats.common; + +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; + +public class FoldValuesTests extends AbstractXContentTestCase { + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected FoldValues doParseInstance(XContentParser parser) throws IOException { + return FoldValues.PARSER.apply(parser, null); + } + + @Override + protected FoldValues createTestInstance() { + return createRandom(); + } + + public static FoldValues createRandom() { + int valuesSize = randomIntBetween(0, 10); + double[] values = new double[valuesSize]; + for (int i = 0; i < valuesSize; i++) { + values[i] = randomDouble(); + } + return new FoldValues(randomIntBetween(0, Integer.MAX_VALUE), values); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/MemoryUsageTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/stats/common/MemoryUsageTests.java similarity index 96% rename from client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/MemoryUsageTests.java rename to client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/stats/common/MemoryUsageTests.java index 8e06db6f2b3..0e272957521 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/MemoryUsageTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/stats/common/MemoryUsageTests.java @@ -16,7 +16,7 @@ * specific language governing permissions and limitations * under the License. */ -package org.elasticsearch.client.ml.dataframe; +package org.elasticsearch.client.ml.dataframe.stats.common; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.test.AbstractXContentTestCase; diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/stats/outlierdetection/OutlierDetectionStatsTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/stats/outlierdetection/OutlierDetectionStatsTests.java new file mode 100644 index 00000000000..f40de67a62c --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/stats/outlierdetection/OutlierDetectionStatsTests.java @@ -0,0 +1,51 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.stats.outlierdetection; + +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; +import java.time.Instant; + +public class OutlierDetectionStatsTests extends AbstractXContentTestCase { + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected OutlierDetectionStats doParseInstance(XContentParser parser) throws IOException { + return OutlierDetectionStats.PARSER.apply(parser, null); + } + + @Override + protected OutlierDetectionStats createTestInstance() { + return createRandom(); + } + + public static OutlierDetectionStats createRandom() { + return new OutlierDetectionStats( + Instant.now(), + ParametersTests.createRandom(), + TimingStatsTests.createRandom() + ); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/stats/outlierdetection/ParametersTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/stats/outlierdetection/ParametersTests.java new file mode 100644 index 00000000000..4f566562683 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/stats/outlierdetection/ParametersTests.java @@ -0,0 +1,53 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.stats.outlierdetection; + +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; + +public class ParametersTests extends AbstractXContentTestCase { + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected Parameters doParseInstance(XContentParser parser) throws IOException { + return Parameters.PARSER.apply(parser, null); + } + + @Override + protected Parameters createTestInstance() { + return createRandom(); + } + + public static Parameters createRandom() { + return new Parameters( + randomBoolean() ? null : randomIntBetween(1, Integer.MAX_VALUE), + randomBoolean() ? null : randomAlphaOfLength(5), + randomBoolean() ? null : randomBoolean(), + randomBoolean() ? null : randomDouble(), + randomBoolean() ? null : randomDouble(), + randomBoolean() ? null : randomBoolean() + ); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/stats/outlierdetection/TimingStatsTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/stats/outlierdetection/TimingStatsTests.java new file mode 100644 index 00000000000..5483782e1d1 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/stats/outlierdetection/TimingStatsTests.java @@ -0,0 +1,48 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.stats.outlierdetection; + +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; + +public class TimingStatsTests extends AbstractXContentTestCase { + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + + @Override + protected TimingStats doParseInstance(XContentParser parser) throws IOException { + return TimingStats.PARSER.apply(parser, null); + } + + @Override + protected TimingStats createTestInstance() { + return createRandom(); + } + + public static TimingStats createRandom() { + return new TimingStats(randomBoolean() ? null : TimeValue.timeValueMillis(randomNonNegativeLong())); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/stats/regression/HyperparametersTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/stats/regression/HyperparametersTests.java new file mode 100644 index 00000000000..43d0571bb20 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/stats/regression/HyperparametersTests.java @@ -0,0 +1,62 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.stats.regression; + +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; + +public class HyperparametersTests extends AbstractXContentTestCase { + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected Hyperparameters doParseInstance(XContentParser parser) throws IOException { + return Hyperparameters.PARSER.apply(parser, null); + } + + + @Override + protected Hyperparameters createTestInstance() { + return createRandom(); + } + + public static Hyperparameters createRandom() { + return new Hyperparameters( + randomBoolean() ? null : randomDouble(), + randomBoolean() ? null : randomDouble(), + randomBoolean() ? null : randomDouble(), + randomBoolean() ? null : randomDouble(), + randomBoolean() ? null : randomIntBetween(0, Integer.MAX_VALUE), + randomBoolean() ? null : randomIntBetween(0, Integer.MAX_VALUE), + randomBoolean() ? null : randomIntBetween(0, Integer.MAX_VALUE), + randomBoolean() ? null : randomIntBetween(0, Integer.MAX_VALUE), + randomBoolean() ? null : randomIntBetween(0, Integer.MAX_VALUE), + randomBoolean() ? null : randomDouble(), + randomBoolean() ? null : randomDouble(), + randomBoolean() ? null : randomDouble(), + randomBoolean() ? null : randomDouble(), + randomBoolean() ? null : randomDouble() + ); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/stats/regression/RegressionStatsTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/stats/regression/RegressionStatsTests.java new file mode 100644 index 00000000000..d4e784bb335 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/stats/regression/RegressionStatsTests.java @@ -0,0 +1,54 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.stats.regression; + +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; +import java.time.Instant; + +public class RegressionStatsTests extends AbstractXContentTestCase { + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected RegressionStats doParseInstance(XContentParser parser) throws IOException { + return RegressionStats.PARSER.apply(parser, null); + } + + + @Override + protected RegressionStats createTestInstance() { + return createRandom(); + } + + public static RegressionStats createRandom() { + return new RegressionStats( + Instant.now(), + randomBoolean() ? null : randomIntBetween(1, Integer.MAX_VALUE), + HyperparametersTests.createRandom(), + TimingStatsTests.createRandom(), + ValidationLossTests.createRandom() + ); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/stats/regression/TimingStatsTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/stats/regression/TimingStatsTests.java new file mode 100644 index 00000000000..95fe6531f3b --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/stats/regression/TimingStatsTests.java @@ -0,0 +1,50 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.stats.regression; + +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; + +public class TimingStatsTests extends AbstractXContentTestCase { + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected TimingStats doParseInstance(XContentParser parser) throws IOException { + return TimingStats.PARSER.apply(parser, null); + } + + @Override + protected TimingStats createTestInstance() { + return createRandom(); + } + + public static TimingStats createRandom() { + return new TimingStats( + randomBoolean() ? null : TimeValue.timeValueMillis(randomNonNegativeLong()), + randomBoolean() ? null : TimeValue.timeValueMillis(randomNonNegativeLong()) + ); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/stats/regression/ValidationLossTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/stats/regression/ValidationLossTests.java new file mode 100644 index 00000000000..d2a9f960bbb --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/stats/regression/ValidationLossTests.java @@ -0,0 +1,50 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.stats.regression; + +import org.elasticsearch.client.ml.dataframe.stats.common.FoldValuesTests; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; + +public class ValidationLossTests extends AbstractXContentTestCase { + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected ValidationLoss doParseInstance(XContentParser parser) throws IOException { + return ValidationLoss.PARSER.apply(parser, null); + } + + @Override + protected ValidationLoss createTestInstance() { + return createRandom(); + } + + public static ValidationLoss createRandom() { + return new ValidationLoss( + randomBoolean() ? null : randomAlphaOfLength(10), + randomBoolean() ? null : randomList(5, () -> FoldValuesTests.createRandom()) + ); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java index 7c2b869c1a1..ff551398db8 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java @@ -39,18 +39,15 @@ import org.elasticsearch.transport.Transport; import org.elasticsearch.xpack.core.action.XPackInfoAction; import org.elasticsearch.xpack.core.action.XPackUsageAction; import org.elasticsearch.xpack.core.analytics.AnalyticsFeatureSetUsage; -import org.elasticsearch.xpack.core.search.action.DeleteAsyncSearchAction; -import org.elasticsearch.xpack.core.search.action.GetAsyncSearchAction; -import org.elasticsearch.xpack.core.search.action.SubmitAsyncSearchAction; import org.elasticsearch.xpack.core.ccr.AutoFollowMetadata; import org.elasticsearch.xpack.core.ccr.CCRFeatureSet; import org.elasticsearch.xpack.core.deprecation.DeprecationInfoAction; import org.elasticsearch.xpack.core.enrich.EnrichFeatureSet; -import org.elasticsearch.xpack.core.eql.EqlFeatureSetUsage; import org.elasticsearch.xpack.core.enrich.action.DeleteEnrichPolicyAction; import org.elasticsearch.xpack.core.enrich.action.ExecuteEnrichPolicyAction; import org.elasticsearch.xpack.core.enrich.action.GetEnrichPolicyAction; import org.elasticsearch.xpack.core.enrich.action.PutEnrichPolicyAction; +import org.elasticsearch.xpack.core.eql.EqlFeatureSetUsage; import org.elasticsearch.xpack.core.flattened.FlattenedFeatureSetUsage; import org.elasticsearch.xpack.core.frozen.FrozenIndicesFeatureSetUsage; import org.elasticsearch.xpack.core.frozen.action.FreezeIndexAction; @@ -152,6 +149,10 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis; import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStats; +import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStats; +import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStats; +import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStats; import org.elasticsearch.xpack.core.ml.inference.preprocessing.CustomWordEmbedding; import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding; import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding; @@ -184,6 +185,9 @@ import org.elasticsearch.xpack.core.rollup.action.StartRollupJobAction; import org.elasticsearch.xpack.core.rollup.action.StopRollupJobAction; import org.elasticsearch.xpack.core.rollup.job.RollupJob; import org.elasticsearch.xpack.core.rollup.job.RollupJobStatus; +import org.elasticsearch.xpack.core.search.action.DeleteAsyncSearchAction; +import org.elasticsearch.xpack.core.search.action.GetAsyncSearchAction; +import org.elasticsearch.xpack.core.search.action.SubmitAsyncSearchAction; import org.elasticsearch.xpack.core.security.SecurityFeatureSetUsage; import org.elasticsearch.xpack.core.security.SecurityField; import org.elasticsearch.xpack.core.security.SecuritySettings; @@ -505,6 +509,9 @@ public class XPackClientPlugin extends Plugin implements ActionPlugin, NetworkPl new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, OutlierDetection.NAME.getPreferredName(), OutlierDetection::new), new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, Regression.NAME.getPreferredName(), Regression::new), new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, Classification.NAME.getPreferredName(), Classification::new), + new NamedWriteableRegistry.Entry(AnalysisStats.class, OutlierDetectionStats.TYPE_VALUE, OutlierDetectionStats::new), + new NamedWriteableRegistry.Entry(AnalysisStats.class, RegressionStats.TYPE_VALUE, RegressionStats::new), + new NamedWriteableRegistry.Entry(AnalysisStats.class, ClassificationStats.TYPE_VALUE, ClassificationStats::new), // ML - Inference preprocessing new NamedWriteableRegistry.Entry(PreProcessor.class, FrequencyEncoding.NAME.getPreferredName(), FrequencyEncoding::new), new NamedWriteableRegistry.Entry(PreProcessor.class, OneHotEncoding.NAME.getPreferredName(), OneHotEncoding::new), diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsAction.java index c0590797ead..209058e0046 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsAction.java @@ -28,6 +28,7 @@ import org.elasticsearch.xpack.core.action.util.PageParams; import org.elasticsearch.xpack.core.action.util.QueryPage; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; +import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStats; import org.elasticsearch.xpack.core.ml.dataframe.stats.MemoryUsage; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.PhaseProgress; @@ -167,18 +168,23 @@ public class GetDataFrameAnalyticsStatsAction extends ActionType progress, - @Nullable MemoryUsage memoryUsage, @Nullable DiscoveryNode node, @Nullable String assignmentExplanation) { + @Nullable MemoryUsage memoryUsage, @Nullable AnalysisStats analysisStats, @Nullable DiscoveryNode node, + @Nullable String assignmentExplanation) { this.id = Objects.requireNonNull(id); this.state = Objects.requireNonNull(state); this.failureReason = failureReason; this.progress = Objects.requireNonNull(progress); this.memoryUsage = memoryUsage; + this.analysisStats = analysisStats; this.node = node; this.assignmentExplanation = assignmentExplanation; } @@ -197,6 +203,11 @@ public class GetDataFrameAnalyticsStatsAction extends ActionType getNamedWriteables() { + return Arrays.asList( + new NamedWriteableRegistry.Entry(AnalysisStats.class, ClassificationStats.TYPE_VALUE, ClassificationStats::new), + new NamedWriteableRegistry.Entry(AnalysisStats.class, OutlierDetectionStats.TYPE_VALUE, OutlierDetectionStats::new), + new NamedWriteableRegistry.Entry(AnalysisStats.class, RegressionStats.TYPE_VALUE, RegressionStats::new) + ); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/stats/Fields.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/stats/Fields.java new file mode 100644 index 00000000000..26b2424a63c --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/stats/Fields.java @@ -0,0 +1,20 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.stats; + +import org.elasticsearch.common.ParseField; + +/** + * A collection of parse fields commonly used by stats objects + */ +public final class Fields { + + public static final ParseField TYPE = new ParseField("type"); + public static final ParseField JOB_ID = new ParseField("job_id"); + public static final ParseField TIMESTAMP = new ParseField("timestamp"); + + private Fields() {} +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/stats/MemoryUsage.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/stats/MemoryUsage.java index 5131d88d959..b672c3809d1 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/stats/MemoryUsage.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/stats/MemoryUsage.java @@ -26,9 +26,6 @@ public class MemoryUsage implements Writeable, ToXContentObject { public static final String TYPE_VALUE = "analytics_memory_usage"; - public static final ParseField TYPE = new ParseField("type"); - public static final ParseField JOB_ID = new ParseField("job_id"); - public static final ParseField TIMESTAMP = new ParseField("timestamp"); public static final ParseField PEAK_USAGE_BYTES = new ParseField("peak_usage_bytes"); public static final ConstructingObjectParser STRICT_PARSER = createParser(false); @@ -38,11 +35,11 @@ public class MemoryUsage implements Writeable, ToXContentObject { ConstructingObjectParser parser = new ConstructingObjectParser<>(TYPE_VALUE, ignoreUnknownFields, a -> new MemoryUsage((String) a[0], (Instant) a[1], (long) a[2])); - parser.declareString((bucket, s) -> {}, TYPE); - parser.declareString(ConstructingObjectParser.constructorArg(), JOB_ID); + parser.declareString((bucket, s) -> {}, Fields.TYPE); + parser.declareString(ConstructingObjectParser.constructorArg(), Fields.JOB_ID); parser.declareField(ConstructingObjectParser.constructorArg(), - p -> TimeUtils.parseTimeFieldToInstant(p, TIMESTAMP.getPreferredName()), - TIMESTAMP, + p -> TimeUtils.parseTimeFieldToInstant(p, Fields.TIMESTAMP.getPreferredName()), + Fields.TIMESTAMP, ObjectParser.ValueType.VALUE); parser.declareLong(ConstructingObjectParser.constructorArg(), PEAK_USAGE_BYTES); return parser; @@ -56,7 +53,7 @@ public class MemoryUsage implements Writeable, ToXContentObject { this.jobId = Objects.requireNonNull(jobId); // We intend to store this timestamp in millis granularity. Thus we're rounding here to ensure // internal representation matches toXContent - this.timestamp = Instant.ofEpochMilli(ExceptionsHelper.requireNonNull(timestamp, TIMESTAMP).toEpochMilli()); + this.timestamp = Instant.ofEpochMilli(ExceptionsHelper.requireNonNull(timestamp, Fields.TIMESTAMP).toEpochMilli()); this.peakUsageBytes = peakUsageBytes; } @@ -77,10 +74,10 @@ public class MemoryUsage implements Writeable, ToXContentObject { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); if (params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false)) { - builder.field(TYPE.getPreferredName(), TYPE_VALUE); - builder.field(JOB_ID.getPreferredName(), jobId); + builder.field(Fields.TYPE.getPreferredName(), TYPE_VALUE); + builder.field(Fields.JOB_ID.getPreferredName(), jobId); } - builder.timeField(TIMESTAMP.getPreferredName(), TIMESTAMP.getPreferredName() + "_string", timestamp.toEpochMilli()); + builder.timeField(Fields.TIMESTAMP.getPreferredName(), Fields.TIMESTAMP.getPreferredName() + "_string", timestamp.toEpochMilli()); builder.field(PEAK_USAGE_BYTES.getPreferredName(), peakUsageBytes); builder.endObject(); return builder; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/stats/classification/ClassificationStats.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/stats/classification/ClassificationStats.java new file mode 100644 index 00000000000..7cd2186db79 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/stats/classification/ClassificationStats.java @@ -0,0 +1,148 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.stats.classification; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.common.time.TimeUtils; +import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStats; +import org.elasticsearch.xpack.core.ml.dataframe.stats.Fields; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; + +import java.io.IOException; +import java.time.Instant; +import java.util.Objects; + +public class ClassificationStats implements AnalysisStats { + + public static final String TYPE_VALUE = "classification_stats"; + + public static final ParseField ITERATION = new ParseField("iteration"); + public static final ParseField HYPERPARAMETERS = new ParseField("hyperparameters"); + public static final ParseField TIMING_STATS = new ParseField("timing_stats"); + public static final ParseField VALIDATION_LOSS = new ParseField("validation_loss"); + + public static final ConstructingObjectParser STRICT_PARSER = createParser(false); + public static final ConstructingObjectParser LENIENT_PARSER = createParser(true); + + private static ConstructingObjectParser createParser(boolean ignoreUnknownFields) { + ConstructingObjectParser parser = new ConstructingObjectParser<>(TYPE_VALUE, ignoreUnknownFields, + a -> new ClassificationStats( + (String) a[0], + (Instant) a[1], + (int) a[2], + (Hyperparameters) a[3], + (TimingStats) a[4], + (ValidationLoss) a[5] + ) + ); + + parser.declareString((bucket, s) -> {}, Fields.TYPE); + parser.declareString(ConstructingObjectParser.constructorArg(), Fields.JOB_ID); + parser.declareField(ConstructingObjectParser.constructorArg(), + p -> TimeUtils.parseTimeFieldToInstant(p, Fields.TIMESTAMP.getPreferredName()), + Fields.TIMESTAMP, + ObjectParser.ValueType.VALUE); + parser.declareInt(ConstructingObjectParser.constructorArg(), ITERATION); + parser.declareObject(ConstructingObjectParser.constructorArg(), + (p, c) -> Hyperparameters.fromXContent(p, ignoreUnknownFields), HYPERPARAMETERS); + parser.declareObject(ConstructingObjectParser.constructorArg(), + (p, c) -> TimingStats.fromXContent(p, ignoreUnknownFields), TIMING_STATS); + parser.declareObject(ConstructingObjectParser.constructorArg(), + (p, c) -> ValidationLoss.fromXContent(p, ignoreUnknownFields), VALIDATION_LOSS); + return parser; + } + + private final String jobId; + private final Instant timestamp; + private final int iteration; + private final Hyperparameters hyperparameters; + private final TimingStats timingStats; + private final ValidationLoss validationLoss; + + public ClassificationStats(String jobId, Instant timestamp, int iteration, Hyperparameters hyperparameters, TimingStats timingStats, + ValidationLoss validationLoss) { + this.jobId = Objects.requireNonNull(jobId); + // We intend to store this timestamp in millis granularity. Thus we're rounding here to ensure + // internal representation matches toXContent + this.timestamp = Instant.ofEpochMilli(ExceptionsHelper.requireNonNull(timestamp, Fields.TIMESTAMP).toEpochMilli()); + this.iteration = iteration; + this.hyperparameters = Objects.requireNonNull(hyperparameters); + this.timingStats = Objects.requireNonNull(timingStats); + this.validationLoss = Objects.requireNonNull(validationLoss); + } + + public ClassificationStats(StreamInput in) throws IOException { + this.jobId = in.readString(); + this.timestamp = in.readInstant(); + this.iteration = in.readVInt(); + this.hyperparameters = new Hyperparameters(in); + this.timingStats = new TimingStats(in); + this.validationLoss = new ValidationLoss(in); + } + + @Override + public String getWriteableName() { + return TYPE_VALUE; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(jobId); + out.writeInstant(timestamp); + out.writeVInt(iteration); + hyperparameters.writeTo(out); + timingStats.writeTo(out); + validationLoss.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false)) { + builder.field(Fields.TYPE.getPreferredName(), TYPE_VALUE); + builder.field(Fields.JOB_ID.getPreferredName(), jobId); + } + builder.timeField(Fields.TIMESTAMP.getPreferredName(), Fields.TIMESTAMP.getPreferredName() + "_string", timestamp.toEpochMilli()); + builder.field(ITERATION.getPreferredName(), iteration); + builder.field(HYPERPARAMETERS.getPreferredName(), hyperparameters); + builder.field(TIMING_STATS.getPreferredName(), timingStats); + builder.field(VALIDATION_LOSS.getPreferredName(), validationLoss); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ClassificationStats that = (ClassificationStats) o; + return Objects.equals(jobId, that.jobId) + && Objects.equals(timestamp, that.timestamp) + && iteration == that.iteration + && Objects.equals(hyperparameters, that.hyperparameters) + && Objects.equals(timingStats, that.timingStats) + && Objects.equals(validationLoss, that.validationLoss); + } + + @Override + public int hashCode() { + return Objects.hash(jobId, timestamp, iteration, hyperparameters, timingStats, validationLoss); + } + + public String documentId(String jobId) { + return documentIdPrefix(jobId) + timestamp.toEpochMilli(); + } + + public static String documentIdPrefix(String jobId) { + return TYPE_VALUE + "_" + jobId + "_"; + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/stats/classification/Hyperparameters.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/stats/classification/Hyperparameters.java new file mode 100644 index 00000000000..c7f2c901428 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/stats/classification/Hyperparameters.java @@ -0,0 +1,237 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.stats.classification; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; + +public class Hyperparameters implements ToXContentObject, Writeable { + + public static final ParseField CLASS_ASSIGNMENT_OBJECTIVE = new ParseField("class_assignment_objective"); + public static final ParseField DOWNSAMPLE_FACTOR = new ParseField("downsample_factor"); + public static final ParseField ETA = new ParseField("eta"); + public static final ParseField ETA_GROWTH_RATE_PER_TREE = new ParseField("eta_growth_rate_per_tree"); + public static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction"); + public static final ParseField MAX_ATTEMPTS_TO_ADD_TREE = new ParseField("max_attempts_to_add_tree"); + public static final ParseField MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER = new ParseField( + "max_optimization_rounds_per_hyperparameter"); + public static final ParseField MAX_TREES = new ParseField("max_trees"); + public static final ParseField NUM_FOLDS = new ParseField("num_folds"); + public static final ParseField NUM_SPLITS_PER_FEATURE = new ParseField("num_splits_per_feature"); + public static final ParseField REGULARIZATION_DEPTH_PENALTY_MULTIPLIER = new ParseField("regularization_depth_penalty_multiplier"); + public static final ParseField REGULARIZATION_LEAF_WEIGHT_PENALTY_MULTIPLIER + = new ParseField("regularization_leaf_weight_penalty_multiplier"); + public static final ParseField REGULARIZATION_SOFT_TREE_DEPTH_LIMIT = new ParseField("regularization_soft_tree_depth_limit"); + public static final ParseField REGULARIZATION_SOFT_TREE_DEPTH_TOLERANCE = new ParseField("regularization_soft_tree_depth_tolerance"); + public static final ParseField REGULARIZATION_TREE_SIZE_PENALTY_MULTIPLIER = + new ParseField("regularization_tree_size_penalty_multiplier"); + + public static Hyperparameters fromXContent(XContentParser parser, boolean ignoreUnknownFields) { + return createParser(ignoreUnknownFields).apply(parser, null); + } + + private static ConstructingObjectParser createParser(boolean ignoreUnknownFields) { + ConstructingObjectParser parser = new ConstructingObjectParser<>("classification_hyperparameters", + ignoreUnknownFields, + a -> new Hyperparameters( + (String) a[0], + (double) a[1], + (double) a[2], + (double) a[3], + (double) a[4], + (int) a[5], + (int) a[6], + (int) a[7], + (int) a[8], + (int) a[9], + (double) a[10], + (double) a[11], + (double) a[12], + (double) a[13], + (double) a[14] + )); + + parser.declareString(constructorArg(), CLASS_ASSIGNMENT_OBJECTIVE); + parser.declareDouble(constructorArg(), DOWNSAMPLE_FACTOR); + parser.declareDouble(constructorArg(), ETA); + parser.declareDouble(constructorArg(), ETA_GROWTH_RATE_PER_TREE); + parser.declareDouble(constructorArg(), FEATURE_BAG_FRACTION); + parser.declareInt(constructorArg(), MAX_ATTEMPTS_TO_ADD_TREE); + parser.declareInt(constructorArg(), MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER); + parser.declareInt(constructorArg(), MAX_TREES); + parser.declareInt(constructorArg(), NUM_FOLDS); + parser.declareInt(constructorArg(), NUM_SPLITS_PER_FEATURE); + parser.declareDouble(constructorArg(), REGULARIZATION_DEPTH_PENALTY_MULTIPLIER); + parser.declareDouble(constructorArg(), REGULARIZATION_LEAF_WEIGHT_PENALTY_MULTIPLIER); + parser.declareDouble(constructorArg(), REGULARIZATION_SOFT_TREE_DEPTH_LIMIT); + parser.declareDouble(constructorArg(), REGULARIZATION_SOFT_TREE_DEPTH_TOLERANCE); + parser.declareDouble(constructorArg(), REGULARIZATION_TREE_SIZE_PENALTY_MULTIPLIER); + + return parser; + } + + private final String classAssignmentObjective; + private final double downsampleFactor; + private final double eta; + private final double etaGrowthRatePerTree; + private final double featureBagFraction; + private final int maxAttemptsToAddTree; + private final int maxOptimizationRoundsPerHyperparameter; + private final int maxTrees; + private final int numFolds; + private final int numSplitsPerFeature; + private final double regularizationDepthPenaltyMultiplier; + private final double regularizationLeafWeightPenaltyMultiplier; + private final double regularizationSoftTreeDepthLimit; + private final double regularizationSoftTreeDepthTolerance; + private final double regularizationTreeSizePenaltyMultiplier; + + public Hyperparameters(String classAssignmentObjective, + double downsampleFactor, + double eta, + double etaGrowthRatePerTree, + double featureBagFraction, + int maxAttemptsToAddTree, + int maxOptimizationRoundsPerHyperparameter, + int maxTrees, + int numFolds, + int numSplitsPerFeature, + double regularizationDepthPenaltyMultiplier, + double regularizationLeafWeightPenaltyMultiplier, + double regularizationSoftTreeDepthLimit, + double regularizationSoftTreeDepthTolerance, + double regularizationTreeSizePenaltyMultiplier) { + this.classAssignmentObjective = Objects.requireNonNull(classAssignmentObjective); + this.downsampleFactor = downsampleFactor; + this.eta = eta; + this.etaGrowthRatePerTree = etaGrowthRatePerTree; + this.featureBagFraction = featureBagFraction; + this.maxAttemptsToAddTree = maxAttemptsToAddTree; + this.maxOptimizationRoundsPerHyperparameter = maxOptimizationRoundsPerHyperparameter; + this.maxTrees = maxTrees; + this.numFolds = numFolds; + this.numSplitsPerFeature = numSplitsPerFeature; + this.regularizationDepthPenaltyMultiplier = regularizationDepthPenaltyMultiplier; + this.regularizationLeafWeightPenaltyMultiplier = regularizationLeafWeightPenaltyMultiplier; + this.regularizationSoftTreeDepthLimit = regularizationSoftTreeDepthLimit; + this.regularizationSoftTreeDepthTolerance = regularizationSoftTreeDepthTolerance; + this.regularizationTreeSizePenaltyMultiplier = regularizationTreeSizePenaltyMultiplier; + } + + public Hyperparameters(StreamInput in) throws IOException { + this.classAssignmentObjective = in.readString(); + this.downsampleFactor = in.readDouble(); + this.eta = in.readDouble(); + this.etaGrowthRatePerTree = in.readDouble(); + this.featureBagFraction = in.readDouble(); + this.maxAttemptsToAddTree = in.readVInt(); + this.maxOptimizationRoundsPerHyperparameter = in.readVInt(); + this.maxTrees = in.readVInt(); + this.numFolds = in.readVInt(); + this.numSplitsPerFeature = in.readVInt(); + this.regularizationDepthPenaltyMultiplier = in.readDouble(); + this.regularizationLeafWeightPenaltyMultiplier = in.readDouble(); + this.regularizationSoftTreeDepthLimit = in.readDouble(); + this.regularizationSoftTreeDepthTolerance = in.readDouble(); + this.regularizationTreeSizePenaltyMultiplier = in.readDouble(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(classAssignmentObjective); + out.writeDouble(downsampleFactor); + out.writeDouble(eta); + out.writeDouble(etaGrowthRatePerTree); + out.writeDouble(featureBagFraction); + out.writeVInt(maxAttemptsToAddTree); + out.writeVInt(maxOptimizationRoundsPerHyperparameter); + out.writeVInt(maxTrees); + out.writeVInt(numFolds); + out.writeVInt(numSplitsPerFeature); + out.writeDouble(regularizationDepthPenaltyMultiplier); + out.writeDouble(regularizationLeafWeightPenaltyMultiplier); + out.writeDouble(regularizationSoftTreeDepthLimit); + out.writeDouble(regularizationSoftTreeDepthTolerance); + out.writeDouble(regularizationTreeSizePenaltyMultiplier); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(CLASS_ASSIGNMENT_OBJECTIVE.getPreferredName(), classAssignmentObjective); + builder.field(DOWNSAMPLE_FACTOR.getPreferredName(), downsampleFactor); + builder.field(ETA.getPreferredName(), eta); + builder.field(ETA_GROWTH_RATE_PER_TREE.getPreferredName(), etaGrowthRatePerTree); + builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction); + builder.field(MAX_ATTEMPTS_TO_ADD_TREE.getPreferredName(), maxAttemptsToAddTree); + builder.field(MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER.getPreferredName(), maxOptimizationRoundsPerHyperparameter); + builder.field(MAX_TREES.getPreferredName(), maxTrees); + builder.field(NUM_FOLDS.getPreferredName(), numFolds); + builder.field(NUM_SPLITS_PER_FEATURE.getPreferredName(), numSplitsPerFeature); + builder.field(REGULARIZATION_DEPTH_PENALTY_MULTIPLIER.getPreferredName(), regularizationDepthPenaltyMultiplier); + builder.field(REGULARIZATION_LEAF_WEIGHT_PENALTY_MULTIPLIER.getPreferredName(), regularizationLeafWeightPenaltyMultiplier); + builder.field(REGULARIZATION_SOFT_TREE_DEPTH_LIMIT.getPreferredName(), regularizationSoftTreeDepthLimit); + builder.field(REGULARIZATION_SOFT_TREE_DEPTH_TOLERANCE.getPreferredName(), regularizationSoftTreeDepthTolerance); + builder.field(REGULARIZATION_TREE_SIZE_PENALTY_MULTIPLIER.getPreferredName(), regularizationTreeSizePenaltyMultiplier); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + Hyperparameters that = (Hyperparameters) o; + return Objects.equals(classAssignmentObjective, that.classAssignmentObjective) + && downsampleFactor == that.downsampleFactor + && eta == that.eta + && etaGrowthRatePerTree == that.etaGrowthRatePerTree + && featureBagFraction == that.featureBagFraction + && maxAttemptsToAddTree == that.maxAttemptsToAddTree + && maxOptimizationRoundsPerHyperparameter == that.maxOptimizationRoundsPerHyperparameter + && maxTrees == that.maxTrees + && numFolds == that.numFolds + && numSplitsPerFeature == that.numSplitsPerFeature + && regularizationDepthPenaltyMultiplier == that.regularizationDepthPenaltyMultiplier + && regularizationLeafWeightPenaltyMultiplier == that.regularizationLeafWeightPenaltyMultiplier + && regularizationSoftTreeDepthLimit == that.regularizationSoftTreeDepthLimit + && regularizationSoftTreeDepthTolerance == that.regularizationSoftTreeDepthTolerance + && regularizationTreeSizePenaltyMultiplier == that.regularizationTreeSizePenaltyMultiplier; + } + + @Override + public int hashCode() { + return Objects.hash( + classAssignmentObjective, + downsampleFactor, + eta, + etaGrowthRatePerTree, + featureBagFraction, + maxAttemptsToAddTree, + maxOptimizationRoundsPerHyperparameter, + maxTrees, + numFolds, + numSplitsPerFeature, + regularizationDepthPenaltyMultiplier, + regularizationLeafWeightPenaltyMultiplier, + regularizationSoftTreeDepthLimit, + regularizationSoftTreeDepthTolerance, + regularizationTreeSizePenaltyMultiplier + ); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/stats/classification/TimingStats.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/stats/classification/TimingStats.java new file mode 100644 index 00000000000..07245c88f20 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/stats/classification/TimingStats.java @@ -0,0 +1,80 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.stats.classification; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Objects; + +public class TimingStats implements Writeable, ToXContentObject { + + public static final ParseField ELAPSED_TIME = new ParseField("elapsed_time"); + public static final ParseField ITERATION_TIME = new ParseField("iteration_time"); + + public static TimingStats fromXContent(XContentParser parser, boolean ignoreUnknownFields) { + return createParser(ignoreUnknownFields).apply(parser, null); + } + + private static ConstructingObjectParser createParser(boolean ignoreUnknownFields) { + ConstructingObjectParser parser = new ConstructingObjectParser<>("classification_timing_stats", + ignoreUnknownFields, + a -> new TimingStats(TimeValue.timeValueMillis((long) a[0]), TimeValue.timeValueMillis((long) a[1]))); + + parser.declareLong(ConstructingObjectParser.constructorArg(), ELAPSED_TIME); + parser.declareLong(ConstructingObjectParser.constructorArg(), ITERATION_TIME); + return parser; + } + + private final TimeValue elapsedTime; + private final TimeValue iterationTime; + + public TimingStats(TimeValue elapsedTime, TimeValue iterationTime) { + this.elapsedTime = Objects.requireNonNull(elapsedTime); + this.iterationTime = Objects.requireNonNull(iterationTime); + } + + public TimingStats(StreamInput in) throws IOException { + this.elapsedTime = in.readTimeValue(); + this.iterationTime = in.readTimeValue(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeTimeValue(elapsedTime); + out.writeTimeValue(iterationTime); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.humanReadableField(ELAPSED_TIME.getPreferredName(), ELAPSED_TIME.getPreferredName() + "_string", elapsedTime); + builder.humanReadableField(ITERATION_TIME.getPreferredName(), ITERATION_TIME.getPreferredName() + "_string", iterationTime); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + TimingStats that = (TimingStats) o; + return Objects.equals(elapsedTime, that.elapsedTime) && Objects.equals(iterationTime, that.iterationTime); + } + + @Override + public int hashCode() { + return Objects.hash(elapsedTime, iterationTime); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/stats/classification/ValidationLoss.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/stats/classification/ValidationLoss.java new file mode 100644 index 00000000000..5526ae063a4 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/stats/classification/ValidationLoss.java @@ -0,0 +1,82 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.stats.classification; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.dataframe.stats.common.FoldValues; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +public class ValidationLoss implements ToXContentObject, Writeable { + + public static final ParseField LOSS_TYPE = new ParseField("loss_type"); + public static final ParseField FOLD_VALUES = new ParseField("fold_values"); + + public static ValidationLoss fromXContent(XContentParser parser, boolean ignoreUnknownFields) { + return createParser(ignoreUnknownFields).apply(parser, null); + } + + private static ConstructingObjectParser createParser(boolean ignoreUnknownFields) { + ConstructingObjectParser parser = new ConstructingObjectParser<>("classification_validation_loss", + ignoreUnknownFields, + a -> new ValidationLoss((String) a[0], (List) a[1])); + + parser.declareString(ConstructingObjectParser.constructorArg(), LOSS_TYPE); + parser.declareObjectArray(ConstructingObjectParser.constructorArg(), + (p, c) -> FoldValues.fromXContent(p, ignoreUnknownFields), FOLD_VALUES); + return parser; + } + + private final String lossType; + private final List foldValues; + + public ValidationLoss(String lossType, List values) { + this.lossType = Objects.requireNonNull(lossType); + this.foldValues = Objects.requireNonNull(values); + } + + public ValidationLoss(StreamInput in) throws IOException { + lossType = in.readString(); + foldValues = in.readList(FoldValues::new); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(lossType); + out.writeList(foldValues); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(LOSS_TYPE.getPreferredName(), lossType); + builder.field(FOLD_VALUES.getPreferredName(), foldValues); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ValidationLoss that = (ValidationLoss) o; + return Objects.equals(lossType, that.lossType) && Objects.equals(foldValues, that.foldValues); + } + + @Override + public int hashCode() { + return Objects.hash(lossType, foldValues); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/stats/common/FoldValues.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/stats/common/FoldValues.java new file mode 100644 index 00000000000..1638f18fd49 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/stats/common/FoldValues.java @@ -0,0 +1,84 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.stats.common; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; + +public class FoldValues implements Writeable, ToXContentObject { + + public static final ParseField FOLD = new ParseField("fold"); + public static final ParseField VALUES = new ParseField("values"); + + public static FoldValues fromXContent(XContentParser parser, boolean ignoreUnknownFields) { + return createParser(ignoreUnknownFields).apply(parser, null); + } + + private static ConstructingObjectParser createParser(boolean ignoreUnknownFields) { + ConstructingObjectParser parser = new ConstructingObjectParser<>("fold_values", ignoreUnknownFields, + a -> new FoldValues((int) a[0], (List) a[1])); + parser.declareInt(ConstructingObjectParser.constructorArg(), FOLD); + parser.declareDoubleArray(ConstructingObjectParser.constructorArg(), VALUES); + return parser; + } + + private final int fold; + private final double[] values; + + private FoldValues(int fold, List values) { + this(fold, values.stream().mapToDouble(Double::doubleValue).toArray()); + } + + public FoldValues(int fold, double[] values) { + this.fold = fold; + this.values = values; + } + + public FoldValues(StreamInput in) throws IOException { + fold = in.readVInt(); + values = in.readDoubleArray(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeVInt(fold); + out.writeDoubleArray(values); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(FOLD.getPreferredName(), fold); + builder.array(VALUES.getPreferredName(), values); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (o == this) return true; + if (o == null || getClass() != o.getClass()) return false; + + FoldValues other = (FoldValues) o; + return fold == other.fold && Arrays.equals(values, other.values); + } + + @Override + public int hashCode() { + return Objects.hash(fold, Arrays.hashCode(values)); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/stats/outlierdetection/OutlierDetectionStats.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/stats/outlierdetection/OutlierDetectionStats.java new file mode 100644 index 00000000000..e1e24f36b87 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/stats/outlierdetection/OutlierDetectionStats.java @@ -0,0 +1,122 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.common.time.TimeUtils; +import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStats; +import org.elasticsearch.xpack.core.ml.dataframe.stats.Fields; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; + +import java.io.IOException; +import java.time.Instant; +import java.util.Objects; + +public class OutlierDetectionStats implements AnalysisStats { + + public static final String TYPE_VALUE = "outlier_detection_stats"; + + public static final ParseField PARAMETERS = new ParseField("parameters"); + public static final ParseField TIMING_STATS = new ParseField("timings_stats"); + + public static final ConstructingObjectParser STRICT_PARSER = createParser(false); + public static final ConstructingObjectParser LENIENT_PARSER = createParser(true); + + private static ConstructingObjectParser createParser(boolean ignoreUnknownFields) { + ConstructingObjectParser parser = new ConstructingObjectParser<>(TYPE_VALUE, ignoreUnknownFields, + a -> new OutlierDetectionStats((String) a[0], (Instant) a[1], (Parameters) a[2], (TimingStats) a[3])); + + parser.declareString((bucket, s) -> {}, Fields.TYPE); + parser.declareString(ConstructingObjectParser.constructorArg(), Fields.JOB_ID); + parser.declareField(ConstructingObjectParser.constructorArg(), + p -> TimeUtils.parseTimeFieldToInstant(p, Fields.TIMESTAMP.getPreferredName()), + Fields.TIMESTAMP, + ObjectParser.ValueType.VALUE); + parser.declareObject(ConstructingObjectParser.constructorArg(), + (p, c) -> Parameters.fromXContent(p, ignoreUnknownFields), PARAMETERS); + parser.declareObject(ConstructingObjectParser.constructorArg(), + (p, c) -> TimingStats.fromXContent(p, ignoreUnknownFields), TIMING_STATS); + return parser; + } + + private final String jobId; + private final Instant timestamp; + private final Parameters parameters; + private final TimingStats timingStats; + + public OutlierDetectionStats(String jobId, Instant timestamp, Parameters parameters, TimingStats timingStats) { + this.jobId = Objects.requireNonNull(jobId); + // We intend to store this timestamp in millis granularity. Thus we're rounding here to ensure + // internal representation matches toXContent + this.timestamp = Instant.ofEpochMilli(ExceptionsHelper.requireNonNull(timestamp, Fields.TIMESTAMP).toEpochMilli()); + this.parameters = Objects.requireNonNull(parameters); + this.timingStats = Objects.requireNonNull(timingStats); + } + + public OutlierDetectionStats(StreamInput in) throws IOException { + this.jobId = in.readString(); + this.timestamp = in.readInstant(); + this.parameters = new Parameters(in); + this.timingStats = new TimingStats(in); + } + + @Override + public String getWriteableName() { + return TYPE_VALUE; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(jobId); + out.writeInstant(timestamp); + parameters.writeTo(out); + timingStats.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false)) { + builder.field(Fields.TYPE.getPreferredName(), TYPE_VALUE); + builder.field(Fields.JOB_ID.getPreferredName(), jobId); + } + builder.timeField(Fields.TIMESTAMP.getPreferredName(), Fields.TIMESTAMP.getPreferredName() + "_string", timestamp.toEpochMilli()); + builder.field(PARAMETERS.getPreferredName(), parameters); + builder.field(TIMING_STATS.getPreferredName(), timingStats); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + OutlierDetectionStats that = (OutlierDetectionStats) o; + return Objects.equals(jobId, that.jobId) + && Objects.equals(timestamp, that.timestamp) + && Objects.equals(parameters, that.parameters) + && Objects.equals(timingStats, that.timingStats); + } + + @Override + public int hashCode() { + return Objects.hash(jobId, timestamp, parameters, timingStats); + } + + public String documentId(String jobId) { + return documentIdPrefix(jobId) + timestamp.toEpochMilli(); + } + + public static String documentIdPrefix(String jobId) { + return TYPE_VALUE + "_" + jobId + "_"; + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/stats/outlierdetection/Parameters.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/stats/outlierdetection/Parameters.java new file mode 100644 index 00000000000..79c74e457b3 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/stats/outlierdetection/Parameters.java @@ -0,0 +1,125 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; + +public class Parameters implements Writeable, ToXContentObject { + + public static final ParseField N_NEIGHBORS = new ParseField("n_neighbors"); + public static final ParseField METHOD = new ParseField("method"); + public static final ParseField FEATURE_INFLUENCE_THRESHOLD = new ParseField("feature_influence_threshold"); + public static final ParseField COMPUTE_FEATURE_INFLUENCE = new ParseField("compute_feature_influence"); + public static final ParseField OUTLIER_FRACTION = new ParseField("outlier_fraction"); + public static final ParseField STANDARDIZATION_ENABLED = new ParseField("standardization_enabled"); + + public static Parameters fromXContent(XContentParser parser, boolean ignoreUnknownFields) { + return createParser(ignoreUnknownFields).apply(parser, null); + } + + private static ConstructingObjectParser createParser(boolean ignoreUnknownFields) { + ConstructingObjectParser parser = new ConstructingObjectParser<>("outlier_detection_parameters", + ignoreUnknownFields, + a -> new Parameters( + (int) a[0], + (String) a[1], + (boolean) a[2], + (double) a[3], + (double) a[4], + (boolean) a[5] + )); + + parser.declareInt(constructorArg(), N_NEIGHBORS); + parser.declareString(constructorArg(), METHOD); + parser.declareBoolean(constructorArg(), COMPUTE_FEATURE_INFLUENCE); + parser.declareDouble(constructorArg(), FEATURE_INFLUENCE_THRESHOLD); + parser.declareDouble(constructorArg(), OUTLIER_FRACTION); + parser.declareBoolean(constructorArg(), STANDARDIZATION_ENABLED); + + return parser; + } + + private final int nNeighbors; + private final String method; + private final boolean computeFeatureInfluence; + private final double featureInfluenceThreshold; + private final double outlierFraction; + private final boolean standardizationEnabled; + + public Parameters(int nNeighbors, String method, boolean computeFeatureInfluence, double featureInfluenceThreshold, + double outlierFraction, boolean standardizationEnabled) { + this.nNeighbors = nNeighbors; + this.method = method; + this.computeFeatureInfluence = computeFeatureInfluence; + this.featureInfluenceThreshold = featureInfluenceThreshold; + this.outlierFraction = outlierFraction; + this.standardizationEnabled = standardizationEnabled; + } + + public Parameters(StreamInput in) throws IOException { + this.nNeighbors = in.readVInt(); + this.method = in.readString(); + this.computeFeatureInfluence = in.readBoolean(); + this.featureInfluenceThreshold = in.readDouble(); + this.outlierFraction = in.readDouble(); + this.standardizationEnabled = in.readBoolean(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeVInt(nNeighbors); + out.writeString(method); + out.writeBoolean(computeFeatureInfluence); + out.writeDouble(featureInfluenceThreshold); + out.writeDouble(outlierFraction); + out.writeBoolean(standardizationEnabled); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(N_NEIGHBORS.getPreferredName(), nNeighbors); + builder.field(METHOD.getPreferredName(), method); + builder.field(COMPUTE_FEATURE_INFLUENCE.getPreferredName(), computeFeatureInfluence); + builder.field(FEATURE_INFLUENCE_THRESHOLD.getPreferredName(), featureInfluenceThreshold); + builder.field(OUTLIER_FRACTION.getPreferredName(), outlierFraction); + builder.field(STANDARDIZATION_ENABLED.getPreferredName(), standardizationEnabled); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + Parameters that = (Parameters) o; + return nNeighbors == that.nNeighbors + && Objects.equals(method, that.method) + && computeFeatureInfluence == that.computeFeatureInfluence + && featureInfluenceThreshold == that.featureInfluenceThreshold + && outlierFraction == that.outlierFraction + && standardizationEnabled == that.standardizationEnabled; + } + + @Override + public int hashCode() { + return Objects.hash(nNeighbors, method, computeFeatureInfluence, featureInfluenceThreshold, outlierFraction, + standardizationEnabled); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/stats/outlierdetection/TimingStats.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/stats/outlierdetection/TimingStats.java new file mode 100644 index 00000000000..7721cd3fb67 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/stats/outlierdetection/TimingStats.java @@ -0,0 +1,73 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Objects; + +public class TimingStats implements Writeable, ToXContentObject { + + public static final ParseField ELAPSED_TIME = new ParseField("elapsed_time"); + + public static TimingStats fromXContent(XContentParser parser, boolean ignoreUnknownFields) { + return createParser(ignoreUnknownFields).apply(parser, null); + } + + private static ConstructingObjectParser createParser(boolean ignoreUnknownFields) { + ConstructingObjectParser parser = new ConstructingObjectParser<>("outlier_detection_timing_stats", + ignoreUnknownFields, + a -> new TimingStats(TimeValue.timeValueMillis((long) a[0]))); + + parser.declareLong(ConstructingObjectParser.constructorArg(), ELAPSED_TIME); + return parser; + } + + private final TimeValue elapsedTime; + + public TimingStats(TimeValue elapsedTime) { + this.elapsedTime = Objects.requireNonNull(elapsedTime); + } + + public TimingStats(StreamInput in) throws IOException { + this.elapsedTime = in.readTimeValue(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeTimeValue(elapsedTime); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.humanReadableField(ELAPSED_TIME.getPreferredName(), ELAPSED_TIME.getPreferredName() + "_string", elapsedTime); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + TimingStats that = (TimingStats) o; + return Objects.equals(elapsedTime, that.elapsedTime); + } + + @Override + public int hashCode() { + return Objects.hash(elapsedTime); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/stats/regression/Hyperparameters.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/stats/regression/Hyperparameters.java new file mode 100644 index 00000000000..332aa32a48b --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/stats/regression/Hyperparameters.java @@ -0,0 +1,226 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.stats.regression; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; + +public class Hyperparameters implements ToXContentObject, Writeable { + + public static final ParseField DOWNSAMPLE_FACTOR = new ParseField("downsample_factor"); + public static final ParseField ETA = new ParseField("eta"); + public static final ParseField ETA_GROWTH_RATE_PER_TREE = new ParseField("eta_growth_rate_per_tree"); + public static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction"); + public static final ParseField MAX_ATTEMPTS_TO_ADD_TREE = new ParseField("max_attempts_to_add_tree"); + public static final ParseField MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER = new ParseField( + "max_optimization_rounds_per_hyperparameter"); + public static final ParseField MAX_TREES = new ParseField("max_trees"); + public static final ParseField NUM_FOLDS = new ParseField("num_folds"); + public static final ParseField NUM_SPLITS_PER_FEATURE = new ParseField("num_splits_per_feature"); + public static final ParseField REGULARIZATION_DEPTH_PENALTY_MULTIPLIER = new ParseField("regularization_depth_penalty_multiplier"); + public static final ParseField REGULARIZATION_LEAF_WEIGHT_PENALTY_MULTIPLIER + = new ParseField("regularization_leaf_weight_penalty_multiplier"); + public static final ParseField REGULARIZATION_SOFT_TREE_DEPTH_LIMIT = new ParseField("regularization_soft_tree_depth_limit"); + public static final ParseField REGULARIZATION_SOFT_TREE_DEPTH_TOLERANCE = new ParseField("regularization_soft_tree_depth_tolerance"); + public static final ParseField REGULARIZATION_TREE_SIZE_PENALTY_MULTIPLIER = + new ParseField("regularization_tree_size_penalty_multiplier"); + + public static Hyperparameters fromXContent(XContentParser parser, boolean ignoreUnknownFields) { + return createParser(ignoreUnknownFields).apply(parser, null); + } + + private static ConstructingObjectParser createParser(boolean ignoreUnknownFields) { + ConstructingObjectParser parser = new ConstructingObjectParser<>("regression_hyperparameters", + ignoreUnknownFields, + a -> new Hyperparameters( + (double) a[0], + (double) a[1], + (double) a[2], + (double) a[3], + (int) a[4], + (int) a[5], + (int) a[6], + (int) a[7], + (int) a[8], + (double) a[9], + (double) a[10], + (double) a[11], + (double) a[12], + (double) a[13] + )); + + parser.declareDouble(constructorArg(), DOWNSAMPLE_FACTOR); + parser.declareDouble(constructorArg(), ETA); + parser.declareDouble(constructorArg(), ETA_GROWTH_RATE_PER_TREE); + parser.declareDouble(constructorArg(), FEATURE_BAG_FRACTION); + parser.declareInt(constructorArg(), MAX_ATTEMPTS_TO_ADD_TREE); + parser.declareInt(constructorArg(), MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER); + parser.declareInt(constructorArg(), MAX_TREES); + parser.declareInt(constructorArg(), NUM_FOLDS); + parser.declareInt(constructorArg(), NUM_SPLITS_PER_FEATURE); + parser.declareDouble(constructorArg(), REGULARIZATION_DEPTH_PENALTY_MULTIPLIER); + parser.declareDouble(constructorArg(), REGULARIZATION_LEAF_WEIGHT_PENALTY_MULTIPLIER); + parser.declareDouble(constructorArg(), REGULARIZATION_SOFT_TREE_DEPTH_LIMIT); + parser.declareDouble(constructorArg(), REGULARIZATION_SOFT_TREE_DEPTH_TOLERANCE); + parser.declareDouble(constructorArg(), REGULARIZATION_TREE_SIZE_PENALTY_MULTIPLIER); + + return parser; + } + + private final double downsampleFactor; + private final double eta; + private final double etaGrowthRatePerTree; + private final double featureBagFraction; + private final int maxAttemptsToAddTree; + private final int maxOptimizationRoundsPerHyperparameter; + private final int maxTrees; + private final int numFolds; + private final int numSplitsPerFeature; + private final double regularizationDepthPenaltyMultiplier; + private final double regularizationLeafWeightPenaltyMultiplier; + private final double regularizationSoftTreeDepthLimit; + private final double regularizationSoftTreeDepthTolerance; + private final double regularizationTreeSizePenaltyMultiplier; + + public Hyperparameters(double downsampleFactor, + double eta, + double etaGrowthRatePerTree, + double featureBagFraction, + int maxAttemptsToAddTree, + int maxOptimizationRoundsPerHyperparameter, + int maxTrees, + int numFolds, + int numSplitsPerFeature, + double regularizationDepthPenaltyMultiplier, + double regularizationLeafWeightPenaltyMultiplier, + double regularizationSoftTreeDepthLimit, + double regularizationSoftTreeDepthTolerance, + double regularizationTreeSizePenaltyMultiplier) { + this.downsampleFactor = downsampleFactor; + this.eta = eta; + this.etaGrowthRatePerTree = etaGrowthRatePerTree; + this.featureBagFraction = featureBagFraction; + this.maxAttemptsToAddTree = maxAttemptsToAddTree; + this.maxOptimizationRoundsPerHyperparameter = maxOptimizationRoundsPerHyperparameter; + this.maxTrees = maxTrees; + this.numFolds = numFolds; + this.numSplitsPerFeature = numSplitsPerFeature; + this.regularizationDepthPenaltyMultiplier = regularizationDepthPenaltyMultiplier; + this.regularizationLeafWeightPenaltyMultiplier = regularizationLeafWeightPenaltyMultiplier; + this.regularizationSoftTreeDepthLimit = regularizationSoftTreeDepthLimit; + this.regularizationSoftTreeDepthTolerance = regularizationSoftTreeDepthTolerance; + this.regularizationTreeSizePenaltyMultiplier = regularizationTreeSizePenaltyMultiplier; + } + + public Hyperparameters(StreamInput in) throws IOException { + this.downsampleFactor = in.readDouble(); + this.eta = in.readDouble(); + this.etaGrowthRatePerTree = in.readDouble(); + this.featureBagFraction = in.readDouble(); + this.maxAttemptsToAddTree = in.readVInt(); + this.maxOptimizationRoundsPerHyperparameter = in.readVInt(); + this.maxTrees = in.readVInt(); + this.numFolds = in.readVInt(); + this.numSplitsPerFeature = in.readVInt(); + this.regularizationDepthPenaltyMultiplier = in.readDouble(); + this.regularizationLeafWeightPenaltyMultiplier = in.readDouble(); + this.regularizationSoftTreeDepthLimit = in.readDouble(); + this.regularizationSoftTreeDepthTolerance = in.readDouble(); + this.regularizationTreeSizePenaltyMultiplier = in.readDouble(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeDouble(downsampleFactor); + out.writeDouble(eta); + out.writeDouble(etaGrowthRatePerTree); + out.writeDouble(featureBagFraction); + out.writeVInt(maxAttemptsToAddTree); + out.writeVInt(maxOptimizationRoundsPerHyperparameter); + out.writeVInt(maxTrees); + out.writeVInt(numFolds); + out.writeVInt(numSplitsPerFeature); + out.writeDouble(regularizationDepthPenaltyMultiplier); + out.writeDouble(regularizationLeafWeightPenaltyMultiplier); + out.writeDouble(regularizationSoftTreeDepthLimit); + out.writeDouble(regularizationSoftTreeDepthTolerance); + out.writeDouble(regularizationTreeSizePenaltyMultiplier); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(DOWNSAMPLE_FACTOR.getPreferredName(), downsampleFactor); + builder.field(ETA.getPreferredName(), eta); + builder.field(ETA_GROWTH_RATE_PER_TREE.getPreferredName(), etaGrowthRatePerTree); + builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction); + builder.field(MAX_ATTEMPTS_TO_ADD_TREE.getPreferredName(), maxAttemptsToAddTree); + builder.field(MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER.getPreferredName(), maxOptimizationRoundsPerHyperparameter); + builder.field(MAX_TREES.getPreferredName(), maxTrees); + builder.field(NUM_FOLDS.getPreferredName(), numFolds); + builder.field(NUM_SPLITS_PER_FEATURE.getPreferredName(), numSplitsPerFeature); + builder.field(REGULARIZATION_DEPTH_PENALTY_MULTIPLIER.getPreferredName(), regularizationDepthPenaltyMultiplier); + builder.field(REGULARIZATION_LEAF_WEIGHT_PENALTY_MULTIPLIER.getPreferredName(), regularizationLeafWeightPenaltyMultiplier); + builder.field(REGULARIZATION_SOFT_TREE_DEPTH_LIMIT.getPreferredName(), regularizationSoftTreeDepthLimit); + builder.field(REGULARIZATION_SOFT_TREE_DEPTH_TOLERANCE.getPreferredName(), regularizationSoftTreeDepthTolerance); + builder.field(REGULARIZATION_TREE_SIZE_PENALTY_MULTIPLIER.getPreferredName(), regularizationTreeSizePenaltyMultiplier); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + Hyperparameters that = (Hyperparameters) o; + return downsampleFactor == that.downsampleFactor + && eta == that.eta + && etaGrowthRatePerTree == that.etaGrowthRatePerTree + && featureBagFraction == that.featureBagFraction + && maxAttemptsToAddTree == that.maxAttemptsToAddTree + && maxOptimizationRoundsPerHyperparameter == that.maxOptimizationRoundsPerHyperparameter + && maxTrees == that.maxTrees + && numFolds == that.numFolds + && numSplitsPerFeature == that.numSplitsPerFeature + && regularizationDepthPenaltyMultiplier == that.regularizationDepthPenaltyMultiplier + && regularizationLeafWeightPenaltyMultiplier == that.regularizationLeafWeightPenaltyMultiplier + && regularizationSoftTreeDepthLimit == that.regularizationSoftTreeDepthLimit + && regularizationSoftTreeDepthTolerance == that.regularizationSoftTreeDepthTolerance + && regularizationTreeSizePenaltyMultiplier == that.regularizationTreeSizePenaltyMultiplier; + } + + @Override + public int hashCode() { + return Objects.hash( + downsampleFactor, + eta, + etaGrowthRatePerTree, + featureBagFraction, + maxAttemptsToAddTree, + maxOptimizationRoundsPerHyperparameter, + maxTrees, + numFolds, + numSplitsPerFeature, + regularizationDepthPenaltyMultiplier, + regularizationLeafWeightPenaltyMultiplier, + regularizationSoftTreeDepthLimit, + regularizationSoftTreeDepthTolerance, + regularizationTreeSizePenaltyMultiplier + ); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/stats/regression/RegressionStats.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/stats/regression/RegressionStats.java new file mode 100644 index 00000000000..c6afc7d4f37 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/stats/regression/RegressionStats.java @@ -0,0 +1,148 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.stats.regression; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.common.time.TimeUtils; +import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStats; +import org.elasticsearch.xpack.core.ml.dataframe.stats.Fields; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; + +import java.io.IOException; +import java.time.Instant; +import java.util.Objects; + +public class RegressionStats implements AnalysisStats { + + public static final String TYPE_VALUE = "regression_stats"; + + public static final ParseField ITERATION = new ParseField("iteration"); + public static final ParseField HYPERPARAMETERS = new ParseField("hyperparameters"); + public static final ParseField TIMING_STATS = new ParseField("timing_stats"); + public static final ParseField VALIDATION_LOSS = new ParseField("validation_loss"); + + public static final ConstructingObjectParser STRICT_PARSER = createParser(false); + public static final ConstructingObjectParser LENIENT_PARSER = createParser(true); + + private static ConstructingObjectParser createParser(boolean ignoreUnknownFields) { + ConstructingObjectParser parser = new ConstructingObjectParser<>(TYPE_VALUE, ignoreUnknownFields, + a -> new RegressionStats( + (String) a[0], + (Instant) a[1], + (int) a[2], + (Hyperparameters) a[3], + (TimingStats) a[4], + (ValidationLoss) a[5] + ) + ); + + parser.declareString((bucket, s) -> {}, Fields.TYPE); + parser.declareString(ConstructingObjectParser.constructorArg(), Fields.JOB_ID); + parser.declareField(ConstructingObjectParser.constructorArg(), + p -> TimeUtils.parseTimeFieldToInstant(p, Fields.TIMESTAMP.getPreferredName()), + Fields.TIMESTAMP, + ObjectParser.ValueType.VALUE); + parser.declareInt(ConstructingObjectParser.constructorArg(), ITERATION); + parser.declareObject(ConstructingObjectParser.constructorArg(), + (p, c) -> Hyperparameters.fromXContent(p, ignoreUnknownFields), HYPERPARAMETERS); + parser.declareObject(ConstructingObjectParser.constructorArg(), + (p, c) -> TimingStats.fromXContent(p, ignoreUnknownFields), TIMING_STATS); + parser.declareObject(ConstructingObjectParser.constructorArg(), + (p, c) -> ValidationLoss.fromXContent(p, ignoreUnknownFields), VALIDATION_LOSS); + return parser; + } + + private final String jobId; + private final Instant timestamp; + private final int iteration; + private final Hyperparameters hyperparameters; + private final TimingStats timingStats; + private final ValidationLoss validationLoss; + + public RegressionStats(String jobId, Instant timestamp, int iteration, Hyperparameters hyperparameters, TimingStats timingStats, + ValidationLoss validationLoss) { + this.jobId = Objects.requireNonNull(jobId); + // We intend to store this timestamp in millis granularity. Thus we're rounding here to ensure + // internal representation matches toXContent + this.timestamp = Instant.ofEpochMilli(ExceptionsHelper.requireNonNull(timestamp, Fields.TIMESTAMP).toEpochMilli()); + this.iteration = iteration; + this.hyperparameters = Objects.requireNonNull(hyperparameters); + this.timingStats = Objects.requireNonNull(timingStats); + this.validationLoss = Objects.requireNonNull(validationLoss); + } + + public RegressionStats(StreamInput in) throws IOException { + this.jobId = in.readString(); + this.timestamp = in.readInstant(); + this.iteration = in.readVInt(); + this.hyperparameters = new Hyperparameters(in); + this.timingStats = new TimingStats(in); + this.validationLoss = new ValidationLoss(in); + } + + @Override + public String getWriteableName() { + return TYPE_VALUE; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(jobId); + out.writeInstant(timestamp); + out.writeVInt(iteration); + hyperparameters.writeTo(out); + timingStats.writeTo(out); + validationLoss.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false)) { + builder.field(Fields.TYPE.getPreferredName(), TYPE_VALUE); + builder.field(Fields.JOB_ID.getPreferredName(), jobId); + } + builder.timeField(Fields.TIMESTAMP.getPreferredName(), Fields.TIMESTAMP.getPreferredName() + "_string", timestamp.toEpochMilli()); + builder.field(ITERATION.getPreferredName(), iteration); + builder.field(HYPERPARAMETERS.getPreferredName(), hyperparameters); + builder.field(TIMING_STATS.getPreferredName(), timingStats); + builder.field(VALIDATION_LOSS.getPreferredName(), validationLoss); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + RegressionStats that = (RegressionStats) o; + return Objects.equals(jobId, that.jobId) + && Objects.equals(timestamp, that.timestamp) + && iteration == that.iteration + && Objects.equals(hyperparameters, that.hyperparameters) + && Objects.equals(timingStats, that.timingStats) + && Objects.equals(validationLoss, that.validationLoss); + } + + @Override + public int hashCode() { + return Objects.hash(jobId, timestamp, iteration, hyperparameters, timingStats, validationLoss); + } + + public String documentId(String jobId) { + return documentIdPrefix(jobId) + timestamp.toEpochMilli(); + } + + public static String documentIdPrefix(String jobId) { + return TYPE_VALUE + "_" + jobId + "_"; + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/stats/regression/TimingStats.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/stats/regression/TimingStats.java new file mode 100644 index 00000000000..8dba89f601d --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/stats/regression/TimingStats.java @@ -0,0 +1,79 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.stats.regression; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Objects; + +public class TimingStats implements Writeable, ToXContentObject { + + public static final ParseField ELAPSED_TIME = new ParseField("elapsed_time"); + public static final ParseField ITERATION_TIME = new ParseField("iteration_time"); + + public static TimingStats fromXContent(XContentParser parser, boolean ignoreUnknownFields) { + return createParser(ignoreUnknownFields).apply(parser, null); + } + + private static ConstructingObjectParser createParser(boolean ignoreUnknownFields) { + ConstructingObjectParser parser = new ConstructingObjectParser<>("regression_timing_stats", ignoreUnknownFields, + a -> new TimingStats(TimeValue.timeValueMillis((long) a[0]), TimeValue.timeValueMillis((long) a[1]))); + + parser.declareLong(ConstructingObjectParser.constructorArg(), ELAPSED_TIME); + parser.declareLong(ConstructingObjectParser.constructorArg(), ITERATION_TIME); + return parser; + } + + private final TimeValue elapsedTime; + private final TimeValue iterationTime; + + public TimingStats(TimeValue elapsedTime, TimeValue iterationTime) { + this.elapsedTime = Objects.requireNonNull(elapsedTime); + this.iterationTime = Objects.requireNonNull(iterationTime); + } + + public TimingStats(StreamInput in) throws IOException { + this.elapsedTime = in.readTimeValue(); + this.iterationTime = in.readTimeValue(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeTimeValue(elapsedTime); + out.writeTimeValue(iterationTime); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.humanReadableField(ELAPSED_TIME.getPreferredName(), ELAPSED_TIME.getPreferredName() + "_string", elapsedTime); + builder.humanReadableField(ITERATION_TIME.getPreferredName(), ITERATION_TIME.getPreferredName() + "_string", iterationTime); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + TimingStats that = (TimingStats) o; + return Objects.equals(elapsedTime, that.elapsedTime) && Objects.equals(iterationTime, that.iterationTime); + } + + @Override + public int hashCode() { + return Objects.hash(elapsedTime, iterationTime); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/stats/regression/ValidationLoss.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/stats/regression/ValidationLoss.java new file mode 100644 index 00000000000..c31eb72bce0 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/stats/regression/ValidationLoss.java @@ -0,0 +1,82 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.stats.regression; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.dataframe.stats.common.FoldValues; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +public class ValidationLoss implements ToXContentObject, Writeable { + + public static final ParseField LOSS_TYPE = new ParseField("loss_type"); + public static final ParseField FOLD_VALUES = new ParseField("fold_values"); + + public static ValidationLoss fromXContent(XContentParser parser, boolean ignoreUnknownFields) { + return createParser(ignoreUnknownFields).apply(parser, null); + } + + private static ConstructingObjectParser createParser(boolean ignoreUnknownFields) { + ConstructingObjectParser parser = new ConstructingObjectParser<>("regression_validation_loss", + ignoreUnknownFields, + a -> new ValidationLoss((String) a[0], (List) a[1])); + + parser.declareString(ConstructingObjectParser.constructorArg(), LOSS_TYPE); + parser.declareObjectArray(ConstructingObjectParser.constructorArg(), + (p, c) -> FoldValues.fromXContent(p, ignoreUnknownFields), FOLD_VALUES); + return parser; + } + + private final String lossType; + private final List foldValues; + + public ValidationLoss(String lossType, List values) { + this.lossType = Objects.requireNonNull(lossType); + this.foldValues = Objects.requireNonNull(values); + } + + public ValidationLoss(StreamInput in) throws IOException { + lossType = in.readString(); + foldValues = in.readList(FoldValues::new); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(lossType); + out.writeList(foldValues); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(LOSS_TYPE.getPreferredName(), lossType); + builder.field(FOLD_VALUES.getPreferredName(), foldValues); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ValidationLoss that = (ValidationLoss) o; + return Objects.equals(lossType, that.lossType) && Objects.equals(foldValues, that.foldValues); + } + + @Override + public int hashCode() { + return Objects.hash(lossType, foldValues); + } +} diff --git a/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/stats_index_mappings.json b/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/stats_index_mappings.json index 838ea90d947..f7f5e1e4d20 100644 --- a/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/stats_index_mappings.json +++ b/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/stats_index_mappings.json @@ -3,18 +3,120 @@ "_meta": { "version" : "${xpack.ml.version}" }, + "dynamic": false, "properties" : { - "type" : { - "type" : "keyword" + "iteration": { + "type": "integer" + }, + "hyperparameters": { + "properties": { + "class_assignment_objective": { + "type": "keyword" + }, + "downsample_factor": { + "type": "double" + }, + "eta": { + "type": "double" + }, + "eta_growth_rate_per_tree": { + "type": "double" + }, + "feature_bag_fraction": { + "type": "double" + }, + "max_attempts_to_add_tree": { + "type": "integer" + }, + "max_optimization_rounds_per_hyperparameter": { + "type": "integer" + }, + "max_trees": { + "type": "integer" + }, + "num_folds": { + "type": "integer" + }, + "num_splits_per_feature": { + "type": "integer" + }, + "regularization_depth_penalty_multiplier": { + "type": "double" + }, + "regularization_leaf_weight_penalty_multiplier": { + "type": "double" + }, + "regularization_soft_tree_depth_limit": { + "type": "double" + }, + "regularization_soft_tree_depth_tolerance": { + "type": "double" + }, + "regularization_tree_size_penalty_multiplier": { + "type": "double" + } + } }, "job_id" : { "type" : "keyword" }, - "timestamp" : { - "type" : "date" + "parameters": { + "properties": { + "compute_feature_influence": { + "type": "boolean" + }, + "feature_influence_threshold": { + "type": "double" + }, + "method": { + "type": "keyword" + }, + "n_neighbors": { + "type": "integer" + }, + "outlier_fraction": { + "type": "double" + }, + "standardization_enabled": { + "type": "boolean" + } + } }, "peak_usage_bytes" : { "type" : "long" + }, + "timestamp" : { + "type" : "date" + }, + "timing_stats": { + "properties": { + "elapsed_time": { + "type": "long" + }, + "iteration_time": { + "type": "long" + } + } + }, + "type" : { + "type" : "keyword" + }, + "validation_loss": { + "properties": { + "fold_values": { + "properties": { + "fold": { + "type": "integer" + }, + "values": { + "type": "double" + } + } + }, + "loss_type": { + "type": "keyword" + } + } } } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsActionResponseTests.java index f5dab116b38..5cb2b3fef54 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsActionResponseTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsActionResponseTests.java @@ -5,14 +5,20 @@ */ package org.elasticsearch.xpack.core.ml.action; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xpack.core.action.util.QueryPage; import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction.Response; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigTests; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; +import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStats; +import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStatsNamedWriteablesProvider; import org.elasticsearch.xpack.core.ml.dataframe.stats.MemoryUsage; import org.elasticsearch.xpack.core.ml.dataframe.stats.MemoryUsageTests; +import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStatsTests; +import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStatsTests; +import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStatsTests; import org.elasticsearch.xpack.core.ml.utils.PhaseProgress; import java.util.ArrayList; @@ -21,6 +27,13 @@ import java.util.stream.IntStream; public class GetDataFrameAnalyticsStatsActionResponseTests extends AbstractWireSerializingTestCase { + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + List namedWriteables = new ArrayList<>(); + namedWriteables.addAll(new AnalysisStatsNamedWriteablesProvider().getNamedWriteables()); + return new NamedWriteableRegistry(namedWriteables); + } + public static Response randomResponse(int listSize) { List analytics = new ArrayList<>(listSize); for (int j = 0; j < listSize; j++) { @@ -30,8 +43,15 @@ public class GetDataFrameAnalyticsStatsActionResponseTests extends AbstractWireS IntStream.of(progressSize).forEach(progressIndex -> progress.add( new PhaseProgress(randomAlphaOfLength(10), randomIntBetween(0, 100)))); MemoryUsage memoryUsage = randomBoolean() ? null : MemoryUsageTests.createRandom(); + AnalysisStats analysisStats = randomBoolean() ? null : + randomFrom( + ClassificationStatsTests.createRandom(), + OutlierDetectionStatsTests.createRandom(), + RegressionStatsTests.createRandom() + ); Response.Stats stats = new Response.Stats(DataFrameAnalyticsConfigTests.randomValidId(), - randomFrom(DataFrameAnalyticsState.values()), failureReason, progress, memoryUsage, null, randomAlphaOfLength(20)); + randomFrom(DataFrameAnalyticsState.values()), failureReason, progress, memoryUsage, analysisStats, null, + randomAlphaOfLength(20)); analytics.add(stats); } return new Response(new QueryPage<>(analytics, analytics.size(), GetDataFrameAnalyticsAction.Response.RESULTS_FIELD)); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/stats/classification/ClassificationStatsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/stats/classification/ClassificationStatsTests.java new file mode 100644 index 00000000000..4aba8d77fff --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/stats/classification/ClassificationStatsTests.java @@ -0,0 +1,69 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.stats.classification; + +import org.elasticsearch.Version; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase; +import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; +import org.junit.Before; + +import java.io.IOException; +import java.time.Instant; +import java.util.Collections; + +public class ClassificationStatsTests extends AbstractBWCSerializationTestCase { + + private boolean lenient; + + @Before + public void chooseLenient() { + lenient = randomBoolean(); + } + + @Override + protected boolean supportsUnknownFields() { + return lenient; + } + + @Override + protected ClassificationStats mutateInstanceForVersion(ClassificationStats instance, Version version) { + return instance; + } + + @Override + protected ClassificationStats doParseInstance(XContentParser parser) throws IOException { + return lenient ? ClassificationStats.LENIENT_PARSER.apply(parser, null) : ClassificationStats.STRICT_PARSER.apply(parser, null); + } + + @Override + protected ToXContent.Params getToXContentParams() { + return new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")); + } + + @Override + protected Writeable.Reader instanceReader() { + return ClassificationStats::new; + } + + @Override + protected ClassificationStats createTestInstance() { + return createRandom(); + } + + public static ClassificationStats createRandom() { + return new ClassificationStats( + randomAlphaOfLength(10), + Instant.now(), + randomIntBetween(1, Integer.MAX_VALUE), + HyperparametersTests.createRandom(), + TimingStatsTests.createRandom(), + ValidationLossTests.createRandom() + ); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/stats/classification/HyperparametersTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/stats/classification/HyperparametersTests.java new file mode 100644 index 00000000000..4c67e3c36be --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/stats/classification/HyperparametersTests.java @@ -0,0 +1,69 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.stats.classification; + +import org.elasticsearch.Version; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase; +import org.junit.Before; + +import java.io.IOException; + +public class HyperparametersTests extends AbstractBWCSerializationTestCase { + + private boolean lenient; + + @Before + public void chooseLenient() { + lenient = randomBoolean(); + } + + @Override + protected boolean supportsUnknownFields() { + return lenient; + } + + @Override + protected Hyperparameters mutateInstanceForVersion(Hyperparameters instance, Version version) { + return instance; + } + + @Override + protected Hyperparameters doParseInstance(XContentParser parser) throws IOException { + return Hyperparameters.fromXContent(parser, lenient); + } + + @Override + protected Writeable.Reader instanceReader() { + return Hyperparameters::new; + } + + @Override + protected Hyperparameters createTestInstance() { + return createRandom(); + } + + public static Hyperparameters createRandom() { + return new Hyperparameters( + randomAlphaOfLength(10), + randomDouble(), + randomDouble(), + randomDouble(), + randomDouble(), + randomIntBetween(0, Integer.MAX_VALUE), + randomIntBetween(0, Integer.MAX_VALUE), + randomIntBetween(0, Integer.MAX_VALUE), + randomIntBetween(0, Integer.MAX_VALUE), + randomIntBetween(0, Integer.MAX_VALUE), + randomDouble(), + randomDouble(), + randomDouble(), + randomDouble(), + randomDouble() + ); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/stats/classification/TimingStatsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/stats/classification/TimingStatsTests.java new file mode 100644 index 00000000000..cff085e7082 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/stats/classification/TimingStatsTests.java @@ -0,0 +1,54 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.stats.classification; + +import org.elasticsearch.Version; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase; +import org.junit.Before; + +import java.io.IOException; + +public class TimingStatsTests extends AbstractBWCSerializationTestCase { + + private boolean lenient; + + @Before + public void chooseLenient() { + lenient = randomBoolean(); + } + + @Override + protected boolean supportsUnknownFields() { + return lenient; + } + + @Override + protected TimingStats mutateInstanceForVersion(TimingStats instance, Version version) { + return instance; + } + + @Override + protected TimingStats doParseInstance(XContentParser parser) throws IOException { + return TimingStats.fromXContent(parser, lenient); + } + + @Override + protected Writeable.Reader instanceReader() { + return TimingStats::new; + } + + @Override + protected TimingStats createTestInstance() { + return createRandom(); + } + + public static TimingStats createRandom() { + return new TimingStats(TimeValue.timeValueMillis(randomNonNegativeLong()), TimeValue.timeValueMillis(randomNonNegativeLong())); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/stats/classification/ValidationLossTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/stats/classification/ValidationLossTests.java new file mode 100644 index 00000000000..2807a9e23d9 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/stats/classification/ValidationLossTests.java @@ -0,0 +1,52 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.stats.classification; + +import org.elasticsearch.Version; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase; +import org.elasticsearch.xpack.core.ml.dataframe.stats.common.FoldValuesTests; +import org.junit.Before; + +import java.io.IOException; + +public class ValidationLossTests extends AbstractBWCSerializationTestCase { + + private boolean lenient; + + @Before + public void chooseStrictOrLenient() { + lenient = randomBoolean(); + } + + @Override + protected ValidationLoss doParseInstance(XContentParser parser) throws IOException { + return ValidationLoss.fromXContent(parser, lenient); + } + + @Override + protected Writeable.Reader instanceReader() { + return ValidationLoss::new; + } + + @Override + protected ValidationLoss createTestInstance() { + return createRandom(); + } + + public static ValidationLoss createRandom() { + return new ValidationLoss( + randomAlphaOfLength(10), + randomList(5, () -> FoldValuesTests.createRandom()) + ); + } + + @Override + protected ValidationLoss mutateInstanceForVersion(ValidationLoss instance, Version version) { + return instance; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/stats/common/FoldValuesTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/stats/common/FoldValuesTests.java new file mode 100644 index 00000000000..92992924765 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/stats/common/FoldValuesTests.java @@ -0,0 +1,58 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.stats.common; + +import org.elasticsearch.Version; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase; +import org.junit.Before; + +import java.io.IOException; + +public class FoldValuesTests extends AbstractBWCSerializationTestCase { + + private boolean lenient; + + @Before + public void chooseLenient() { + lenient = randomBoolean(); + } + + @Override + protected boolean supportsUnknownFields() { + return lenient; + } + + @Override + protected FoldValues doParseInstance(XContentParser parser) throws IOException { + return FoldValues.fromXContent(parser, lenient); + } + + @Override + protected Writeable.Reader instanceReader() { + return FoldValues::new; + } + + @Override + protected FoldValues createTestInstance() { + return createRandom(); + } + + public static FoldValues createRandom() { + int valuesSize = randomIntBetween(0, 10); + double[] values = new double[valuesSize]; + for (int i = 0; i < valuesSize; i++) { + values[i] = randomDouble(); + } + return new FoldValues(randomIntBetween(0, Integer.MAX_VALUE), values); + } + + @Override + protected FoldValues mutateInstanceForVersion(FoldValues instance, Version version) { + return instance; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/stats/outlierdetection/OutlierDetectionStatsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/stats/outlierdetection/OutlierDetectionStatsTests.java new file mode 100644 index 00000000000..40c10a8a541 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/stats/outlierdetection/OutlierDetectionStatsTests.java @@ -0,0 +1,68 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection; + +import org.elasticsearch.Version; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase; +import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; +import org.junit.Before; + +import java.io.IOException; +import java.time.Instant; +import java.util.Collections; + +public class OutlierDetectionStatsTests extends AbstractBWCSerializationTestCase { + + private boolean lenient; + + @Before + public void chooseLenient() { + lenient = randomBoolean(); + } + + @Override + protected boolean supportsUnknownFields() { + return lenient; + } + + @Override + protected OutlierDetectionStats mutateInstanceForVersion(OutlierDetectionStats instance, Version version) { + return instance; + } + + @Override + protected OutlierDetectionStats doParseInstance(XContentParser parser) throws IOException { + return lenient ? OutlierDetectionStats.LENIENT_PARSER.apply(parser, null) + : OutlierDetectionStats.STRICT_PARSER.apply(parser, null); + } + + @Override + protected ToXContent.Params getToXContentParams() { + return new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")); + } + + @Override + protected Writeable.Reader instanceReader() { + return OutlierDetectionStats::new; + } + + @Override + protected OutlierDetectionStats createTestInstance() { + return createRandom(); + } + + public static OutlierDetectionStats createRandom() { + return new OutlierDetectionStats( + randomAlphaOfLength(10), + Instant.now(), + ParametersTests.createRandom(), + TimingStatsTests.createRandom() + ); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/stats/outlierdetection/ParametersTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/stats/outlierdetection/ParametersTests.java new file mode 100644 index 00000000000..eb84cc30ebf --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/stats/outlierdetection/ParametersTests.java @@ -0,0 +1,61 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection; + +import org.elasticsearch.Version; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase; +import org.junit.Before; + +import java.io.IOException; + +public class ParametersTests extends AbstractBWCSerializationTestCase { + + private boolean lenient; + + @Before + public void chooseLenient() { + lenient = randomBoolean(); + } + + @Override + protected boolean supportsUnknownFields() { + return lenient; + } + + @Override + protected Parameters mutateInstanceForVersion(Parameters instance, Version version) { + return instance; + } + + @Override + protected Parameters doParseInstance(XContentParser parser) throws IOException { + return Parameters.fromXContent(parser, lenient); + } + + @Override + protected Writeable.Reader instanceReader() { + return Parameters::new; + } + + @Override + protected Parameters createTestInstance() { + return createRandom(); + } + + public static Parameters createRandom() { + + return new Parameters( + randomIntBetween(1, Integer.MAX_VALUE), + randomAlphaOfLength(5), + randomBoolean(), + randomDouble(), + randomDouble(), + randomBoolean() + ); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/stats/outlierdetection/TimingStatsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/stats/outlierdetection/TimingStatsTests.java new file mode 100644 index 00000000000..389938a712a --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/stats/outlierdetection/TimingStatsTests.java @@ -0,0 +1,54 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection; + +import org.elasticsearch.Version; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase; +import org.junit.Before; + +import java.io.IOException; + +public class TimingStatsTests extends AbstractBWCSerializationTestCase { + + private boolean lenient; + + @Before + public void chooseLenient() { + lenient = randomBoolean(); + } + + @Override + protected boolean supportsUnknownFields() { + return lenient; + } + + @Override + protected TimingStats mutateInstanceForVersion(TimingStats instance, Version version) { + return instance; + } + + @Override + protected TimingStats doParseInstance(XContentParser parser) throws IOException { + return TimingStats.fromXContent(parser, lenient); + } + + @Override + protected Writeable.Reader instanceReader() { + return TimingStats::new; + } + + @Override + protected TimingStats createTestInstance() { + return createRandom(); + } + + public static TimingStats createRandom() { + return new TimingStats(TimeValue.timeValueMillis(randomNonNegativeLong())); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/stats/regression/HyperparametersTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/stats/regression/HyperparametersTests.java new file mode 100644 index 00000000000..bebdec390ec --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/stats/regression/HyperparametersTests.java @@ -0,0 +1,68 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.stats.regression; + +import org.elasticsearch.Version; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase; +import org.junit.Before; + +import java.io.IOException; + +public class HyperparametersTests extends AbstractBWCSerializationTestCase { + + private boolean lenient; + + @Before + public void chooseLenient() { + lenient = randomBoolean(); + } + + @Override + protected boolean supportsUnknownFields() { + return lenient; + } + + @Override + protected Hyperparameters mutateInstanceForVersion(Hyperparameters instance, Version version) { + return instance; + } + + @Override + protected Hyperparameters doParseInstance(XContentParser parser) throws IOException { + return Hyperparameters.fromXContent(parser, lenient); + } + + @Override + protected Writeable.Reader instanceReader() { + return Hyperparameters::new; + } + + @Override + protected Hyperparameters createTestInstance() { + return createRandom(); + } + + public static Hyperparameters createRandom() { + return new Hyperparameters( + randomDouble(), + randomDouble(), + randomDouble(), + randomDouble(), + randomIntBetween(0, Integer.MAX_VALUE), + randomIntBetween(0, Integer.MAX_VALUE), + randomIntBetween(0, Integer.MAX_VALUE), + randomIntBetween(0, Integer.MAX_VALUE), + randomIntBetween(0, Integer.MAX_VALUE), + randomDouble(), + randomDouble(), + randomDouble(), + randomDouble(), + randomDouble() + ); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/stats/regression/RegressionStatsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/stats/regression/RegressionStatsTests.java new file mode 100644 index 00000000000..dc4cc7b7ab3 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/stats/regression/RegressionStatsTests.java @@ -0,0 +1,69 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.stats.regression; + +import org.elasticsearch.Version; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase; +import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; +import org.junit.Before; + +import java.io.IOException; +import java.time.Instant; +import java.util.Collections; + +public class RegressionStatsTests extends AbstractBWCSerializationTestCase { + + private boolean lenient; + + @Before + public void chooseLenient() { + lenient = randomBoolean(); + } + + @Override + protected boolean supportsUnknownFields() { + return lenient; + } + + @Override + protected RegressionStats mutateInstanceForVersion(RegressionStats instance, Version version) { + return instance; + } + + @Override + protected RegressionStats doParseInstance(XContentParser parser) throws IOException { + return lenient ? RegressionStats.LENIENT_PARSER.apply(parser, null) : RegressionStats.STRICT_PARSER.apply(parser, null); + } + + @Override + protected ToXContent.Params getToXContentParams() { + return new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")); + } + + @Override + protected Writeable.Reader instanceReader() { + return RegressionStats::new; + } + + @Override + protected RegressionStats createTestInstance() { + return createRandom(); + } + + public static RegressionStats createRandom() { + return new RegressionStats( + randomAlphaOfLength(10), + Instant.now(), + randomIntBetween(1, Integer.MAX_VALUE), + HyperparametersTests.createRandom(), + TimingStatsTests.createRandom(), + ValidationLossTests.createRandom() + ); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/stats/regression/TimingStatsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/stats/regression/TimingStatsTests.java new file mode 100644 index 00000000000..ec69602d754 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/stats/regression/TimingStatsTests.java @@ -0,0 +1,54 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.stats.regression; + +import org.elasticsearch.Version; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase; +import org.junit.Before; + +import java.io.IOException; + +public class TimingStatsTests extends AbstractBWCSerializationTestCase { + + private boolean lenient; + + @Before + public void chooseLenient() { + lenient = randomBoolean(); + } + + @Override + protected boolean supportsUnknownFields() { + return lenient; + } + + @Override + protected TimingStats mutateInstanceForVersion(TimingStats instance, Version version) { + return instance; + } + + @Override + protected TimingStats doParseInstance(XContentParser parser) throws IOException { + return TimingStats.fromXContent(parser, lenient); + } + + @Override + protected Writeable.Reader instanceReader() { + return TimingStats::new; + } + + @Override + protected TimingStats createTestInstance() { + return createRandom(); + } + + public static TimingStats createRandom() { + return new TimingStats(TimeValue.timeValueMillis(randomNonNegativeLong()), TimeValue.timeValueMillis(randomNonNegativeLong())); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/stats/regression/ValidationLossTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/stats/regression/ValidationLossTests.java new file mode 100644 index 00000000000..922428bff5e --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/stats/regression/ValidationLossTests.java @@ -0,0 +1,52 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.stats.regression; + +import org.elasticsearch.Version; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase; +import org.elasticsearch.xpack.core.ml.dataframe.stats.common.FoldValuesTests; +import org.junit.Before; + +import java.io.IOException; + +public class ValidationLossTests extends AbstractBWCSerializationTestCase { + + private boolean lenient; + + @Before + public void chooseStrictOrLenient() { + lenient = randomBoolean(); + } + + @Override + protected ValidationLoss doParseInstance(XContentParser parser) throws IOException { + return ValidationLoss.fromXContent(parser, lenient); + } + + @Override + protected Writeable.Reader instanceReader() { + return ValidationLoss::new; + } + + @Override + protected ValidationLoss createTestInstance() { + return createRandom(); + } + + public static ValidationLoss createRandom() { + return new ValidationLoss( + randomAlphaOfLength(10), + randomList(5, () -> FoldValuesTests.createRandom()) + ); + } + + @Override + protected ValidationLoss mutateInstanceForVersion(ValidationLoss instance, Version version) { + return instance; + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsStatsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsStatsAction.java index be9e1218e2f..6efca760917 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsStatsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsStatsAction.java @@ -41,7 +41,12 @@ import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction.R import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState; +import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStats; +import org.elasticsearch.xpack.core.ml.dataframe.stats.Fields; import org.elasticsearch.xpack.core.ml.dataframe.stats.MemoryUsage; +import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStats; +import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStats; +import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStats; import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.PhaseProgress; @@ -103,7 +108,8 @@ public class TransportGetDataFrameAnalyticsStatsAction Stats stats = buildStats( task.getParams().getId(), statsHolder.getProgressTracker().report(), - statsHolder.getMemoryUsage() + statsHolder.getMemoryUsage(), + statsHolder.getAnalysisStats() ); listener.onResponse(new QueryPage<>(Collections.singletonList(stats), 1, GetDataFrameAnalyticsAction.Response.RESULTS_FIELD)); @@ -192,7 +198,10 @@ public class TransportGetDataFrameAnalyticsStatsAction MultiSearchRequest multiSearchRequest = new MultiSearchRequest(); multiSearchRequest.add(buildStoredProgressSearch(configId)); - multiSearchRequest.add(buildMemoryUsageSearch(configId)); + multiSearchRequest.add(buildStatsDocSearch(configId, MemoryUsage.TYPE_VALUE)); + multiSearchRequest.add(buildStatsDocSearch(configId, OutlierDetectionStats.TYPE_VALUE)); + multiSearchRequest.add(buildStatsDocSearch(configId, ClassificationStats.TYPE_VALUE)); + multiSearchRequest.add(buildStatsDocSearch(configId, RegressionStats.TYPE_VALUE)); executeAsyncWithOrigin(client, ML_ORIGIN, MultiSearchAction.INSTANCE, multiSearchRequest, ActionListener.wrap( multiSearchResponse -> { @@ -213,7 +222,8 @@ public class TransportGetDataFrameAnalyticsStatsAction } listener.onResponse(buildStats(configId, retrievedStatsHolder.progress.get(), - retrievedStatsHolder.memoryUsage + retrievedStatsHolder.memoryUsage, + retrievedStatsHolder.analysisStats )); }, e -> listener.onFailure(ExceptionsHelper.serverError("Error searching for stats", e)) @@ -228,18 +238,17 @@ public class TransportGetDataFrameAnalyticsStatsAction return searchRequest; } - private static SearchRequest buildMemoryUsageSearch(String configId) { + private static SearchRequest buildStatsDocSearch(String configId, String statsType) { SearchRequest searchRequest = new SearchRequest(MlStatsIndex.indexPattern()); searchRequest.indicesOptions(IndicesOptions.lenientExpandOpen()); searchRequest.source().size(1); QueryBuilder query = QueryBuilders.boolQuery() - .filter(QueryBuilders.termQuery(MemoryUsage.JOB_ID.getPreferredName(), configId)) - .filter(QueryBuilders.termQuery(MemoryUsage.TYPE.getPreferredName(), MemoryUsage.TYPE_VALUE)); + .filter(QueryBuilders.termQuery(Fields.JOB_ID.getPreferredName(), configId)) + .filter(QueryBuilders.termQuery(Fields.TYPE.getPreferredName(), statsType)); searchRequest.source().query(query); - searchRequest.source().sort(SortBuilders.fieldSort(MemoryUsage.TIMESTAMP.getPreferredName()).order(SortOrder.DESC) + searchRequest.source().sort(SortBuilders.fieldSort(Fields.TIMESTAMP.getPreferredName()).order(SortOrder.DESC) // We need this for the search not to fail when there are no mappings yet in the index .unmappedType("long")); - searchRequest.source().sort(MemoryUsage.TIMESTAMP.getPreferredName(), SortOrder.DESC); return searchRequest; } @@ -249,6 +258,12 @@ public class TransportGetDataFrameAnalyticsStatsAction retrievedStatsHolder.progress = MlParserUtils.parse(hit, StoredProgress.PARSER); } else if (hitId.startsWith(MemoryUsage.documentIdPrefix(configId))) { retrievedStatsHolder.memoryUsage = MlParserUtils.parse(hit, MemoryUsage.LENIENT_PARSER); + } else if (hitId.startsWith(OutlierDetectionStats.documentIdPrefix(configId))) { + retrievedStatsHolder.analysisStats = MlParserUtils.parse(hit, OutlierDetectionStats.LENIENT_PARSER); + } else if (hitId.startsWith(ClassificationStats.documentIdPrefix(configId))) { + retrievedStatsHolder.analysisStats = MlParserUtils.parse(hit, ClassificationStats.LENIENT_PARSER); + } else if (hitId.startsWith(RegressionStats.documentIdPrefix(configId))) { + retrievedStatsHolder.analysisStats = MlParserUtils.parse(hit, RegressionStats.LENIENT_PARSER); } else { throw ExceptionsHelper.serverError("unexpected doc id [" + hitId + "]"); } @@ -256,7 +271,8 @@ public class TransportGetDataFrameAnalyticsStatsAction private GetDataFrameAnalyticsStatsAction.Response.Stats buildStats(String concreteAnalyticsId, List progress, - MemoryUsage memoryUsage) { + MemoryUsage memoryUsage, + AnalysisStats analysisStats) { ClusterState clusterState = clusterService.state(); PersistentTasksCustomMetaData tasks = clusterState.getMetaData().custom(PersistentTasksCustomMetaData.TYPE); PersistentTasksCustomMetaData.PersistentTask analyticsTask = MlTasks.getDataFrameAnalyticsTask(concreteAnalyticsId, tasks); @@ -278,6 +294,7 @@ public class TransportGetDataFrameAnalyticsStatsAction failureReason, progress, memoryUsage, + analysisStats, node, assignmentExplanation ); @@ -287,5 +304,6 @@ public class TransportGetDataFrameAnalyticsStatsAction private volatile StoredProgress progress = new StoredProgress(new ProgressTracker().report()); private volatile MemoryUsage memoryUsage; + private volatile AnalysisStats analysisStats; } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java index e93c6473998..2afb2b65d5b 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java @@ -23,6 +23,9 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import org.elasticsearch.xpack.core.ml.dataframe.stats.MemoryUsage; +import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStats; +import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStats; +import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStats; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; @@ -175,6 +178,21 @@ public class AnalyticsResultProcessor { statsHolder.setMemoryUsage(memoryUsage); indexStatsResult(memoryUsage, memoryUsage::documentId); } + OutlierDetectionStats outlierDetectionStats = result.getOutlierDetectionStats(); + if (outlierDetectionStats != null) { + statsHolder.setAnalysisStats(outlierDetectionStats); + indexStatsResult(outlierDetectionStats, outlierDetectionStats::documentId); + } + ClassificationStats classificationStats = result.getClassificationStats(); + if (classificationStats != null) { + statsHolder.setAnalysisStats(classificationStats); + indexStatsResult(classificationStats, classificationStats::documentId); + } + RegressionStats regressionStats = result.getRegressionStats(); + if (regressionStats != null) { + statsHolder.setAnalysisStats(regressionStats); + indexStatsResult(regressionStats, regressionStats::documentId); + } } private void createAndIndexInferenceModel(TrainedModelDefinition.Builder inferenceModel) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResult.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResult.java index fcac851fa13..c3e160b7cd7 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResult.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResult.java @@ -12,6 +12,9 @@ import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.xpack.core.ml.dataframe.stats.MemoryUsage; +import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStats; +import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStats; +import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStats; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import java.io.IOException; @@ -28,9 +31,20 @@ public class AnalyticsResult implements ToXContentObject { private static final ParseField PROGRESS_PERCENT = new ParseField("progress_percent"); private static final ParseField INFERENCE_MODEL = new ParseField("inference_model"); private static final ParseField ANALYTICS_MEMORY_USAGE = new ParseField("analytics_memory_usage"); + private static final ParseField OUTLIER_DETECTION_STATS = new ParseField("outlier_detection_stats"); + private static final ParseField CLASSIFICATION_STATS = new ParseField("classification_stats"); + private static final ParseField REGRESSION_STATS = new ParseField("regression_stats"); public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(TYPE.getPreferredName(), - a -> new AnalyticsResult((RowResults) a[0], (Integer) a[1], (TrainedModelDefinition.Builder) a[2], (MemoryUsage) a[3])); + a -> new AnalyticsResult( + (RowResults) a[0], + (Integer) a[1], + (TrainedModelDefinition.Builder) a[2], + (MemoryUsage) a[3], + (OutlierDetectionStats) a[4], + (ClassificationStats) a[5], + (RegressionStats) a[6] + )); static { PARSER.declareObject(optionalConstructorArg(), RowResults.PARSER, RowResults.TYPE); @@ -38,6 +52,9 @@ public class AnalyticsResult implements ToXContentObject { // TODO change back to STRICT_PARSER once native side is aligned PARSER.declareObject(optionalConstructorArg(), TrainedModelDefinition.LENIENT_PARSER, INFERENCE_MODEL); PARSER.declareObject(optionalConstructorArg(), MemoryUsage.STRICT_PARSER, ANALYTICS_MEMORY_USAGE); + PARSER.declareObject(optionalConstructorArg(), OutlierDetectionStats.STRICT_PARSER, OUTLIER_DETECTION_STATS); + PARSER.declareObject(optionalConstructorArg(), ClassificationStats.STRICT_PARSER, CLASSIFICATION_STATS); + PARSER.declareObject(optionalConstructorArg(), RegressionStats.STRICT_PARSER, REGRESSION_STATS); } private final RowResults rowResults; @@ -45,16 +62,25 @@ public class AnalyticsResult implements ToXContentObject { private final TrainedModelDefinition.Builder inferenceModelBuilder; private final TrainedModelDefinition inferenceModel; private final MemoryUsage memoryUsage; + private final OutlierDetectionStats outlierDetectionStats; + private final ClassificationStats classificationStats; + private final RegressionStats regressionStats; public AnalyticsResult(@Nullable RowResults rowResults, @Nullable Integer progressPercent, @Nullable TrainedModelDefinition.Builder inferenceModelBuilder, - @Nullable MemoryUsage memoryUsage) { + @Nullable MemoryUsage memoryUsage, + @Nullable OutlierDetectionStats outlierDetectionStats, + @Nullable ClassificationStats classificationStats, + @Nullable RegressionStats regressionStats) { this.rowResults = rowResults; this.progressPercent = progressPercent; this.inferenceModelBuilder = inferenceModelBuilder; this.inferenceModel = inferenceModelBuilder == null ? null : inferenceModelBuilder.build(); this.memoryUsage = memoryUsage; + this.outlierDetectionStats = outlierDetectionStats; + this.classificationStats = classificationStats; + this.regressionStats = regressionStats; } public RowResults getRowResults() { @@ -73,6 +99,18 @@ public class AnalyticsResult implements ToXContentObject { return memoryUsage; } + public OutlierDetectionStats getOutlierDetectionStats() { + return outlierDetectionStats; + } + + public ClassificationStats getClassificationStats() { + return classificationStats; + } + + public RegressionStats getRegressionStats() { + return regressionStats; + } + @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); @@ -90,6 +128,15 @@ public class AnalyticsResult implements ToXContentObject { if (memoryUsage != null) { builder.field(ANALYTICS_MEMORY_USAGE.getPreferredName(), memoryUsage, params); } + if (outlierDetectionStats != null) { + builder.field(OUTLIER_DETECTION_STATS.getPreferredName(), outlierDetectionStats, params); + } + if (classificationStats != null) { + builder.field(CLASSIFICATION_STATS.getPreferredName(), classificationStats, params); + } + if (regressionStats != null) { + builder.field(REGRESSION_STATS.getPreferredName(), regressionStats, params); + } builder.endObject(); return builder; } @@ -107,11 +154,15 @@ public class AnalyticsResult implements ToXContentObject { return Objects.equals(rowResults, that.rowResults) && Objects.equals(progressPercent, that.progressPercent) && Objects.equals(inferenceModel, that.inferenceModel) - && Objects.equals(memoryUsage, that.memoryUsage); + && Objects.equals(memoryUsage, that.memoryUsage) + && Objects.equals(outlierDetectionStats, that.outlierDetectionStats) + && Objects.equals(classificationStats, that.classificationStats) + && Objects.equals(regressionStats, that.regressionStats); } @Override public int hashCode() { - return Objects.hash(rowResults, progressPercent, inferenceModel, memoryUsage); + return Objects.hash(rowResults, progressPercent, inferenceModel, memoryUsage, outlierDetectionStats, classificationStats, + regressionStats); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolder.java index d2e9bdd957e..ff6b9ec7bcf 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolder.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.ml.dataframe.stats; +import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStats; import org.elasticsearch.xpack.core.ml.dataframe.stats.MemoryUsage; import java.util.concurrent.atomic.AtomicReference; @@ -17,10 +18,12 @@ public class StatsHolder { private final ProgressTracker progressTracker; private final AtomicReference memoryUsageHolder; + private final AtomicReference analysisStatsHolder; public StatsHolder() { progressTracker = new ProgressTracker(); memoryUsageHolder = new AtomicReference<>(); + analysisStatsHolder = new AtomicReference<>(); } public ProgressTracker getProgressTracker() { @@ -34,4 +37,12 @@ public class StatsHolder { public MemoryUsage getMemoryUsage() { return memoryUsageHolder.get(); } + + public void setAnalysisStats(AnalysisStats analysisStats) { + analysisStatsHolder.set(analysisStats); + } + + public AnalysisStats getAnalysisStats() { + return analysisStatsHolder.get(); + } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java index ee3436ff87f..008e97571f7 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java @@ -56,7 +56,7 @@ public class AnalyticsProcessManagerTests extends ESTestCase { private static final String CONFIG_ID = "config-id"; private static final int NUM_ROWS = 100; private static final int NUM_COLS = 4; - private static final AnalyticsResult PROCESS_RESULT = new AnalyticsResult(null, null, null, null); + private static final AnalyticsResult PROCESS_RESULT = new AnalyticsResult(null, null, null, null, null, null, null); private Client client; private DataFrameAnalyticsAuditor auditor; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java index 23051163439..da7bbff71d5 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java @@ -100,7 +100,9 @@ public class AnalyticsResultProcessorTests extends ESTestCase { public void testProcess_GivenEmptyResults() { givenDataFrameRows(2); - givenProcessResults(Arrays.asList(new AnalyticsResult(null, 50, null, null), new AnalyticsResult(null, 100, null, null))); + givenProcessResults(Arrays.asList( + new AnalyticsResult(null, 50, null, null, null, null, null), + new AnalyticsResult(null, 100, null, null, null, null, null))); AnalyticsResultProcessor resultProcessor = createResultProcessor(); resultProcessor.process(process); @@ -115,8 +117,8 @@ public class AnalyticsResultProcessorTests extends ESTestCase { givenDataFrameRows(2); RowResults rowResults1 = mock(RowResults.class); RowResults rowResults2 = mock(RowResults.class); - givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1, 50, null, null), - new AnalyticsResult(rowResults2, 100, null, null))); + givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1, 50, null, null, null, null, null), + new AnalyticsResult(rowResults2, 100, null, null, null, null, null))); AnalyticsResultProcessor resultProcessor = createResultProcessor(); resultProcessor.process(process); @@ -133,8 +135,8 @@ public class AnalyticsResultProcessorTests extends ESTestCase { givenDataFrameRows(2); RowResults rowResults1 = mock(RowResults.class); RowResults rowResults2 = mock(RowResults.class); - givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1, 50, null, null), - new AnalyticsResult(rowResults2, 100, null, null))); + givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1, 50, null, null, null, null, null), + new AnalyticsResult(rowResults2, 100, null, null, null, null, null))); doThrow(new RuntimeException("some failure")).when(dataFrameRowsJoiner).processRowResults(any(RowResults.class)); @@ -167,7 +169,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase { extractedFieldList.add(new MultiField("bar", new DocValueField("bar.keyword", Collections.emptySet()))); extractedFieldList.add(new DocValueField("baz", Collections.emptySet())); TrainedModelDefinition.Builder inferenceModel = TrainedModelDefinitionTests.createRandomBuilder(); - givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel, null))); + givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel, null, null, null, null))); AnalyticsResultProcessor resultProcessor = createResultProcessor(extractedFieldList); resultProcessor.process(process); @@ -212,7 +214,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase { }).when(trainedModelProvider).storeTrainedModel(any(TrainedModelConfig.class), any(ActionListener.class)); TrainedModelDefinition.Builder inferenceModel = TrainedModelDefinitionTests.createRandomBuilder(); - givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel, null))); + givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel, null, null, null, null))); AnalyticsResultProcessor resultProcessor = createResultProcessor(); resultProcessor.process(process); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResultTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResultTests.java index e4332dc1767..d1be7cbd250 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResultTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResultTests.java @@ -11,7 +11,14 @@ import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.search.SearchModule; import org.elasticsearch.test.AbstractXContentTestCase; +import org.elasticsearch.xpack.core.ml.dataframe.stats.MemoryUsage; import org.elasticsearch.xpack.core.ml.dataframe.stats.MemoryUsageTests; +import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStats; +import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStatsTests; +import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStats; +import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStatsTests; +import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStats; +import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStatsTests; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests; @@ -36,6 +43,10 @@ public class AnalyticsResultTests extends AbstractXContentTestCase