mirror of
https://github.com/honeymoose/OpenSearch.git
synced 2025-02-24 22:09:24 +00:00
Adds parsing and indexing of analysis instrumentation stats. The latest one is also returned from the get-stats API. Note that we chose to duplicate objects even where they are currently similar. There are already ideas on how these will diverge in the future and while the duplication looks ugly at the moment, it is the option that offers the highest flexibility. Backport of #53788
This commit is contained in:
parent
f7143b8d85
commit
60153c5433
@ -20,12 +20,15 @@
|
||||
package org.elasticsearch.client.ml.dataframe;
|
||||
|
||||
import org.elasticsearch.client.ml.NodeAttributes;
|
||||
import org.elasticsearch.client.ml.dataframe.stats.AnalysisStats;
|
||||
import org.elasticsearch.client.ml.dataframe.stats.common.MemoryUsage;
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.inject.internal.ToStringBuilder;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.common.xcontent.XContentParserUtils;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
@ -45,6 +48,7 @@ public class DataFrameAnalyticsStats {
|
||||
static final ParseField FAILURE_REASON = new ParseField("failure_reason");
|
||||
static final ParseField PROGRESS = new ParseField("progress");
|
||||
static final ParseField MEMORY_USAGE = new ParseField("memory_usage");
|
||||
static final ParseField ANALYSIS_STATS = new ParseField("analysis_stats");
|
||||
static final ParseField NODE = new ParseField("node");
|
||||
static final ParseField ASSIGNMENT_EXPLANATION = new ParseField("assignment_explanation");
|
||||
|
||||
@ -57,8 +61,9 @@ public class DataFrameAnalyticsStats {
|
||||
(String) args[2],
|
||||
(List<PhaseProgress>) args[3],
|
||||
(MemoryUsage) args[4],
|
||||
(NodeAttributes) args[5],
|
||||
(String) args[6]));
|
||||
(AnalysisStats) args[5],
|
||||
(NodeAttributes) args[6],
|
||||
(String) args[7]));
|
||||
|
||||
static {
|
||||
PARSER.declareString(constructorArg(), ID);
|
||||
@ -71,26 +76,38 @@ public class DataFrameAnalyticsStats {
|
||||
PARSER.declareString(optionalConstructorArg(), FAILURE_REASON);
|
||||
PARSER.declareObjectArray(optionalConstructorArg(), PhaseProgress.PARSER, PROGRESS);
|
||||
PARSER.declareObject(optionalConstructorArg(), MemoryUsage.PARSER, MEMORY_USAGE);
|
||||
PARSER.declareObject(optionalConstructorArg(), (p, c) -> parseAnalysisStats(p), ANALYSIS_STATS);
|
||||
PARSER.declareObject(optionalConstructorArg(), NodeAttributes.PARSER, NODE);
|
||||
PARSER.declareString(optionalConstructorArg(), ASSIGNMENT_EXPLANATION);
|
||||
}
|
||||
|
||||
private static AnalysisStats parseAnalysisStats(XContentParser parser) throws IOException {
|
||||
XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser::getTokenLocation);
|
||||
XContentParserUtils.ensureExpectedToken(XContentParser.Token.FIELD_NAME, parser.nextToken(), parser::getTokenLocation);
|
||||
AnalysisStats analysisStats = parser.namedObject(AnalysisStats.class, parser.currentName(), true);
|
||||
XContentParserUtils.ensureExpectedToken(XContentParser.Token.END_OBJECT, parser.nextToken(), parser::getTokenLocation);
|
||||
return analysisStats;
|
||||
}
|
||||
|
||||
private final String id;
|
||||
private final DataFrameAnalyticsState state;
|
||||
private final String failureReason;
|
||||
private final List<PhaseProgress> progress;
|
||||
private final MemoryUsage memoryUsage;
|
||||
private final AnalysisStats analysisStats;
|
||||
private final NodeAttributes node;
|
||||
private final String assignmentExplanation;
|
||||
|
||||
public DataFrameAnalyticsStats(String id, DataFrameAnalyticsState state, @Nullable String failureReason,
|
||||
@Nullable List<PhaseProgress> progress, @Nullable MemoryUsage memoryUsage,
|
||||
@Nullable NodeAttributes node, @Nullable String assignmentExplanation) {
|
||||
@Nullable AnalysisStats analysisStats, @Nullable NodeAttributes node,
|
||||
@Nullable String assignmentExplanation) {
|
||||
this.id = id;
|
||||
this.state = state;
|
||||
this.failureReason = failureReason;
|
||||
this.progress = progress;
|
||||
this.memoryUsage = memoryUsage;
|
||||
this.analysisStats = analysisStats;
|
||||
this.node = node;
|
||||
this.assignmentExplanation = assignmentExplanation;
|
||||
}
|
||||
@ -116,6 +133,11 @@ public class DataFrameAnalyticsStats {
|
||||
return memoryUsage;
|
||||
}
|
||||
|
||||
@Nullable
|
||||
public AnalysisStats getAnalysisStats() {
|
||||
return analysisStats;
|
||||
}
|
||||
|
||||
public NodeAttributes getNode() {
|
||||
return node;
|
||||
}
|
||||
@ -135,13 +157,14 @@ public class DataFrameAnalyticsStats {
|
||||
&& Objects.equals(failureReason, other.failureReason)
|
||||
&& Objects.equals(progress, other.progress)
|
||||
&& Objects.equals(memoryUsage, other.memoryUsage)
|
||||
&& Objects.equals(analysisStats, other.analysisStats)
|
||||
&& Objects.equals(node, other.node)
|
||||
&& Objects.equals(assignmentExplanation, other.assignmentExplanation);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(id, state, failureReason, progress, memoryUsage, node, assignmentExplanation);
|
||||
return Objects.hash(id, state, failureReason, progress, memoryUsage, analysisStats, node, assignmentExplanation);
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -152,6 +175,7 @@ public class DataFrameAnalyticsStats {
|
||||
.add("failureReason", failureReason)
|
||||
.add("progress", progress)
|
||||
.add("memoryUsage", memoryUsage)
|
||||
.add("analysisStats", analysisStats)
|
||||
.add("node", node)
|
||||
.add("assignmentExplanation", assignmentExplanation)
|
||||
.toString();
|
||||
|
@ -0,0 +1,29 @@
|
||||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.stats;
|
||||
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
|
||||
/**
|
||||
* Statistics for the data frame analysis
|
||||
*/
|
||||
public interface AnalysisStats extends ToXContentObject {
|
||||
|
||||
String getName();
|
||||
}
|
@ -0,0 +1,52 @@
|
||||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.stats;
|
||||
|
||||
import org.elasticsearch.client.ml.dataframe.stats.classification.ClassificationStats;
|
||||
import org.elasticsearch.client.ml.dataframe.stats.outlierdetection.OutlierDetectionStats;
|
||||
import org.elasticsearch.client.ml.dataframe.stats.regression.RegressionStats;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.plugins.spi.NamedXContentProvider;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
public class AnalysisStatsNamedXContentProvider implements NamedXContentProvider {
|
||||
|
||||
@Override
|
||||
public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
|
||||
return Arrays.asList(
|
||||
new NamedXContentRegistry.Entry(
|
||||
AnalysisStats.class,
|
||||
ClassificationStats.NAME,
|
||||
(p, c) -> ClassificationStats.PARSER.apply(p, null)
|
||||
),
|
||||
new NamedXContentRegistry.Entry(
|
||||
AnalysisStats.class,
|
||||
OutlierDetectionStats.NAME,
|
||||
(p, c) -> OutlierDetectionStats.PARSER.apply(p, null)
|
||||
),
|
||||
new NamedXContentRegistry.Entry(
|
||||
AnalysisStats.class,
|
||||
RegressionStats.NAME,
|
||||
(p, c) -> RegressionStats.PARSER.apply(p, null)
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
@ -0,0 +1,135 @@
|
||||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.stats.classification;
|
||||
|
||||
import org.elasticsearch.client.common.TimeUtil;
|
||||
import org.elasticsearch.client.ml.dataframe.stats.AnalysisStats;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContent;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.time.Instant;
|
||||
import java.util.Objects;
|
||||
|
||||
public class ClassificationStats implements AnalysisStats {
|
||||
|
||||
public static final ParseField NAME = new ParseField("classification_stats");
|
||||
|
||||
public static final ParseField TIMESTAMP = new ParseField("timestamp");
|
||||
public static final ParseField ITERATION = new ParseField("iteration");
|
||||
public static final ParseField HYPERPARAMETERS = new ParseField("hyperparameters");
|
||||
public static final ParseField TIMING_STATS = new ParseField("timing_stats");
|
||||
public static final ParseField VALIDATION_LOSS = new ParseField("validation_loss");
|
||||
|
||||
public static final ConstructingObjectParser<ClassificationStats, Void> PARSER = new ConstructingObjectParser<>(NAME.getPreferredName(),
|
||||
true,
|
||||
a -> new ClassificationStats(
|
||||
(Instant) a[0],
|
||||
(Integer) a[1],
|
||||
(Hyperparameters) a[2],
|
||||
(TimingStats) a[3],
|
||||
(ValidationLoss) a[4]
|
||||
)
|
||||
);
|
||||
|
||||
static {
|
||||
PARSER.declareField(ConstructingObjectParser.constructorArg(),
|
||||
p -> TimeUtil.parseTimeFieldToInstant(p, TIMESTAMP.getPreferredName()),
|
||||
TIMESTAMP,
|
||||
ObjectParser.ValueType.VALUE);
|
||||
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), ITERATION);
|
||||
PARSER.declareObject(ConstructingObjectParser.constructorArg(), Hyperparameters.PARSER, HYPERPARAMETERS);
|
||||
PARSER.declareObject(ConstructingObjectParser.constructorArg(), TimingStats.PARSER, TIMING_STATS);
|
||||
PARSER.declareObject(ConstructingObjectParser.constructorArg(), ValidationLoss.PARSER, VALIDATION_LOSS);
|
||||
}
|
||||
|
||||
private final Instant timestamp;
|
||||
private final Integer iteration;
|
||||
private final Hyperparameters hyperparameters;
|
||||
private final TimingStats timingStats;
|
||||
private final ValidationLoss validationLoss;
|
||||
|
||||
public ClassificationStats(Instant timestamp, Integer iteration, Hyperparameters hyperparameters, TimingStats timingStats,
|
||||
ValidationLoss validationLoss) {
|
||||
this.timestamp = Instant.ofEpochMilli(Objects.requireNonNull(timestamp).toEpochMilli());
|
||||
this.iteration = iteration;
|
||||
this.hyperparameters = Objects.requireNonNull(hyperparameters);
|
||||
this.timingStats = Objects.requireNonNull(timingStats);
|
||||
this.validationLoss = Objects.requireNonNull(validationLoss);
|
||||
}
|
||||
|
||||
public Instant getTimestamp() {
|
||||
return timestamp;
|
||||
}
|
||||
|
||||
public Integer getIteration() {
|
||||
return iteration;
|
||||
}
|
||||
|
||||
public Hyperparameters getHyperparameters() {
|
||||
return hyperparameters;
|
||||
}
|
||||
|
||||
public TimingStats getTimingStats() {
|
||||
return timingStats;
|
||||
}
|
||||
|
||||
public ValidationLoss getValidationLoss() {
|
||||
return validationLoss;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.timeField(TIMESTAMP.getPreferredName(), TIMESTAMP.getPreferredName() + "_string", timestamp.toEpochMilli());
|
||||
if (iteration != null) {
|
||||
builder.field(ITERATION.getPreferredName(), iteration);
|
||||
}
|
||||
builder.field(HYPERPARAMETERS.getPreferredName(), hyperparameters);
|
||||
builder.field(TIMING_STATS.getPreferredName(), timingStats);
|
||||
builder.field(VALIDATION_LOSS.getPreferredName(), validationLoss);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
ClassificationStats that = (ClassificationStats) o;
|
||||
return Objects.equals(timestamp, that.timestamp)
|
||||
&& Objects.equals(iteration, that.iteration)
|
||||
&& Objects.equals(hyperparameters, that.hyperparameters)
|
||||
&& Objects.equals(timingStats, that.timingStats)
|
||||
&& Objects.equals(validationLoss, that.validationLoss);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(timestamp, iteration, hyperparameters, timingStats, validationLoss);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
}
|
@ -0,0 +1,293 @@
|
||||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.stats.classification;
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||
|
||||
public class Hyperparameters implements ToXContentObject {
|
||||
|
||||
public static final ParseField CLASS_ASSIGNMENT_OBJECTIVE = new ParseField("class_assignment_objective");
|
||||
public static final ParseField DOWNSAMPLE_FACTOR = new ParseField("downsample_factor");
|
||||
public static final ParseField ETA = new ParseField("eta");
|
||||
public static final ParseField ETA_GROWTH_RATE_PER_TREE = new ParseField("eta_growth_rate_per_tree");
|
||||
public static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction");
|
||||
public static final ParseField MAX_ATTEMPTS_TO_ADD_TREE = new ParseField("max_attempts_to_add_tree");
|
||||
public static final ParseField MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER = new ParseField(
|
||||
"max_optimization_rounds_per_hyperparameter");
|
||||
public static final ParseField MAX_TREES = new ParseField("max_trees");
|
||||
public static final ParseField NUM_FOLDS = new ParseField("num_folds");
|
||||
public static final ParseField NUM_SPLITS_PER_FEATURE = new ParseField("num_splits_per_feature");
|
||||
public static final ParseField REGULARIZATION_DEPTH_PENALTY_MULTIPLIER = new ParseField("regularization_depth_penalty_multiplier");
|
||||
public static final ParseField REGULARIZATION_LEAF_WEIGHT_PENALTY_MULTIPLIER
|
||||
= new ParseField("regularization_leaf_weight_penalty_multiplier");
|
||||
public static final ParseField REGULARIZATION_SOFT_TREE_DEPTH_LIMIT = new ParseField("regularization_soft_tree_depth_limit");
|
||||
public static final ParseField REGULARIZATION_SOFT_TREE_DEPTH_TOLERANCE = new ParseField("regularization_soft_tree_depth_tolerance");
|
||||
public static final ParseField REGULARIZATION_TREE_SIZE_PENALTY_MULTIPLIER =
|
||||
new ParseField("regularization_tree_size_penalty_multiplier");
|
||||
|
||||
public static ConstructingObjectParser<Hyperparameters, Void> PARSER = new ConstructingObjectParser<>("classification_hyperparameters",
|
||||
true,
|
||||
a -> new Hyperparameters(
|
||||
(String) a[0],
|
||||
(Double) a[1],
|
||||
(Double) a[2],
|
||||
(Double) a[3],
|
||||
(Double) a[4],
|
||||
(Integer) a[5],
|
||||
(Integer) a[6],
|
||||
(Integer) a[7],
|
||||
(Integer) a[8],
|
||||
(Integer) a[9],
|
||||
(Double) a[10],
|
||||
(Double) a[11],
|
||||
(Double) a[12],
|
||||
(Double) a[13],
|
||||
(Double) a[14]
|
||||
));
|
||||
|
||||
static {
|
||||
PARSER.declareString(optionalConstructorArg(), CLASS_ASSIGNMENT_OBJECTIVE);
|
||||
PARSER.declareDouble(optionalConstructorArg(), DOWNSAMPLE_FACTOR);
|
||||
PARSER.declareDouble(optionalConstructorArg(), ETA);
|
||||
PARSER.declareDouble(optionalConstructorArg(), ETA_GROWTH_RATE_PER_TREE);
|
||||
PARSER.declareDouble(optionalConstructorArg(), FEATURE_BAG_FRACTION);
|
||||
PARSER.declareInt(optionalConstructorArg(), MAX_ATTEMPTS_TO_ADD_TREE);
|
||||
PARSER.declareInt(optionalConstructorArg(), MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER);
|
||||
PARSER.declareInt(optionalConstructorArg(), MAX_TREES);
|
||||
PARSER.declareInt(optionalConstructorArg(), NUM_FOLDS);
|
||||
PARSER.declareInt(optionalConstructorArg(), NUM_SPLITS_PER_FEATURE);
|
||||
PARSER.declareDouble(optionalConstructorArg(), REGULARIZATION_DEPTH_PENALTY_MULTIPLIER);
|
||||
PARSER.declareDouble(optionalConstructorArg(), REGULARIZATION_LEAF_WEIGHT_PENALTY_MULTIPLIER);
|
||||
PARSER.declareDouble(optionalConstructorArg(), REGULARIZATION_SOFT_TREE_DEPTH_LIMIT);
|
||||
PARSER.declareDouble(optionalConstructorArg(), REGULARIZATION_SOFT_TREE_DEPTH_TOLERANCE);
|
||||
PARSER.declareDouble(optionalConstructorArg(), REGULARIZATION_TREE_SIZE_PENALTY_MULTIPLIER);
|
||||
}
|
||||
|
||||
private final String classAssignmentObjective;
|
||||
private final Double downsampleFactor;
|
||||
private final Double eta;
|
||||
private final Double etaGrowthRatePerTree;
|
||||
private final Double featureBagFraction;
|
||||
private final Integer maxAttemptsToAddTree;
|
||||
private final Integer maxOptimizationRoundsPerHyperparameter;
|
||||
private final Integer maxTrees;
|
||||
private final Integer numFolds;
|
||||
private final Integer numSplitsPerFeature;
|
||||
private final Double regularizationDepthPenaltyMultiplier;
|
||||
private final Double regularizationLeafWeightPenaltyMultiplier;
|
||||
private final Double regularizationSoftTreeDepthLimit;
|
||||
private final Double regularizationSoftTreeDepthTolerance;
|
||||
private final Double regularizationTreeSizePenaltyMultiplier;
|
||||
|
||||
public Hyperparameters(String classAssignmentObjective,
|
||||
Double downsampleFactor,
|
||||
Double eta,
|
||||
Double etaGrowthRatePerTree,
|
||||
Double featureBagFraction,
|
||||
Integer maxAttemptsToAddTree,
|
||||
Integer maxOptimizationRoundsPerHyperparameter,
|
||||
Integer maxTrees,
|
||||
Integer numFolds,
|
||||
Integer numSplitsPerFeature,
|
||||
Double regularizationDepthPenaltyMultiplier,
|
||||
Double regularizationLeafWeightPenaltyMultiplier,
|
||||
Double regularizationSoftTreeDepthLimit,
|
||||
Double regularizationSoftTreeDepthTolerance,
|
||||
Double regularizationTreeSizePenaltyMultiplier) {
|
||||
this.classAssignmentObjective = classAssignmentObjective;
|
||||
this.downsampleFactor = downsampleFactor;
|
||||
this.eta = eta;
|
||||
this.etaGrowthRatePerTree = etaGrowthRatePerTree;
|
||||
this.featureBagFraction = featureBagFraction;
|
||||
this.maxAttemptsToAddTree = maxAttemptsToAddTree;
|
||||
this.maxOptimizationRoundsPerHyperparameter = maxOptimizationRoundsPerHyperparameter;
|
||||
this.maxTrees = maxTrees;
|
||||
this.numFolds = numFolds;
|
||||
this.numSplitsPerFeature = numSplitsPerFeature;
|
||||
this.regularizationDepthPenaltyMultiplier = regularizationDepthPenaltyMultiplier;
|
||||
this.regularizationLeafWeightPenaltyMultiplier = regularizationLeafWeightPenaltyMultiplier;
|
||||
this.regularizationSoftTreeDepthLimit = regularizationSoftTreeDepthLimit;
|
||||
this.regularizationSoftTreeDepthTolerance = regularizationSoftTreeDepthTolerance;
|
||||
this.regularizationTreeSizePenaltyMultiplier = regularizationTreeSizePenaltyMultiplier;
|
||||
}
|
||||
|
||||
public String getClassAssignmentObjective() {
|
||||
return classAssignmentObjective;
|
||||
}
|
||||
|
||||
public Double getDownsampleFactor() {
|
||||
return downsampleFactor;
|
||||
}
|
||||
|
||||
public Double getEta() {
|
||||
return eta;
|
||||
}
|
||||
|
||||
public Double getEtaGrowthRatePerTree() {
|
||||
return etaGrowthRatePerTree;
|
||||
}
|
||||
|
||||
public Double getFeatureBagFraction() {
|
||||
return featureBagFraction;
|
||||
}
|
||||
|
||||
public Integer getMaxAttemptsToAddTree() {
|
||||
return maxAttemptsToAddTree;
|
||||
}
|
||||
|
||||
public Integer getMaxOptimizationRoundsPerHyperparameter() {
|
||||
return maxOptimizationRoundsPerHyperparameter;
|
||||
}
|
||||
|
||||
public Integer getMaxTrees() {
|
||||
return maxTrees;
|
||||
}
|
||||
|
||||
public Integer getNumFolds() {
|
||||
return numFolds;
|
||||
}
|
||||
|
||||
public Integer getNumSplitsPerFeature() {
|
||||
return numSplitsPerFeature;
|
||||
}
|
||||
|
||||
public Double getRegularizationDepthPenaltyMultiplier() {
|
||||
return regularizationDepthPenaltyMultiplier;
|
||||
}
|
||||
|
||||
public Double getRegularizationLeafWeightPenaltyMultiplier() {
|
||||
return regularizationLeafWeightPenaltyMultiplier;
|
||||
}
|
||||
|
||||
public Double getRegularizationSoftTreeDepthLimit() {
|
||||
return regularizationSoftTreeDepthLimit;
|
||||
}
|
||||
|
||||
public Double getRegularizationSoftTreeDepthTolerance() {
|
||||
return regularizationSoftTreeDepthTolerance;
|
||||
}
|
||||
|
||||
public Double getRegularizationTreeSizePenaltyMultiplier() {
|
||||
return regularizationTreeSizePenaltyMultiplier;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
if (classAssignmentObjective != null) {
|
||||
builder.field(CLASS_ASSIGNMENT_OBJECTIVE.getPreferredName(), classAssignmentObjective);
|
||||
}
|
||||
if (downsampleFactor != null) {
|
||||
builder.field(DOWNSAMPLE_FACTOR.getPreferredName(), downsampleFactor);
|
||||
}
|
||||
if (eta != null) {
|
||||
builder.field(ETA.getPreferredName(), eta);
|
||||
}
|
||||
if (etaGrowthRatePerTree != null) {
|
||||
builder.field(ETA_GROWTH_RATE_PER_TREE.getPreferredName(), etaGrowthRatePerTree);
|
||||
}
|
||||
if (featureBagFraction != null) {
|
||||
builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction);
|
||||
}
|
||||
if (maxAttemptsToAddTree != null) {
|
||||
builder.field(MAX_ATTEMPTS_TO_ADD_TREE.getPreferredName(), maxAttemptsToAddTree);
|
||||
}
|
||||
if (maxOptimizationRoundsPerHyperparameter != null) {
|
||||
builder.field(MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER.getPreferredName(), maxOptimizationRoundsPerHyperparameter);
|
||||
}
|
||||
if (maxTrees != null) {
|
||||
builder.field(MAX_TREES.getPreferredName(), maxTrees);
|
||||
}
|
||||
if (numFolds != null) {
|
||||
builder.field(NUM_FOLDS.getPreferredName(), numFolds);
|
||||
}
|
||||
if (numSplitsPerFeature != null) {
|
||||
builder.field(NUM_SPLITS_PER_FEATURE.getPreferredName(), numSplitsPerFeature);
|
||||
}
|
||||
if (regularizationDepthPenaltyMultiplier != null) {
|
||||
builder.field(REGULARIZATION_DEPTH_PENALTY_MULTIPLIER.getPreferredName(), regularizationDepthPenaltyMultiplier);
|
||||
}
|
||||
if (regularizationLeafWeightPenaltyMultiplier != null) {
|
||||
builder.field(REGULARIZATION_LEAF_WEIGHT_PENALTY_MULTIPLIER.getPreferredName(), regularizationLeafWeightPenaltyMultiplier);
|
||||
}
|
||||
if (regularizationSoftTreeDepthLimit != null) {
|
||||
builder.field(REGULARIZATION_SOFT_TREE_DEPTH_LIMIT.getPreferredName(), regularizationSoftTreeDepthLimit);
|
||||
}
|
||||
if (regularizationSoftTreeDepthTolerance != null) {
|
||||
builder.field(REGULARIZATION_SOFT_TREE_DEPTH_TOLERANCE.getPreferredName(), regularizationSoftTreeDepthTolerance);
|
||||
}
|
||||
if (regularizationTreeSizePenaltyMultiplier != null) {
|
||||
builder.field(REGULARIZATION_TREE_SIZE_PENALTY_MULTIPLIER.getPreferredName(), regularizationTreeSizePenaltyMultiplier);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
|
||||
Hyperparameters that = (Hyperparameters) o;
|
||||
return Objects.equals(classAssignmentObjective, that.classAssignmentObjective)
|
||||
&& Objects.equals(downsampleFactor, that.downsampleFactor)
|
||||
&& Objects.equals(eta, that.eta)
|
||||
&& Objects.equals(etaGrowthRatePerTree, that.etaGrowthRatePerTree)
|
||||
&& Objects.equals(featureBagFraction, that.featureBagFraction)
|
||||
&& Objects.equals(maxAttemptsToAddTree, that.maxAttemptsToAddTree)
|
||||
&& Objects.equals(maxOptimizationRoundsPerHyperparameter, that.maxOptimizationRoundsPerHyperparameter)
|
||||
&& Objects.equals(maxTrees, that.maxTrees)
|
||||
&& Objects.equals(numFolds, that.numFolds)
|
||||
&& Objects.equals(numSplitsPerFeature, that.numSplitsPerFeature)
|
||||
&& Objects.equals(regularizationDepthPenaltyMultiplier, that.regularizationDepthPenaltyMultiplier)
|
||||
&& Objects.equals(regularizationLeafWeightPenaltyMultiplier, that.regularizationLeafWeightPenaltyMultiplier)
|
||||
&& Objects.equals(regularizationSoftTreeDepthLimit, that.regularizationSoftTreeDepthLimit)
|
||||
&& Objects.equals(regularizationSoftTreeDepthTolerance, that.regularizationSoftTreeDepthTolerance)
|
||||
&& Objects.equals(regularizationTreeSizePenaltyMultiplier, that.regularizationTreeSizePenaltyMultiplier);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(
|
||||
classAssignmentObjective,
|
||||
downsampleFactor,
|
||||
eta,
|
||||
etaGrowthRatePerTree,
|
||||
featureBagFraction,
|
||||
maxAttemptsToAddTree,
|
||||
maxOptimizationRoundsPerHyperparameter,
|
||||
maxTrees,
|
||||
numFolds,
|
||||
numSplitsPerFeature,
|
||||
regularizationDepthPenaltyMultiplier,
|
||||
regularizationLeafWeightPenaltyMultiplier,
|
||||
regularizationSoftTreeDepthLimit,
|
||||
regularizationSoftTreeDepthTolerance,
|
||||
regularizationTreeSizePenaltyMultiplier
|
||||
);
|
||||
}
|
||||
}
|
@ -0,0 +1,87 @@
|
||||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.stats.classification;
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.unit.TimeValue;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
|
||||
public class TimingStats implements ToXContentObject {
|
||||
|
||||
public static final ParseField ELAPSED_TIME = new ParseField("elapsed_time");
|
||||
public static final ParseField ITERATION_TIME = new ParseField("iteration_time");
|
||||
|
||||
public static ConstructingObjectParser<TimingStats, Void> PARSER = new ConstructingObjectParser<>("classification_timing_stats", true,
|
||||
a -> new TimingStats(
|
||||
a[0] == null ? null : TimeValue.timeValueMillis((long) a[0]),
|
||||
a[1] == null ? null : TimeValue.timeValueMillis((long) a[1])
|
||||
));
|
||||
|
||||
static {
|
||||
PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), ELAPSED_TIME);
|
||||
PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), ITERATION_TIME);
|
||||
}
|
||||
|
||||
private final TimeValue elapsedTime;
|
||||
private final TimeValue iterationTime;
|
||||
|
||||
public TimingStats(TimeValue elapsedTime, TimeValue iterationTime) {
|
||||
this.elapsedTime = elapsedTime;
|
||||
this.iterationTime = iterationTime;
|
||||
}
|
||||
|
||||
public TimeValue getElapsedTime() {
|
||||
return elapsedTime;
|
||||
}
|
||||
|
||||
public TimeValue getIterationTime() {
|
||||
return iterationTime;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
if (elapsedTime != null) {
|
||||
builder.humanReadableField(ELAPSED_TIME.getPreferredName(), ELAPSED_TIME.getPreferredName() + "_string", elapsedTime);
|
||||
}
|
||||
if (iterationTime != null) {
|
||||
builder.humanReadableField(ITERATION_TIME.getPreferredName(), ITERATION_TIME.getPreferredName() + "_string", iterationTime);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
TimingStats that = (TimingStats) o;
|
||||
return Objects.equals(elapsedTime, that.elapsedTime) && Objects.equals(iterationTime, that.iterationTime);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(elapsedTime, iterationTime);
|
||||
}
|
||||
}
|
@ -0,0 +1,87 @@
|
||||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.stats.classification;
|
||||
|
||||
import org.elasticsearch.client.ml.dataframe.stats.common.FoldValues;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
public class ValidationLoss implements ToXContentObject {
|
||||
|
||||
public static final ParseField LOSS_TYPE = new ParseField("loss_type");
|
||||
public static final ParseField FOLD_VALUES = new ParseField("fold_values");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public static ConstructingObjectParser<ValidationLoss, Void> PARSER = new ConstructingObjectParser<>("classification_validation_loss",
|
||||
true,
|
||||
a -> new ValidationLoss((String) a[0], (List<FoldValues>) a[1]));
|
||||
|
||||
static {
|
||||
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), LOSS_TYPE);
|
||||
PARSER.declareObjectArray(ConstructingObjectParser.optionalConstructorArg(), FoldValues.PARSER, FOLD_VALUES);
|
||||
}
|
||||
|
||||
private final String lossType;
|
||||
private final List<FoldValues> foldValues;
|
||||
|
||||
public ValidationLoss(String lossType, List<FoldValues> values) {
|
||||
this.lossType = lossType;
|
||||
this.foldValues = values;
|
||||
}
|
||||
|
||||
public String getLossType() {
|
||||
return lossType;
|
||||
}
|
||||
|
||||
public List<FoldValues> getFoldValues() {
|
||||
return foldValues;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
if (lossType != null) {
|
||||
builder.field(LOSS_TYPE.getPreferredName(), lossType);
|
||||
}
|
||||
if (foldValues != null) {
|
||||
builder.field(FOLD_VALUES.getPreferredName(), foldValues);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
ValidationLoss that = (ValidationLoss) o;
|
||||
return Objects.equals(lossType, that.lossType) && Objects.equals(foldValues, that.foldValues);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(lossType, foldValues);
|
||||
}
|
||||
}
|
@ -0,0 +1,87 @@
|
||||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.stats.common;
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
public class FoldValues implements ToXContentObject {
|
||||
|
||||
public static final ParseField FOLD = new ParseField("fold");
|
||||
public static final ParseField VALUES = new ParseField("values");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public static ConstructingObjectParser<FoldValues, Void> PARSER = new ConstructingObjectParser<>("fold_values", true,
|
||||
a -> new FoldValues((int) a[0], (List<Double>) a[1]));
|
||||
|
||||
static {
|
||||
PARSER.declareInt(ConstructingObjectParser.constructorArg(), FOLD);
|
||||
PARSER.declareDoubleArray(ConstructingObjectParser.constructorArg(), VALUES);
|
||||
}
|
||||
|
||||
private final int fold;
|
||||
private final double[] values;
|
||||
|
||||
private FoldValues(int fold, List<Double> values) {
|
||||
this(fold, values.stream().mapToDouble(Double::doubleValue).toArray());
|
||||
}
|
||||
|
||||
public FoldValues(int fold, double[] values) {
|
||||
this.fold = fold;
|
||||
this.values = values;
|
||||
}
|
||||
|
||||
public int getFold() {
|
||||
return fold;
|
||||
}
|
||||
|
||||
public double[] getValues() {
|
||||
return values;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(FOLD.getPreferredName(), fold);
|
||||
builder.array(VALUES.getPreferredName(), values);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (o == this) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
|
||||
FoldValues other = (FoldValues) o;
|
||||
return fold == other.fold && Arrays.equals(values, other.values);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(fold, Arrays.hashCode(values));
|
||||
}
|
||||
}
|
@ -16,7 +16,7 @@
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe;
|
||||
package org.elasticsearch.client.ml.dataframe.stats.common;
|
||||
|
||||
import org.elasticsearch.client.common.TimeUtil;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
@ -54,6 +54,14 @@ public class MemoryUsage implements ToXContentObject {
|
||||
this.peakUsageBytes = peakUsageBytes;
|
||||
}
|
||||
|
||||
public Instant getTimestamp() {
|
||||
return timestamp;
|
||||
}
|
||||
|
||||
public long getPeakUsageBytes() {
|
||||
return peakUsageBytes;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
@ -0,0 +1,105 @@
|
||||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.stats.outlierdetection;
|
||||
|
||||
import org.elasticsearch.client.common.TimeUtil;
|
||||
import org.elasticsearch.client.ml.dataframe.stats.AnalysisStats;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContent;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.time.Instant;
|
||||
import java.util.Objects;
|
||||
|
||||
public class OutlierDetectionStats implements AnalysisStats {
|
||||
|
||||
public static final ParseField NAME = new ParseField("outlier_detection_stats");
|
||||
|
||||
public static final ParseField TIMESTAMP = new ParseField("timestamp");
|
||||
public static final ParseField PARAMETERS = new ParseField("parameters");
|
||||
public static final ParseField TIMING_STATS = new ParseField("timings_stats");
|
||||
|
||||
public static final ConstructingObjectParser<OutlierDetectionStats, Void> PARSER = new ConstructingObjectParser<>(
|
||||
NAME.getPreferredName(), true,
|
||||
a -> new OutlierDetectionStats((Instant) a[0], (Parameters) a[1], (TimingStats) a[2]));
|
||||
|
||||
static {
|
||||
PARSER.declareField(ConstructingObjectParser.constructorArg(),
|
||||
p -> TimeUtil.parseTimeFieldToInstant(p, TIMESTAMP.getPreferredName()),
|
||||
TIMESTAMP,
|
||||
ObjectParser.ValueType.VALUE);
|
||||
PARSER.declareObject(ConstructingObjectParser.constructorArg(), Parameters.PARSER, PARAMETERS);
|
||||
PARSER.declareObject(ConstructingObjectParser.constructorArg(), TimingStats.PARSER, TIMING_STATS);
|
||||
}
|
||||
|
||||
private final Instant timestamp;
|
||||
private final Parameters parameters;
|
||||
private final TimingStats timingStats;
|
||||
|
||||
public OutlierDetectionStats(Instant timestamp, Parameters parameters, TimingStats timingStats) {
|
||||
this.timestamp = Instant.ofEpochMilli(Objects.requireNonNull(timestamp).toEpochMilli());
|
||||
this.parameters = Objects.requireNonNull(parameters);
|
||||
this.timingStats = Objects.requireNonNull(timingStats);
|
||||
}
|
||||
|
||||
public Instant getTimestamp() {
|
||||
return timestamp;
|
||||
}
|
||||
|
||||
public Parameters getParameters() {
|
||||
return parameters;
|
||||
}
|
||||
|
||||
public TimingStats getTimingStats() {
|
||||
return timingStats;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.timeField(TIMESTAMP.getPreferredName(), TIMESTAMP.getPreferredName() + "_string", timestamp.toEpochMilli());
|
||||
builder.field(PARAMETERS.getPreferredName(), parameters);
|
||||
builder.field(TIMING_STATS.getPreferredName(), timingStats);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
OutlierDetectionStats that = (OutlierDetectionStats) o;
|
||||
return Objects.equals(timestamp, that.timestamp)
|
||||
&& Objects.equals(parameters, that.parameters)
|
||||
&& Objects.equals(timingStats, that.timingStats);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(timestamp, parameters, timingStats);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
}
|
@ -0,0 +1,146 @@
|
||||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.stats.outlierdetection;
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||
|
||||
public class Parameters implements ToXContentObject {
|
||||
|
||||
public static final ParseField N_NEIGHBORS = new ParseField("n_neighbors");
|
||||
public static final ParseField METHOD = new ParseField("method");
|
||||
public static final ParseField FEATURE_INFLUENCE_THRESHOLD = new ParseField("feature_influence_threshold");
|
||||
public static final ParseField COMPUTE_FEATURE_INFLUENCE = new ParseField("compute_feature_influence");
|
||||
public static final ParseField OUTLIER_FRACTION = new ParseField("outlier_fraction");
|
||||
public static final ParseField STANDARDIZATION_ENABLED = new ParseField("standardization_enabled");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public static ConstructingObjectParser<Parameters, Void> PARSER = new ConstructingObjectParser<>("outlier_detection_parameters",
|
||||
true,
|
||||
a -> new Parameters(
|
||||
(Integer) a[0],
|
||||
(String) a[1],
|
||||
(Boolean) a[2],
|
||||
(Double) a[3],
|
||||
(Double) a[4],
|
||||
(Boolean) a[5]
|
||||
));
|
||||
|
||||
static {
|
||||
PARSER.declareInt(optionalConstructorArg(), N_NEIGHBORS);
|
||||
PARSER.declareString(optionalConstructorArg(), METHOD);
|
||||
PARSER.declareBoolean(optionalConstructorArg(), COMPUTE_FEATURE_INFLUENCE);
|
||||
PARSER.declareDouble(optionalConstructorArg(), FEATURE_INFLUENCE_THRESHOLD);
|
||||
PARSER.declareDouble(optionalConstructorArg(), OUTLIER_FRACTION);
|
||||
PARSER.declareBoolean(optionalConstructorArg(), STANDARDIZATION_ENABLED);
|
||||
}
|
||||
|
||||
private final Integer nNeighbors;
|
||||
private final String method;
|
||||
private final Boolean computeFeatureInfluence;
|
||||
private final Double featureInfluenceThreshold;
|
||||
private final Double outlierFraction;
|
||||
private final Boolean standardizationEnabled;
|
||||
|
||||
public Parameters(Integer nNeighbors, String method, Boolean computeFeatureInfluence, Double featureInfluenceThreshold,
|
||||
Double outlierFraction, Boolean standardizationEnabled) {
|
||||
this.nNeighbors = nNeighbors;
|
||||
this.method = method;
|
||||
this.computeFeatureInfluence = computeFeatureInfluence;
|
||||
this.featureInfluenceThreshold = featureInfluenceThreshold;
|
||||
this.outlierFraction = outlierFraction;
|
||||
this.standardizationEnabled = standardizationEnabled;
|
||||
}
|
||||
|
||||
public Integer getnNeighbors() {
|
||||
return nNeighbors;
|
||||
}
|
||||
|
||||
public String getMethod() {
|
||||
return method;
|
||||
}
|
||||
|
||||
public Boolean getComputeFeatureInfluence() {
|
||||
return computeFeatureInfluence;
|
||||
}
|
||||
|
||||
public Double getFeatureInfluenceThreshold() {
|
||||
return featureInfluenceThreshold;
|
||||
}
|
||||
|
||||
public Double getOutlierFraction() {
|
||||
return outlierFraction;
|
||||
}
|
||||
|
||||
public Boolean getStandardizationEnabled() {
|
||||
return standardizationEnabled;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
if (nNeighbors != null) {
|
||||
builder.field(N_NEIGHBORS.getPreferredName(), nNeighbors);
|
||||
}
|
||||
if (method != null) {
|
||||
builder.field(METHOD.getPreferredName(), method);
|
||||
}
|
||||
if (computeFeatureInfluence != null) {
|
||||
builder.field(COMPUTE_FEATURE_INFLUENCE.getPreferredName(), computeFeatureInfluence);
|
||||
}
|
||||
if (featureInfluenceThreshold != null) {
|
||||
builder.field(FEATURE_INFLUENCE_THRESHOLD.getPreferredName(), featureInfluenceThreshold);
|
||||
}
|
||||
if (outlierFraction != null) {
|
||||
builder.field(OUTLIER_FRACTION.getPreferredName(), outlierFraction);
|
||||
}
|
||||
if (standardizationEnabled != null) {
|
||||
builder.field(STANDARDIZATION_ENABLED.getPreferredName(), standardizationEnabled);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
|
||||
Parameters that = (Parameters) o;
|
||||
return Objects.equals(nNeighbors, that.nNeighbors)
|
||||
&& Objects.equals(method, that.method)
|
||||
&& Objects.equals(computeFeatureInfluence, that.computeFeatureInfluence)
|
||||
&& Objects.equals(featureInfluenceThreshold, that.featureInfluenceThreshold)
|
||||
&& Objects.equals(outlierFraction, that.outlierFraction)
|
||||
&& Objects.equals(standardizationEnabled, that.standardizationEnabled);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(nNeighbors, method, computeFeatureInfluence, featureInfluenceThreshold, outlierFraction,
|
||||
standardizationEnabled);
|
||||
}
|
||||
}
|
@ -0,0 +1,74 @@
|
||||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.stats.outlierdetection;
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.unit.TimeValue;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
|
||||
public class TimingStats implements ToXContentObject {
|
||||
|
||||
public static final ParseField ELAPSED_TIME = new ParseField("elapsed_time");
|
||||
|
||||
public static ConstructingObjectParser<TimingStats, Void> PARSER = new ConstructingObjectParser<>("outlier_detection_timing_stats",
|
||||
true,
|
||||
a -> new TimingStats(a[0] == null ? null : TimeValue.timeValueMillis((long) a[0])));
|
||||
|
||||
static {
|
||||
PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), ELAPSED_TIME);
|
||||
}
|
||||
|
||||
private final TimeValue elapsedTime;
|
||||
|
||||
public TimingStats(TimeValue elapsedTime) {
|
||||
this.elapsedTime = elapsedTime;
|
||||
}
|
||||
|
||||
public TimeValue getElapsedTime() {
|
||||
return elapsedTime;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
if (elapsedTime != null) {
|
||||
builder.humanReadableField(ELAPSED_TIME.getPreferredName(), ELAPSED_TIME.getPreferredName() + "_string", elapsedTime);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
TimingStats that = (TimingStats) o;
|
||||
return Objects.equals(elapsedTime, that.elapsedTime);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(elapsedTime);
|
||||
}
|
||||
}
|
@ -0,0 +1,278 @@
|
||||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.stats.regression;
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||
|
||||
public class Hyperparameters implements ToXContentObject {
|
||||
|
||||
public static final ParseField DOWNSAMPLE_FACTOR = new ParseField("downsample_factor");
|
||||
public static final ParseField ETA = new ParseField("eta");
|
||||
public static final ParseField ETA_GROWTH_RATE_PER_TREE = new ParseField("eta_growth_rate_per_tree");
|
||||
public static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction");
|
||||
public static final ParseField MAX_ATTEMPTS_TO_ADD_TREE = new ParseField("max_attempts_to_add_tree");
|
||||
public static final ParseField MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER = new ParseField(
|
||||
"max_optimization_rounds_per_hyperparameter");
|
||||
public static final ParseField MAX_TREES = new ParseField("max_trees");
|
||||
public static final ParseField NUM_FOLDS = new ParseField("num_folds");
|
||||
public static final ParseField NUM_SPLITS_PER_FEATURE = new ParseField("num_splits_per_feature");
|
||||
public static final ParseField REGULARIZATION_DEPTH_PENALTY_MULTIPLIER = new ParseField("regularization_depth_penalty_multiplier");
|
||||
public static final ParseField REGULARIZATION_LEAF_WEIGHT_PENALTY_MULTIPLIER
|
||||
= new ParseField("regularization_leaf_weight_penalty_multiplier");
|
||||
public static final ParseField REGULARIZATION_SOFT_TREE_DEPTH_LIMIT = new ParseField("regularization_soft_tree_depth_limit");
|
||||
public static final ParseField REGULARIZATION_SOFT_TREE_DEPTH_TOLERANCE = new ParseField("regularization_soft_tree_depth_tolerance");
|
||||
public static final ParseField REGULARIZATION_TREE_SIZE_PENALTY_MULTIPLIER =
|
||||
new ParseField("regularization_tree_size_penalty_multiplier");
|
||||
|
||||
public static ConstructingObjectParser<Hyperparameters, Void> PARSER = new ConstructingObjectParser<>("regression_hyperparameters",
|
||||
true,
|
||||
a -> new Hyperparameters(
|
||||
(Double) a[0],
|
||||
(Double) a[1],
|
||||
(Double) a[2],
|
||||
(Double) a[3],
|
||||
(Integer) a[4],
|
||||
(Integer) a[5],
|
||||
(Integer) a[6],
|
||||
(Integer) a[7],
|
||||
(Integer) a[8],
|
||||
(Double) a[9],
|
||||
(Double) a[10],
|
||||
(Double) a[11],
|
||||
(Double) a[12],
|
||||
(Double) a[13]
|
||||
));
|
||||
|
||||
static {
|
||||
PARSER.declareDouble(optionalConstructorArg(), DOWNSAMPLE_FACTOR);
|
||||
PARSER.declareDouble(optionalConstructorArg(), ETA);
|
||||
PARSER.declareDouble(optionalConstructorArg(), ETA_GROWTH_RATE_PER_TREE);
|
||||
PARSER.declareDouble(optionalConstructorArg(), FEATURE_BAG_FRACTION);
|
||||
PARSER.declareInt(optionalConstructorArg(), MAX_ATTEMPTS_TO_ADD_TREE);
|
||||
PARSER.declareInt(optionalConstructorArg(), MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER);
|
||||
PARSER.declareInt(optionalConstructorArg(), MAX_TREES);
|
||||
PARSER.declareInt(optionalConstructorArg(), NUM_FOLDS);
|
||||
PARSER.declareInt(optionalConstructorArg(), NUM_SPLITS_PER_FEATURE);
|
||||
PARSER.declareDouble(optionalConstructorArg(), REGULARIZATION_DEPTH_PENALTY_MULTIPLIER);
|
||||
PARSER.declareDouble(optionalConstructorArg(), REGULARIZATION_LEAF_WEIGHT_PENALTY_MULTIPLIER);
|
||||
PARSER.declareDouble(optionalConstructorArg(), REGULARIZATION_SOFT_TREE_DEPTH_LIMIT);
|
||||
PARSER.declareDouble(optionalConstructorArg(), REGULARIZATION_SOFT_TREE_DEPTH_TOLERANCE);
|
||||
PARSER.declareDouble(optionalConstructorArg(), REGULARIZATION_TREE_SIZE_PENALTY_MULTIPLIER);
|
||||
}
|
||||
|
||||
private final Double downsampleFactor;
|
||||
private final Double eta;
|
||||
private final Double etaGrowthRatePerTree;
|
||||
private final Double featureBagFraction;
|
||||
private final Integer maxAttemptsToAddTree;
|
||||
private final Integer maxOptimizationRoundsPerHyperparameter;
|
||||
private final Integer maxTrees;
|
||||
private final Integer numFolds;
|
||||
private final Integer numSplitsPerFeature;
|
||||
private final Double regularizationDepthPenaltyMultiplier;
|
||||
private final Double regularizationLeafWeightPenaltyMultiplier;
|
||||
private final Double regularizationSoftTreeDepthLimit;
|
||||
private final Double regularizationSoftTreeDepthTolerance;
|
||||
private final Double regularizationTreeSizePenaltyMultiplier;
|
||||
|
||||
public Hyperparameters(Double downsampleFactor,
|
||||
Double eta,
|
||||
Double etaGrowthRatePerTree,
|
||||
Double featureBagFraction,
|
||||
Integer maxAttemptsToAddTree,
|
||||
Integer maxOptimizationRoundsPerHyperparameter,
|
||||
Integer maxTrees,
|
||||
Integer numFolds,
|
||||
Integer numSplitsPerFeature,
|
||||
Double regularizationDepthPenaltyMultiplier,
|
||||
Double regularizationLeafWeightPenaltyMultiplier,
|
||||
Double regularizationSoftTreeDepthLimit,
|
||||
Double regularizationSoftTreeDepthTolerance,
|
||||
Double regularizationTreeSizePenaltyMultiplier) {
|
||||
this.downsampleFactor = downsampleFactor;
|
||||
this.eta = eta;
|
||||
this.etaGrowthRatePerTree = etaGrowthRatePerTree;
|
||||
this.featureBagFraction = featureBagFraction;
|
||||
this.maxAttemptsToAddTree = maxAttemptsToAddTree;
|
||||
this.maxOptimizationRoundsPerHyperparameter = maxOptimizationRoundsPerHyperparameter;
|
||||
this.maxTrees = maxTrees;
|
||||
this.numFolds = numFolds;
|
||||
this.numSplitsPerFeature = numSplitsPerFeature;
|
||||
this.regularizationDepthPenaltyMultiplier = regularizationDepthPenaltyMultiplier;
|
||||
this.regularizationLeafWeightPenaltyMultiplier = regularizationLeafWeightPenaltyMultiplier;
|
||||
this.regularizationSoftTreeDepthLimit = regularizationSoftTreeDepthLimit;
|
||||
this.regularizationSoftTreeDepthTolerance = regularizationSoftTreeDepthTolerance;
|
||||
this.regularizationTreeSizePenaltyMultiplier = regularizationTreeSizePenaltyMultiplier;
|
||||
}
|
||||
|
||||
public Double getDownsampleFactor() {
|
||||
return downsampleFactor;
|
||||
}
|
||||
|
||||
public Double getEta() {
|
||||
return eta;
|
||||
}
|
||||
|
||||
public Double getEtaGrowthRatePerTree() {
|
||||
return etaGrowthRatePerTree;
|
||||
}
|
||||
|
||||
public Double getFeatureBagFraction() {
|
||||
return featureBagFraction;
|
||||
}
|
||||
|
||||
public Integer getMaxAttemptsToAddTree() {
|
||||
return maxAttemptsToAddTree;
|
||||
}
|
||||
|
||||
public Integer getMaxOptimizationRoundsPerHyperparameter() {
|
||||
return maxOptimizationRoundsPerHyperparameter;
|
||||
}
|
||||
|
||||
public Integer getMaxTrees() {
|
||||
return maxTrees;
|
||||
}
|
||||
|
||||
public Integer getNumFolds() {
|
||||
return numFolds;
|
||||
}
|
||||
|
||||
public Integer getNumSplitsPerFeature() {
|
||||
return numSplitsPerFeature;
|
||||
}
|
||||
|
||||
public Double getRegularizationDepthPenaltyMultiplier() {
|
||||
return regularizationDepthPenaltyMultiplier;
|
||||
}
|
||||
|
||||
public Double getRegularizationLeafWeightPenaltyMultiplier() {
|
||||
return regularizationLeafWeightPenaltyMultiplier;
|
||||
}
|
||||
|
||||
public Double getRegularizationSoftTreeDepthLimit() {
|
||||
return regularizationSoftTreeDepthLimit;
|
||||
}
|
||||
|
||||
public Double getRegularizationSoftTreeDepthTolerance() {
|
||||
return regularizationSoftTreeDepthTolerance;
|
||||
}
|
||||
|
||||
public Double getRegularizationTreeSizePenaltyMultiplier() {
|
||||
return regularizationTreeSizePenaltyMultiplier;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
if (downsampleFactor != null) {
|
||||
builder.field(DOWNSAMPLE_FACTOR.getPreferredName(), downsampleFactor);
|
||||
}
|
||||
if (eta != null) {
|
||||
builder.field(ETA.getPreferredName(), eta);
|
||||
}
|
||||
if (etaGrowthRatePerTree != null) {
|
||||
builder.field(ETA_GROWTH_RATE_PER_TREE.getPreferredName(), etaGrowthRatePerTree);
|
||||
}
|
||||
if (featureBagFraction != null) {
|
||||
builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction);
|
||||
}
|
||||
if (maxAttemptsToAddTree != null) {
|
||||
builder.field(MAX_ATTEMPTS_TO_ADD_TREE.getPreferredName(), maxAttemptsToAddTree);
|
||||
}
|
||||
if (maxOptimizationRoundsPerHyperparameter != null) {
|
||||
builder.field(MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER.getPreferredName(), maxOptimizationRoundsPerHyperparameter);
|
||||
}
|
||||
if (maxTrees != null) {
|
||||
builder.field(MAX_TREES.getPreferredName(), maxTrees);
|
||||
}
|
||||
if (numFolds != null) {
|
||||
builder.field(NUM_FOLDS.getPreferredName(), numFolds);
|
||||
}
|
||||
if (numSplitsPerFeature != null) {
|
||||
builder.field(NUM_SPLITS_PER_FEATURE.getPreferredName(), numSplitsPerFeature);
|
||||
}
|
||||
if (regularizationDepthPenaltyMultiplier != null) {
|
||||
builder.field(REGULARIZATION_DEPTH_PENALTY_MULTIPLIER.getPreferredName(), regularizationDepthPenaltyMultiplier);
|
||||
}
|
||||
if (regularizationLeafWeightPenaltyMultiplier != null) {
|
||||
builder.field(REGULARIZATION_LEAF_WEIGHT_PENALTY_MULTIPLIER.getPreferredName(), regularizationLeafWeightPenaltyMultiplier);
|
||||
}
|
||||
if (regularizationSoftTreeDepthLimit != null) {
|
||||
builder.field(REGULARIZATION_SOFT_TREE_DEPTH_LIMIT.getPreferredName(), regularizationSoftTreeDepthLimit);
|
||||
}
|
||||
if (regularizationSoftTreeDepthTolerance != null) {
|
||||
builder.field(REGULARIZATION_SOFT_TREE_DEPTH_TOLERANCE.getPreferredName(), regularizationSoftTreeDepthTolerance);
|
||||
}
|
||||
if (regularizationTreeSizePenaltyMultiplier != null) {
|
||||
builder.field(REGULARIZATION_TREE_SIZE_PENALTY_MULTIPLIER.getPreferredName(), regularizationTreeSizePenaltyMultiplier);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
|
||||
Hyperparameters that = (Hyperparameters) o;
|
||||
return Objects.equals(downsampleFactor, that.downsampleFactor)
|
||||
&& Objects.equals(eta, that.eta)
|
||||
&& Objects.equals(etaGrowthRatePerTree, that.etaGrowthRatePerTree)
|
||||
&& Objects.equals(featureBagFraction, that.featureBagFraction)
|
||||
&& Objects.equals(maxAttemptsToAddTree, that.maxAttemptsToAddTree)
|
||||
&& Objects.equals(maxOptimizationRoundsPerHyperparameter, that.maxOptimizationRoundsPerHyperparameter)
|
||||
&& Objects.equals(maxTrees, that.maxTrees)
|
||||
&& Objects.equals(numFolds, that.numFolds)
|
||||
&& Objects.equals(numSplitsPerFeature, that.numSplitsPerFeature)
|
||||
&& Objects.equals(regularizationDepthPenaltyMultiplier, that.regularizationDepthPenaltyMultiplier)
|
||||
&& Objects.equals(regularizationLeafWeightPenaltyMultiplier, that.regularizationLeafWeightPenaltyMultiplier)
|
||||
&& Objects.equals(regularizationSoftTreeDepthLimit, that.regularizationSoftTreeDepthLimit)
|
||||
&& Objects.equals(regularizationSoftTreeDepthTolerance, that.regularizationSoftTreeDepthTolerance)
|
||||
&& Objects.equals(regularizationTreeSizePenaltyMultiplier, that.regularizationTreeSizePenaltyMultiplier);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(
|
||||
downsampleFactor,
|
||||
eta,
|
||||
etaGrowthRatePerTree,
|
||||
featureBagFraction,
|
||||
maxAttemptsToAddTree,
|
||||
maxOptimizationRoundsPerHyperparameter,
|
||||
maxTrees,
|
||||
numFolds,
|
||||
numSplitsPerFeature,
|
||||
regularizationDepthPenaltyMultiplier,
|
||||
regularizationLeafWeightPenaltyMultiplier,
|
||||
regularizationSoftTreeDepthLimit,
|
||||
regularizationSoftTreeDepthTolerance,
|
||||
regularizationTreeSizePenaltyMultiplier
|
||||
);
|
||||
}
|
||||
}
|
@ -0,0 +1,135 @@
|
||||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.stats.regression;
|
||||
|
||||
import org.elasticsearch.client.common.TimeUtil;
|
||||
import org.elasticsearch.client.ml.dataframe.stats.AnalysisStats;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContent;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.time.Instant;
|
||||
import java.util.Objects;
|
||||
|
||||
public class RegressionStats implements AnalysisStats {
|
||||
|
||||
public static final ParseField NAME = new ParseField("regression_stats");
|
||||
|
||||
public static final ParseField TIMESTAMP = new ParseField("timestamp");
|
||||
public static final ParseField ITERATION = new ParseField("iteration");
|
||||
public static final ParseField HYPERPARAMETERS = new ParseField("hyperparameters");
|
||||
public static final ParseField TIMING_STATS = new ParseField("timing_stats");
|
||||
public static final ParseField VALIDATION_LOSS = new ParseField("validation_loss");
|
||||
|
||||
public static final ConstructingObjectParser<RegressionStats, Void> PARSER = new ConstructingObjectParser<>(NAME.getPreferredName(),
|
||||
true,
|
||||
a -> new RegressionStats(
|
||||
(Instant) a[0],
|
||||
(Integer) a[1],
|
||||
(Hyperparameters) a[2],
|
||||
(TimingStats) a[3],
|
||||
(ValidationLoss) a[4]
|
||||
)
|
||||
);
|
||||
|
||||
static {
|
||||
PARSER.declareField(ConstructingObjectParser.constructorArg(),
|
||||
p -> TimeUtil.parseTimeFieldToInstant(p, TIMESTAMP.getPreferredName()),
|
||||
TIMESTAMP,
|
||||
ObjectParser.ValueType.VALUE);
|
||||
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), ITERATION);
|
||||
PARSER.declareObject(ConstructingObjectParser.constructorArg(), Hyperparameters.PARSER, HYPERPARAMETERS);
|
||||
PARSER.declareObject(ConstructingObjectParser.constructorArg(), TimingStats.PARSER, TIMING_STATS);
|
||||
PARSER.declareObject(ConstructingObjectParser.constructorArg(), ValidationLoss.PARSER, VALIDATION_LOSS);
|
||||
}
|
||||
|
||||
private final Instant timestamp;
|
||||
private final Integer iteration;
|
||||
private final Hyperparameters hyperparameters;
|
||||
private final TimingStats timingStats;
|
||||
private final ValidationLoss validationLoss;
|
||||
|
||||
public RegressionStats(Instant timestamp, Integer iteration, Hyperparameters hyperparameters, TimingStats timingStats,
|
||||
ValidationLoss validationLoss) {
|
||||
this.timestamp = Instant.ofEpochMilli(Objects.requireNonNull(timestamp).toEpochMilli());
|
||||
this.iteration = iteration;
|
||||
this.hyperparameters = Objects.requireNonNull(hyperparameters);
|
||||
this.timingStats = Objects.requireNonNull(timingStats);
|
||||
this.validationLoss = Objects.requireNonNull(validationLoss);
|
||||
}
|
||||
|
||||
public Instant getTimestamp() {
|
||||
return timestamp;
|
||||
}
|
||||
|
||||
public Integer getIteration() {
|
||||
return iteration;
|
||||
}
|
||||
|
||||
public Hyperparameters getHyperparameters() {
|
||||
return hyperparameters;
|
||||
}
|
||||
|
||||
public TimingStats getTimingStats() {
|
||||
return timingStats;
|
||||
}
|
||||
|
||||
public ValidationLoss getValidationLoss() {
|
||||
return validationLoss;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.timeField(TIMESTAMP.getPreferredName(), TIMESTAMP.getPreferredName() + "_string", timestamp.toEpochMilli());
|
||||
if (iteration != null) {
|
||||
builder.field(ITERATION.getPreferredName(), iteration);
|
||||
}
|
||||
builder.field(HYPERPARAMETERS.getPreferredName(), hyperparameters);
|
||||
builder.field(TIMING_STATS.getPreferredName(), timingStats);
|
||||
builder.field(VALIDATION_LOSS.getPreferredName(), validationLoss);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
RegressionStats that = (RegressionStats) o;
|
||||
return Objects.equals(timestamp, that.timestamp)
|
||||
&& Objects.equals(iteration, that.iteration)
|
||||
&& Objects.equals(hyperparameters, that.hyperparameters)
|
||||
&& Objects.equals(timingStats, that.timingStats)
|
||||
&& Objects.equals(validationLoss, that.validationLoss);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(timestamp, iteration, hyperparameters, timingStats, validationLoss);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
}
|
@ -0,0 +1,87 @@
|
||||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.stats.regression;
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.unit.TimeValue;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
|
||||
public class TimingStats implements ToXContentObject {
|
||||
|
||||
public static final ParseField ELAPSED_TIME = new ParseField("elapsed_time");
|
||||
public static final ParseField ITERATION_TIME = new ParseField("iteration_time");
|
||||
|
||||
public static ConstructingObjectParser<TimingStats, Void> PARSER = new ConstructingObjectParser<>("regression_timing_stats", true,
|
||||
a -> new TimingStats(
|
||||
a[0] == null ? null : TimeValue.timeValueMillis((long) a[0]),
|
||||
a[1] == null ? null : TimeValue.timeValueMillis((long) a[1])
|
||||
));
|
||||
|
||||
static {
|
||||
PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), ELAPSED_TIME);
|
||||
PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), ITERATION_TIME);
|
||||
}
|
||||
|
||||
private final TimeValue elapsedTime;
|
||||
private final TimeValue iterationTime;
|
||||
|
||||
public TimingStats(TimeValue elapsedTime, TimeValue iterationTime) {
|
||||
this.elapsedTime = elapsedTime;
|
||||
this.iterationTime = iterationTime;
|
||||
}
|
||||
|
||||
public TimeValue getElapsedTime() {
|
||||
return elapsedTime;
|
||||
}
|
||||
|
||||
public TimeValue getIterationTime() {
|
||||
return iterationTime;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
if (elapsedTime != null) {
|
||||
builder.humanReadableField(ELAPSED_TIME.getPreferredName(), ELAPSED_TIME.getPreferredName() + "_string", elapsedTime);
|
||||
}
|
||||
if (iterationTime != null) {
|
||||
builder.humanReadableField(ITERATION_TIME.getPreferredName(), ITERATION_TIME.getPreferredName() + "_string", iterationTime);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
TimingStats that = (TimingStats) o;
|
||||
return Objects.equals(elapsedTime, that.elapsedTime) && Objects.equals(iterationTime, that.iterationTime);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(elapsedTime, iterationTime);
|
||||
}
|
||||
}
|
@ -0,0 +1,87 @@
|
||||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.stats.regression;
|
||||
|
||||
import org.elasticsearch.client.ml.dataframe.stats.common.FoldValues;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
public class ValidationLoss implements ToXContentObject {
|
||||
|
||||
public static final ParseField LOSS_TYPE = new ParseField("loss_type");
|
||||
public static final ParseField FOLD_VALUES = new ParseField("fold_values");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public static ConstructingObjectParser<ValidationLoss, Void> PARSER = new ConstructingObjectParser<>("regression_validation_loss",
|
||||
true,
|
||||
a -> new ValidationLoss((String) a[0], (List<FoldValues>) a[1]));
|
||||
|
||||
static {
|
||||
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), LOSS_TYPE);
|
||||
PARSER.declareObjectArray(ConstructingObjectParser.optionalConstructorArg(), FoldValues.PARSER, FOLD_VALUES);
|
||||
}
|
||||
|
||||
private final String lossType;
|
||||
private final List<FoldValues> foldValues;
|
||||
|
||||
public ValidationLoss(String lossType, List<FoldValues> values) {
|
||||
this.lossType = lossType;
|
||||
this.foldValues = values;
|
||||
}
|
||||
|
||||
public String getLossType() {
|
||||
return lossType;
|
||||
}
|
||||
|
||||
public List<FoldValues> getFoldValues() {
|
||||
return foldValues;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
if (lossType != null) {
|
||||
builder.field(LOSS_TYPE.getPreferredName(), lossType);
|
||||
}
|
||||
if (foldValues != null) {
|
||||
builder.field(FOLD_VALUES.getPreferredName(), foldValues);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
ValidationLoss that = (ValidationLoss) o;
|
||||
return Objects.equals(lossType, that.lossType) && Objects.equals(foldValues, that.foldValues);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(lossType, foldValues);
|
||||
}
|
||||
}
|
@ -1,5 +1,6 @@
|
||||
org.elasticsearch.client.indexlifecycle.IndexLifecycleNamedXContentProvider
|
||||
org.elasticsearch.client.ml.dataframe.MlDataFrameAnalysisNamedXContentProvider
|
||||
org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider
|
||||
org.elasticsearch.client.ml.dataframe.stats.AnalysisStatsNamedXContentProvider
|
||||
org.elasticsearch.client.ml.inference.MlInferenceNamedXContentProvider
|
||||
org.elasticsearch.client.transform.TransformNamedXContentProvider
|
||||
|
@ -91,6 +91,7 @@ import org.elasticsearch.client.ml.datafeed.DatafeedConfigTests;
|
||||
import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig;
|
||||
import org.elasticsearch.client.ml.dataframe.MlDataFrameAnalysisNamedXContentProvider;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||
import org.elasticsearch.client.ml.dataframe.stats.AnalysisStatsNamedXContentProvider;
|
||||
import org.elasticsearch.client.ml.filestructurefinder.FileStructure;
|
||||
import org.elasticsearch.client.ml.inference.MlInferenceNamedXContentProvider;
|
||||
import org.elasticsearch.client.ml.inference.TrainedModelConfig;
|
||||
@ -1067,6 +1068,7 @@ public class MLRequestConvertersTests extends ESTestCase {
|
||||
namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers());
|
||||
namedXContent.addAll(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
|
||||
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
|
||||
namedXContent.addAll(new AnalysisStatsNamedXContentProvider().getNamedXContentParsers());
|
||||
return new NamedXContentRegistry(namedXContent);
|
||||
}
|
||||
|
||||
|
@ -20,7 +20,6 @@
|
||||
package org.elasticsearch.client;
|
||||
|
||||
import com.fasterxml.jackson.core.JsonParseException;
|
||||
|
||||
import org.apache.http.HttpEntity;
|
||||
import org.apache.http.HttpHost;
|
||||
import org.apache.http.HttpResponse;
|
||||
@ -69,6 +68,9 @@ import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.Binar
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.PrecisionMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.RecallMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.stats.classification.ClassificationStats;
|
||||
import org.elasticsearch.client.ml.dataframe.stats.outlierdetection.OutlierDetectionStats;
|
||||
import org.elasticsearch.client.ml.dataframe.stats.regression.RegressionStats;
|
||||
import org.elasticsearch.client.ml.inference.preprocessing.CustomWordEmbedding;
|
||||
import org.elasticsearch.client.ml.inference.preprocessing.FrequencyEncoding;
|
||||
import org.elasticsearch.client.ml.inference.preprocessing.OneHotEncoding;
|
||||
@ -697,7 +699,7 @@ public class RestHighLevelClientTests extends ESTestCase {
|
||||
|
||||
public void testProvidedNamedXContents() {
|
||||
List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
|
||||
assertEquals(59, namedXContents.size());
|
||||
assertEquals(62, namedXContents.size());
|
||||
Map<Class<?>, Integer> categories = new HashMap<>();
|
||||
List<String> names = new ArrayList<>();
|
||||
for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
|
||||
@ -707,7 +709,7 @@ public class RestHighLevelClientTests extends ESTestCase {
|
||||
categories.put(namedXContent.categoryClass, counter + 1);
|
||||
}
|
||||
}
|
||||
assertEquals("Had: " + categories, 12, categories.size());
|
||||
assertEquals("Had: " + categories, 13, categories.size());
|
||||
assertEquals(Integer.valueOf(3), categories.get(Aggregation.class));
|
||||
assertTrue(names.contains(ChildrenAggregationBuilder.NAME));
|
||||
assertTrue(names.contains(MatrixStatsAggregationBuilder.NAME));
|
||||
@ -737,6 +739,9 @@ public class RestHighLevelClientTests extends ESTestCase {
|
||||
assertTrue(names.contains(OutlierDetection.NAME.getPreferredName()));
|
||||
assertTrue(names.contains(org.elasticsearch.client.ml.dataframe.Regression.NAME.getPreferredName()));
|
||||
assertTrue(names.contains(org.elasticsearch.client.ml.dataframe.Classification.NAME.getPreferredName()));
|
||||
assertTrue(names.contains(OutlierDetectionStats.NAME.getPreferredName()));
|
||||
assertTrue(names.contains(RegressionStats.NAME.getPreferredName()));
|
||||
assertTrue(names.contains(ClassificationStats.NAME.getPreferredName()));
|
||||
assertEquals(Integer.valueOf(1), categories.get(SyncConfig.class));
|
||||
assertTrue(names.contains(TimeSyncConfig.NAME));
|
||||
assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class));
|
||||
|
@ -20,6 +20,13 @@
|
||||
package org.elasticsearch.client.ml.dataframe;
|
||||
|
||||
import org.elasticsearch.client.ml.NodeAttributesTests;
|
||||
import org.elasticsearch.client.ml.dataframe.stats.AnalysisStats;
|
||||
import org.elasticsearch.client.ml.dataframe.stats.AnalysisStatsNamedXContentProvider;
|
||||
import org.elasticsearch.client.ml.dataframe.stats.classification.ClassificationStatsTests;
|
||||
import org.elasticsearch.client.ml.dataframe.stats.common.MemoryUsageTests;
|
||||
import org.elasticsearch.client.ml.dataframe.stats.outlierdetection.OutlierDetectionStatsTests;
|
||||
import org.elasticsearch.client.ml.dataframe.stats.regression.RegressionStatsTests;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
|
||||
@ -31,23 +38,38 @@ import static org.elasticsearch.test.AbstractXContentTestCase.xContentTester;
|
||||
|
||||
public class DataFrameAnalyticsStatsTests extends ESTestCase {
|
||||
|
||||
@Override
|
||||
protected NamedXContentRegistry xContentRegistry() {
|
||||
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
|
||||
namedXContent.addAll(new AnalysisStatsNamedXContentProvider().getNamedXContentParsers());
|
||||
return new NamedXContentRegistry(namedXContent);
|
||||
}
|
||||
|
||||
public void testFromXContent() throws IOException {
|
||||
xContentTester(this::createParser,
|
||||
DataFrameAnalyticsStatsTests::randomDataFrameAnalyticsStats,
|
||||
DataFrameAnalyticsStatsTests::toXContent,
|
||||
DataFrameAnalyticsStats::fromXContent)
|
||||
.supportsUnknownFields(true)
|
||||
.randomFieldsExcludeFilter(field -> field.startsWith("node.attributes"))
|
||||
.randomFieldsExcludeFilter(field -> field.startsWith("node.attributes") || field.startsWith("analysis_stats"))
|
||||
.test();
|
||||
}
|
||||
|
||||
public static DataFrameAnalyticsStats randomDataFrameAnalyticsStats() {
|
||||
AnalysisStats analysisStats = randomBoolean() ? null :
|
||||
randomFrom(
|
||||
ClassificationStatsTests.createRandom(),
|
||||
OutlierDetectionStatsTests.createRandom(),
|
||||
RegressionStatsTests.createRandom()
|
||||
);
|
||||
|
||||
return new DataFrameAnalyticsStats(
|
||||
randomAlphaOfLengthBetween(1, 10),
|
||||
randomFrom(DataFrameAnalyticsState.values()),
|
||||
randomBoolean() ? null : randomAlphaOfLength(10),
|
||||
randomBoolean() ? null : createRandomProgress(),
|
||||
randomBoolean() ? null : MemoryUsageTests.createRandom(),
|
||||
analysisStats,
|
||||
randomBoolean() ? null : NodeAttributesTests.createRandom(),
|
||||
randomBoolean() ? null : randomAlphaOfLengthBetween(1, 20));
|
||||
}
|
||||
@ -74,6 +96,11 @@ public class DataFrameAnalyticsStatsTests extends ESTestCase {
|
||||
if (stats.getMemoryUsage() != null) {
|
||||
builder.field(DataFrameAnalyticsStats.MEMORY_USAGE.getPreferredName(), stats.getMemoryUsage());
|
||||
}
|
||||
if (stats.getAnalysisStats() != null) {
|
||||
builder.startObject("analysis_stats");
|
||||
builder.field(stats.getAnalysisStats().getName(), stats.getAnalysisStats());
|
||||
builder.endObject();
|
||||
}
|
||||
if (stats.getNode() != null) {
|
||||
builder.field(DataFrameAnalyticsStats.NODE.getPreferredName(), stats.getNode());
|
||||
}
|
||||
|
@ -0,0 +1,53 @@
|
||||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.stats.classification;
|
||||
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.time.Instant;
|
||||
|
||||
public class ClassificationStatsTests extends AbstractXContentTestCase<ClassificationStats> {
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected ClassificationStats doParseInstance(XContentParser parser) throws IOException {
|
||||
return ClassificationStats.PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected ClassificationStats createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
public static ClassificationStats createRandom() {
|
||||
return new ClassificationStats(
|
||||
Instant.now(),
|
||||
randomBoolean() ? null : randomIntBetween(1, Integer.MAX_VALUE),
|
||||
HyperparametersTests.createRandom(),
|
||||
TimingStatsTests.createRandom(),
|
||||
ValidationLossTests.createRandom()
|
||||
);
|
||||
}
|
||||
}
|
@ -0,0 +1,62 @@
|
||||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.stats.classification;
|
||||
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public class HyperparametersTests extends AbstractXContentTestCase<Hyperparameters> {
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Hyperparameters doParseInstance(XContentParser parser) throws IOException {
|
||||
return Hyperparameters.PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Hyperparameters createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
public static Hyperparameters createRandom() {
|
||||
return new Hyperparameters(
|
||||
randomBoolean() ? null : randomAlphaOfLength(10),
|
||||
randomBoolean() ? null : randomDouble(),
|
||||
randomBoolean() ? null : randomDouble(),
|
||||
randomBoolean() ? null : randomDouble(),
|
||||
randomBoolean() ? null : randomDouble(),
|
||||
randomBoolean() ? null : randomIntBetween(0, Integer.MAX_VALUE),
|
||||
randomBoolean() ? null : randomIntBetween(0, Integer.MAX_VALUE),
|
||||
randomBoolean() ? null : randomIntBetween(0, Integer.MAX_VALUE),
|
||||
randomBoolean() ? null : randomIntBetween(0, Integer.MAX_VALUE),
|
||||
randomBoolean() ? null : randomIntBetween(0, Integer.MAX_VALUE),
|
||||
randomBoolean() ? null : randomDouble(),
|
||||
randomBoolean() ? null : randomDouble(),
|
||||
randomBoolean() ? null : randomDouble(),
|
||||
randomBoolean() ? null : randomDouble(),
|
||||
randomBoolean() ? null : randomDouble()
|
||||
);
|
||||
}
|
||||
}
|
@ -0,0 +1,50 @@
|
||||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.stats.classification;
|
||||
|
||||
import org.elasticsearch.common.unit.TimeValue;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public class TimingStatsTests extends AbstractXContentTestCase<TimingStats> {
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected TimingStats doParseInstance(XContentParser parser) throws IOException {
|
||||
return TimingStats.PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected TimingStats createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
public static TimingStats createRandom() {
|
||||
return new TimingStats(
|
||||
randomBoolean() ? null : TimeValue.timeValueMillis(randomNonNegativeLong()),
|
||||
randomBoolean() ? null : TimeValue.timeValueMillis(randomNonNegativeLong())
|
||||
);
|
||||
}
|
||||
}
|
@ -0,0 +1,50 @@
|
||||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.stats.classification;
|
||||
|
||||
import org.elasticsearch.client.ml.dataframe.stats.common.FoldValuesTests;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public class ValidationLossTests extends AbstractXContentTestCase<ValidationLoss> {
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected ValidationLoss doParseInstance(XContentParser parser) throws IOException {
|
||||
return ValidationLoss.PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected ValidationLoss createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
public static ValidationLoss createRandom() {
|
||||
return new ValidationLoss(
|
||||
randomBoolean() ? null : randomAlphaOfLength(10),
|
||||
randomBoolean() ? null : randomList(5, () -> FoldValuesTests.createRandom())
|
||||
);
|
||||
}
|
||||
}
|
@ -0,0 +1,51 @@
|
||||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.stats.common;
|
||||
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public class FoldValuesTests extends AbstractXContentTestCase<FoldValues> {
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected FoldValues doParseInstance(XContentParser parser) throws IOException {
|
||||
return FoldValues.PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected FoldValues createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
public static FoldValues createRandom() {
|
||||
int valuesSize = randomIntBetween(0, 10);
|
||||
double[] values = new double[valuesSize];
|
||||
for (int i = 0; i < valuesSize; i++) {
|
||||
values[i] = randomDouble();
|
||||
}
|
||||
return new FoldValues(randomIntBetween(0, Integer.MAX_VALUE), values);
|
||||
}
|
||||
}
|
@ -16,7 +16,7 @@
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe;
|
||||
package org.elasticsearch.client.ml.dataframe.stats.common;
|
||||
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
@ -0,0 +1,51 @@
|
||||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.stats.outlierdetection;
|
||||
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.time.Instant;
|
||||
|
||||
public class OutlierDetectionStatsTests extends AbstractXContentTestCase<OutlierDetectionStats> {
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected OutlierDetectionStats doParseInstance(XContentParser parser) throws IOException {
|
||||
return OutlierDetectionStats.PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected OutlierDetectionStats createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
public static OutlierDetectionStats createRandom() {
|
||||
return new OutlierDetectionStats(
|
||||
Instant.now(),
|
||||
ParametersTests.createRandom(),
|
||||
TimingStatsTests.createRandom()
|
||||
);
|
||||
}
|
||||
}
|
@ -0,0 +1,53 @@
|
||||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.stats.outlierdetection;
|
||||
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public class ParametersTests extends AbstractXContentTestCase<Parameters> {
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Parameters doParseInstance(XContentParser parser) throws IOException {
|
||||
return Parameters.PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Parameters createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
public static Parameters createRandom() {
|
||||
return new Parameters(
|
||||
randomBoolean() ? null : randomIntBetween(1, Integer.MAX_VALUE),
|
||||
randomBoolean() ? null : randomAlphaOfLength(5),
|
||||
randomBoolean() ? null : randomBoolean(),
|
||||
randomBoolean() ? null : randomDouble(),
|
||||
randomBoolean() ? null : randomDouble(),
|
||||
randomBoolean() ? null : randomBoolean()
|
||||
);
|
||||
}
|
||||
}
|
@ -0,0 +1,48 @@
|
||||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.stats.outlierdetection;
|
||||
|
||||
import org.elasticsearch.common.unit.TimeValue;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public class TimingStatsTests extends AbstractXContentTestCase<TimingStats> {
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
protected TimingStats doParseInstance(XContentParser parser) throws IOException {
|
||||
return TimingStats.PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected TimingStats createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
public static TimingStats createRandom() {
|
||||
return new TimingStats(randomBoolean() ? null : TimeValue.timeValueMillis(randomNonNegativeLong()));
|
||||
}
|
||||
}
|
@ -0,0 +1,62 @@
|
||||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.stats.regression;
|
||||
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public class HyperparametersTests extends AbstractXContentTestCase<Hyperparameters> {
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Hyperparameters doParseInstance(XContentParser parser) throws IOException {
|
||||
return Hyperparameters.PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
protected Hyperparameters createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
public static Hyperparameters createRandom() {
|
||||
return new Hyperparameters(
|
||||
randomBoolean() ? null : randomDouble(),
|
||||
randomBoolean() ? null : randomDouble(),
|
||||
randomBoolean() ? null : randomDouble(),
|
||||
randomBoolean() ? null : randomDouble(),
|
||||
randomBoolean() ? null : randomIntBetween(0, Integer.MAX_VALUE),
|
||||
randomBoolean() ? null : randomIntBetween(0, Integer.MAX_VALUE),
|
||||
randomBoolean() ? null : randomIntBetween(0, Integer.MAX_VALUE),
|
||||
randomBoolean() ? null : randomIntBetween(0, Integer.MAX_VALUE),
|
||||
randomBoolean() ? null : randomIntBetween(0, Integer.MAX_VALUE),
|
||||
randomBoolean() ? null : randomDouble(),
|
||||
randomBoolean() ? null : randomDouble(),
|
||||
randomBoolean() ? null : randomDouble(),
|
||||
randomBoolean() ? null : randomDouble(),
|
||||
randomBoolean() ? null : randomDouble()
|
||||
);
|
||||
}
|
||||
}
|
@ -0,0 +1,54 @@
|
||||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.stats.regression;
|
||||
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.time.Instant;
|
||||
|
||||
public class RegressionStatsTests extends AbstractXContentTestCase<RegressionStats> {
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected RegressionStats doParseInstance(XContentParser parser) throws IOException {
|
||||
return RegressionStats.PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
protected RegressionStats createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
public static RegressionStats createRandom() {
|
||||
return new RegressionStats(
|
||||
Instant.now(),
|
||||
randomBoolean() ? null : randomIntBetween(1, Integer.MAX_VALUE),
|
||||
HyperparametersTests.createRandom(),
|
||||
TimingStatsTests.createRandom(),
|
||||
ValidationLossTests.createRandom()
|
||||
);
|
||||
}
|
||||
}
|
@ -0,0 +1,50 @@
|
||||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.stats.regression;
|
||||
|
||||
import org.elasticsearch.common.unit.TimeValue;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public class TimingStatsTests extends AbstractXContentTestCase<TimingStats> {
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected TimingStats doParseInstance(XContentParser parser) throws IOException {
|
||||
return TimingStats.PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected TimingStats createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
public static TimingStats createRandom() {
|
||||
return new TimingStats(
|
||||
randomBoolean() ? null : TimeValue.timeValueMillis(randomNonNegativeLong()),
|
||||
randomBoolean() ? null : TimeValue.timeValueMillis(randomNonNegativeLong())
|
||||
);
|
||||
}
|
||||
}
|
@ -0,0 +1,50 @@
|
||||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.stats.regression;
|
||||
|
||||
import org.elasticsearch.client.ml.dataframe.stats.common.FoldValuesTests;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public class ValidationLossTests extends AbstractXContentTestCase<ValidationLoss> {
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected ValidationLoss doParseInstance(XContentParser parser) throws IOException {
|
||||
return ValidationLoss.PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected ValidationLoss createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
public static ValidationLoss createRandom() {
|
||||
return new ValidationLoss(
|
||||
randomBoolean() ? null : randomAlphaOfLength(10),
|
||||
randomBoolean() ? null : randomList(5, () -> FoldValuesTests.createRandom())
|
||||
);
|
||||
}
|
||||
}
|
@ -39,18 +39,15 @@ import org.elasticsearch.transport.Transport;
|
||||
import org.elasticsearch.xpack.core.action.XPackInfoAction;
|
||||
import org.elasticsearch.xpack.core.action.XPackUsageAction;
|
||||
import org.elasticsearch.xpack.core.analytics.AnalyticsFeatureSetUsage;
|
||||
import org.elasticsearch.xpack.core.search.action.DeleteAsyncSearchAction;
|
||||
import org.elasticsearch.xpack.core.search.action.GetAsyncSearchAction;
|
||||
import org.elasticsearch.xpack.core.search.action.SubmitAsyncSearchAction;
|
||||
import org.elasticsearch.xpack.core.ccr.AutoFollowMetadata;
|
||||
import org.elasticsearch.xpack.core.ccr.CCRFeatureSet;
|
||||
import org.elasticsearch.xpack.core.deprecation.DeprecationInfoAction;
|
||||
import org.elasticsearch.xpack.core.enrich.EnrichFeatureSet;
|
||||
import org.elasticsearch.xpack.core.eql.EqlFeatureSetUsage;
|
||||
import org.elasticsearch.xpack.core.enrich.action.DeleteEnrichPolicyAction;
|
||||
import org.elasticsearch.xpack.core.enrich.action.ExecuteEnrichPolicyAction;
|
||||
import org.elasticsearch.xpack.core.enrich.action.GetEnrichPolicyAction;
|
||||
import org.elasticsearch.xpack.core.enrich.action.PutEnrichPolicyAction;
|
||||
import org.elasticsearch.xpack.core.eql.EqlFeatureSetUsage;
|
||||
import org.elasticsearch.xpack.core.flattened.FlattenedFeatureSetUsage;
|
||||
import org.elasticsearch.xpack.core.frozen.FrozenIndicesFeatureSetUsage;
|
||||
import org.elasticsearch.xpack.core.frozen.action.FreezeIndexAction;
|
||||
@ -152,6 +149,10 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStats;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStats;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStats;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStats;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.CustomWordEmbedding;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
|
||||
@ -184,6 +185,9 @@ import org.elasticsearch.xpack.core.rollup.action.StartRollupJobAction;
|
||||
import org.elasticsearch.xpack.core.rollup.action.StopRollupJobAction;
|
||||
import org.elasticsearch.xpack.core.rollup.job.RollupJob;
|
||||
import org.elasticsearch.xpack.core.rollup.job.RollupJobStatus;
|
||||
import org.elasticsearch.xpack.core.search.action.DeleteAsyncSearchAction;
|
||||
import org.elasticsearch.xpack.core.search.action.GetAsyncSearchAction;
|
||||
import org.elasticsearch.xpack.core.search.action.SubmitAsyncSearchAction;
|
||||
import org.elasticsearch.xpack.core.security.SecurityFeatureSetUsage;
|
||||
import org.elasticsearch.xpack.core.security.SecurityField;
|
||||
import org.elasticsearch.xpack.core.security.SecuritySettings;
|
||||
@ -505,6 +509,9 @@ public class XPackClientPlugin extends Plugin implements ActionPlugin, NetworkPl
|
||||
new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, OutlierDetection.NAME.getPreferredName(), OutlierDetection::new),
|
||||
new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, Regression.NAME.getPreferredName(), Regression::new),
|
||||
new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, Classification.NAME.getPreferredName(), Classification::new),
|
||||
new NamedWriteableRegistry.Entry(AnalysisStats.class, OutlierDetectionStats.TYPE_VALUE, OutlierDetectionStats::new),
|
||||
new NamedWriteableRegistry.Entry(AnalysisStats.class, RegressionStats.TYPE_VALUE, RegressionStats::new),
|
||||
new NamedWriteableRegistry.Entry(AnalysisStats.class, ClassificationStats.TYPE_VALUE, ClassificationStats::new),
|
||||
// ML - Inference preprocessing
|
||||
new NamedWriteableRegistry.Entry(PreProcessor.class, FrequencyEncoding.NAME.getPreferredName(), FrequencyEncoding::new),
|
||||
new NamedWriteableRegistry.Entry(PreProcessor.class, OneHotEncoding.NAME.getPreferredName(), OneHotEncoding::new),
|
||||
|
@ -28,6 +28,7 @@ import org.elasticsearch.xpack.core.action.util.PageParams;
|
||||
import org.elasticsearch.xpack.core.action.util.QueryPage;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStats;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.MemoryUsage;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
|
||||
@ -167,18 +168,23 @@ public class GetDataFrameAnalyticsStatsAction extends ActionType<GetDataFrameAna
|
||||
@Nullable
|
||||
private final MemoryUsage memoryUsage;
|
||||
|
||||
@Nullable
|
||||
private final AnalysisStats analysisStats;
|
||||
|
||||
@Nullable
|
||||
private final DiscoveryNode node;
|
||||
@Nullable
|
||||
private final String assignmentExplanation;
|
||||
|
||||
public Stats(String id, DataFrameAnalyticsState state, @Nullable String failureReason, List<PhaseProgress> progress,
|
||||
@Nullable MemoryUsage memoryUsage, @Nullable DiscoveryNode node, @Nullable String assignmentExplanation) {
|
||||
@Nullable MemoryUsage memoryUsage, @Nullable AnalysisStats analysisStats, @Nullable DiscoveryNode node,
|
||||
@Nullable String assignmentExplanation) {
|
||||
this.id = Objects.requireNonNull(id);
|
||||
this.state = Objects.requireNonNull(state);
|
||||
this.failureReason = failureReason;
|
||||
this.progress = Objects.requireNonNull(progress);
|
||||
this.memoryUsage = memoryUsage;
|
||||
this.analysisStats = analysisStats;
|
||||
this.node = node;
|
||||
this.assignmentExplanation = assignmentExplanation;
|
||||
}
|
||||
@ -197,6 +203,11 @@ public class GetDataFrameAnalyticsStatsAction extends ActionType<GetDataFrameAna
|
||||
} else {
|
||||
memoryUsage = null;
|
||||
}
|
||||
if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
|
||||
analysisStats = in.readOptionalNamedWriteable(AnalysisStats.class);
|
||||
} else {
|
||||
analysisStats = null;
|
||||
}
|
||||
node = in.readOptionalWriteable(DiscoveryNode::new);
|
||||
assignmentExplanation = in.readOptionalString();
|
||||
}
|
||||
@ -285,6 +296,11 @@ public class GetDataFrameAnalyticsStatsAction extends ActionType<GetDataFrameAna
|
||||
if (memoryUsage != null) {
|
||||
builder.field("memory_usage", memoryUsage);
|
||||
}
|
||||
if (analysisStats != null) {
|
||||
builder.startObject("analysis_stats");
|
||||
builder.field(analysisStats.getWriteableName(), analysisStats);
|
||||
builder.endObject();
|
||||
}
|
||||
if (node != null) {
|
||||
builder.startObject("node");
|
||||
builder.field("id", node.getId());
|
||||
@ -318,6 +334,9 @@ public class GetDataFrameAnalyticsStatsAction extends ActionType<GetDataFrameAna
|
||||
if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
|
||||
out.writeOptionalWriteable(memoryUsage);
|
||||
}
|
||||
if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
|
||||
out.writeOptionalNamedWriteable(analysisStats);
|
||||
}
|
||||
out.writeOptionalWriteable(node);
|
||||
out.writeOptionalString(assignmentExplanation);
|
||||
}
|
||||
@ -350,7 +369,7 @@ public class GetDataFrameAnalyticsStatsAction extends ActionType<GetDataFrameAna
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(id, state, failureReason, progress, memoryUsage, node, assignmentExplanation);
|
||||
return Objects.hash(id, state, failureReason, progress, memoryUsage, analysisStats, node, assignmentExplanation);
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -367,6 +386,7 @@ public class GetDataFrameAnalyticsStatsAction extends ActionType<GetDataFrameAna
|
||||
&& Objects.equals(this.failureReason, other.failureReason)
|
||||
&& Objects.equals(this.progress, other.progress)
|
||||
&& Objects.equals(this.memoryUsage, other.memoryUsage)
|
||||
&& Objects.equals(this.analysisStats, other.analysisStats)
|
||||
&& Objects.equals(this.node, other.node)
|
||||
&& Objects.equals(this.assignmentExplanation, other.assignmentExplanation);
|
||||
}
|
||||
|
@ -0,0 +1,15 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.stats;
|
||||
|
||||
import org.elasticsearch.common.io.stream.NamedWriteable;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
|
||||
/**
|
||||
* Statistics for the data frame analysis
|
||||
*/
|
||||
public interface AnalysisStats extends ToXContentObject, NamedWriteable {
|
||||
}
|
@ -0,0 +1,25 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.stats;
|
||||
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStats;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStats;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStats;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
public class AnalysisStatsNamedWriteablesProvider {
|
||||
|
||||
public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
|
||||
return Arrays.asList(
|
||||
new NamedWriteableRegistry.Entry(AnalysisStats.class, ClassificationStats.TYPE_VALUE, ClassificationStats::new),
|
||||
new NamedWriteableRegistry.Entry(AnalysisStats.class, OutlierDetectionStats.TYPE_VALUE, OutlierDetectionStats::new),
|
||||
new NamedWriteableRegistry.Entry(AnalysisStats.class, RegressionStats.TYPE_VALUE, RegressionStats::new)
|
||||
);
|
||||
}
|
||||
}
|
@ -0,0 +1,20 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.stats;
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
|
||||
/**
|
||||
* A collection of parse fields commonly used by stats objects
|
||||
*/
|
||||
public final class Fields {
|
||||
|
||||
public static final ParseField TYPE = new ParseField("type");
|
||||
public static final ParseField JOB_ID = new ParseField("job_id");
|
||||
public static final ParseField TIMESTAMP = new ParseField("timestamp");
|
||||
|
||||
private Fields() {}
|
||||
}
|
@ -26,9 +26,6 @@ public class MemoryUsage implements Writeable, ToXContentObject {
|
||||
|
||||
public static final String TYPE_VALUE = "analytics_memory_usage";
|
||||
|
||||
public static final ParseField TYPE = new ParseField("type");
|
||||
public static final ParseField JOB_ID = new ParseField("job_id");
|
||||
public static final ParseField TIMESTAMP = new ParseField("timestamp");
|
||||
public static final ParseField PEAK_USAGE_BYTES = new ParseField("peak_usage_bytes");
|
||||
|
||||
public static final ConstructingObjectParser<MemoryUsage, Void> STRICT_PARSER = createParser(false);
|
||||
@ -38,11 +35,11 @@ public class MemoryUsage implements Writeable, ToXContentObject {
|
||||
ConstructingObjectParser<MemoryUsage, Void> parser = new ConstructingObjectParser<>(TYPE_VALUE,
|
||||
ignoreUnknownFields, a -> new MemoryUsage((String) a[0], (Instant) a[1], (long) a[2]));
|
||||
|
||||
parser.declareString((bucket, s) -> {}, TYPE);
|
||||
parser.declareString(ConstructingObjectParser.constructorArg(), JOB_ID);
|
||||
parser.declareString((bucket, s) -> {}, Fields.TYPE);
|
||||
parser.declareString(ConstructingObjectParser.constructorArg(), Fields.JOB_ID);
|
||||
parser.declareField(ConstructingObjectParser.constructorArg(),
|
||||
p -> TimeUtils.parseTimeFieldToInstant(p, TIMESTAMP.getPreferredName()),
|
||||
TIMESTAMP,
|
||||
p -> TimeUtils.parseTimeFieldToInstant(p, Fields.TIMESTAMP.getPreferredName()),
|
||||
Fields.TIMESTAMP,
|
||||
ObjectParser.ValueType.VALUE);
|
||||
parser.declareLong(ConstructingObjectParser.constructorArg(), PEAK_USAGE_BYTES);
|
||||
return parser;
|
||||
@ -56,7 +53,7 @@ public class MemoryUsage implements Writeable, ToXContentObject {
|
||||
this.jobId = Objects.requireNonNull(jobId);
|
||||
// We intend to store this timestamp in millis granularity. Thus we're rounding here to ensure
|
||||
// internal representation matches toXContent
|
||||
this.timestamp = Instant.ofEpochMilli(ExceptionsHelper.requireNonNull(timestamp, TIMESTAMP).toEpochMilli());
|
||||
this.timestamp = Instant.ofEpochMilli(ExceptionsHelper.requireNonNull(timestamp, Fields.TIMESTAMP).toEpochMilli());
|
||||
this.peakUsageBytes = peakUsageBytes;
|
||||
}
|
||||
|
||||
@ -77,10 +74,10 @@ public class MemoryUsage implements Writeable, ToXContentObject {
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
if (params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false)) {
|
||||
builder.field(TYPE.getPreferredName(), TYPE_VALUE);
|
||||
builder.field(JOB_ID.getPreferredName(), jobId);
|
||||
builder.field(Fields.TYPE.getPreferredName(), TYPE_VALUE);
|
||||
builder.field(Fields.JOB_ID.getPreferredName(), jobId);
|
||||
}
|
||||
builder.timeField(TIMESTAMP.getPreferredName(), TIMESTAMP.getPreferredName() + "_string", timestamp.toEpochMilli());
|
||||
builder.timeField(Fields.TIMESTAMP.getPreferredName(), Fields.TIMESTAMP.getPreferredName() + "_string", timestamp.toEpochMilli());
|
||||
builder.field(PEAK_USAGE_BYTES.getPreferredName(), peakUsageBytes);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
|
@ -0,0 +1,148 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.stats.classification;
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xpack.core.common.time.TimeUtils;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStats;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.Fields;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.time.Instant;
|
||||
import java.util.Objects;
|
||||
|
||||
public class ClassificationStats implements AnalysisStats {
|
||||
|
||||
public static final String TYPE_VALUE = "classification_stats";
|
||||
|
||||
public static final ParseField ITERATION = new ParseField("iteration");
|
||||
public static final ParseField HYPERPARAMETERS = new ParseField("hyperparameters");
|
||||
public static final ParseField TIMING_STATS = new ParseField("timing_stats");
|
||||
public static final ParseField VALIDATION_LOSS = new ParseField("validation_loss");
|
||||
|
||||
public static final ConstructingObjectParser<ClassificationStats, Void> STRICT_PARSER = createParser(false);
|
||||
public static final ConstructingObjectParser<ClassificationStats, Void> LENIENT_PARSER = createParser(true);
|
||||
|
||||
private static ConstructingObjectParser<ClassificationStats, Void> createParser(boolean ignoreUnknownFields) {
|
||||
ConstructingObjectParser<ClassificationStats, Void> parser = new ConstructingObjectParser<>(TYPE_VALUE, ignoreUnknownFields,
|
||||
a -> new ClassificationStats(
|
||||
(String) a[0],
|
||||
(Instant) a[1],
|
||||
(int) a[2],
|
||||
(Hyperparameters) a[3],
|
||||
(TimingStats) a[4],
|
||||
(ValidationLoss) a[5]
|
||||
)
|
||||
);
|
||||
|
||||
parser.declareString((bucket, s) -> {}, Fields.TYPE);
|
||||
parser.declareString(ConstructingObjectParser.constructorArg(), Fields.JOB_ID);
|
||||
parser.declareField(ConstructingObjectParser.constructorArg(),
|
||||
p -> TimeUtils.parseTimeFieldToInstant(p, Fields.TIMESTAMP.getPreferredName()),
|
||||
Fields.TIMESTAMP,
|
||||
ObjectParser.ValueType.VALUE);
|
||||
parser.declareInt(ConstructingObjectParser.constructorArg(), ITERATION);
|
||||
parser.declareObject(ConstructingObjectParser.constructorArg(),
|
||||
(p, c) -> Hyperparameters.fromXContent(p, ignoreUnknownFields), HYPERPARAMETERS);
|
||||
parser.declareObject(ConstructingObjectParser.constructorArg(),
|
||||
(p, c) -> TimingStats.fromXContent(p, ignoreUnknownFields), TIMING_STATS);
|
||||
parser.declareObject(ConstructingObjectParser.constructorArg(),
|
||||
(p, c) -> ValidationLoss.fromXContent(p, ignoreUnknownFields), VALIDATION_LOSS);
|
||||
return parser;
|
||||
}
|
||||
|
||||
private final String jobId;
|
||||
private final Instant timestamp;
|
||||
private final int iteration;
|
||||
private final Hyperparameters hyperparameters;
|
||||
private final TimingStats timingStats;
|
||||
private final ValidationLoss validationLoss;
|
||||
|
||||
public ClassificationStats(String jobId, Instant timestamp, int iteration, Hyperparameters hyperparameters, TimingStats timingStats,
|
||||
ValidationLoss validationLoss) {
|
||||
this.jobId = Objects.requireNonNull(jobId);
|
||||
// We intend to store this timestamp in millis granularity. Thus we're rounding here to ensure
|
||||
// internal representation matches toXContent
|
||||
this.timestamp = Instant.ofEpochMilli(ExceptionsHelper.requireNonNull(timestamp, Fields.TIMESTAMP).toEpochMilli());
|
||||
this.iteration = iteration;
|
||||
this.hyperparameters = Objects.requireNonNull(hyperparameters);
|
||||
this.timingStats = Objects.requireNonNull(timingStats);
|
||||
this.validationLoss = Objects.requireNonNull(validationLoss);
|
||||
}
|
||||
|
||||
public ClassificationStats(StreamInput in) throws IOException {
|
||||
this.jobId = in.readString();
|
||||
this.timestamp = in.readInstant();
|
||||
this.iteration = in.readVInt();
|
||||
this.hyperparameters = new Hyperparameters(in);
|
||||
this.timingStats = new TimingStats(in);
|
||||
this.validationLoss = new ValidationLoss(in);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return TYPE_VALUE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeString(jobId);
|
||||
out.writeInstant(timestamp);
|
||||
out.writeVInt(iteration);
|
||||
hyperparameters.writeTo(out);
|
||||
timingStats.writeTo(out);
|
||||
validationLoss.writeTo(out);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
if (params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false)) {
|
||||
builder.field(Fields.TYPE.getPreferredName(), TYPE_VALUE);
|
||||
builder.field(Fields.JOB_ID.getPreferredName(), jobId);
|
||||
}
|
||||
builder.timeField(Fields.TIMESTAMP.getPreferredName(), Fields.TIMESTAMP.getPreferredName() + "_string", timestamp.toEpochMilli());
|
||||
builder.field(ITERATION.getPreferredName(), iteration);
|
||||
builder.field(HYPERPARAMETERS.getPreferredName(), hyperparameters);
|
||||
builder.field(TIMING_STATS.getPreferredName(), timingStats);
|
||||
builder.field(VALIDATION_LOSS.getPreferredName(), validationLoss);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
ClassificationStats that = (ClassificationStats) o;
|
||||
return Objects.equals(jobId, that.jobId)
|
||||
&& Objects.equals(timestamp, that.timestamp)
|
||||
&& iteration == that.iteration
|
||||
&& Objects.equals(hyperparameters, that.hyperparameters)
|
||||
&& Objects.equals(timingStats, that.timingStats)
|
||||
&& Objects.equals(validationLoss, that.validationLoss);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(jobId, timestamp, iteration, hyperparameters, timingStats, validationLoss);
|
||||
}
|
||||
|
||||
public String documentId(String jobId) {
|
||||
return documentIdPrefix(jobId) + timestamp.toEpochMilli();
|
||||
}
|
||||
|
||||
public static String documentIdPrefix(String jobId) {
|
||||
return TYPE_VALUE + "_" + jobId + "_";
|
||||
}
|
||||
}
|
@ -0,0 +1,237 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.stats.classification;
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
|
||||
public class Hyperparameters implements ToXContentObject, Writeable {
|
||||
|
||||
public static final ParseField CLASS_ASSIGNMENT_OBJECTIVE = new ParseField("class_assignment_objective");
|
||||
public static final ParseField DOWNSAMPLE_FACTOR = new ParseField("downsample_factor");
|
||||
public static final ParseField ETA = new ParseField("eta");
|
||||
public static final ParseField ETA_GROWTH_RATE_PER_TREE = new ParseField("eta_growth_rate_per_tree");
|
||||
public static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction");
|
||||
public static final ParseField MAX_ATTEMPTS_TO_ADD_TREE = new ParseField("max_attempts_to_add_tree");
|
||||
public static final ParseField MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER = new ParseField(
|
||||
"max_optimization_rounds_per_hyperparameter");
|
||||
public static final ParseField MAX_TREES = new ParseField("max_trees");
|
||||
public static final ParseField NUM_FOLDS = new ParseField("num_folds");
|
||||
public static final ParseField NUM_SPLITS_PER_FEATURE = new ParseField("num_splits_per_feature");
|
||||
public static final ParseField REGULARIZATION_DEPTH_PENALTY_MULTIPLIER = new ParseField("regularization_depth_penalty_multiplier");
|
||||
public static final ParseField REGULARIZATION_LEAF_WEIGHT_PENALTY_MULTIPLIER
|
||||
= new ParseField("regularization_leaf_weight_penalty_multiplier");
|
||||
public static final ParseField REGULARIZATION_SOFT_TREE_DEPTH_LIMIT = new ParseField("regularization_soft_tree_depth_limit");
|
||||
public static final ParseField REGULARIZATION_SOFT_TREE_DEPTH_TOLERANCE = new ParseField("regularization_soft_tree_depth_tolerance");
|
||||
public static final ParseField REGULARIZATION_TREE_SIZE_PENALTY_MULTIPLIER =
|
||||
new ParseField("regularization_tree_size_penalty_multiplier");
|
||||
|
||||
public static Hyperparameters fromXContent(XContentParser parser, boolean ignoreUnknownFields) {
|
||||
return createParser(ignoreUnknownFields).apply(parser, null);
|
||||
}
|
||||
|
||||
private static ConstructingObjectParser<Hyperparameters, Void> createParser(boolean ignoreUnknownFields) {
|
||||
ConstructingObjectParser<Hyperparameters, Void> parser = new ConstructingObjectParser<>("classification_hyperparameters",
|
||||
ignoreUnknownFields,
|
||||
a -> new Hyperparameters(
|
||||
(String) a[0],
|
||||
(double) a[1],
|
||||
(double) a[2],
|
||||
(double) a[3],
|
||||
(double) a[4],
|
||||
(int) a[5],
|
||||
(int) a[6],
|
||||
(int) a[7],
|
||||
(int) a[8],
|
||||
(int) a[9],
|
||||
(double) a[10],
|
||||
(double) a[11],
|
||||
(double) a[12],
|
||||
(double) a[13],
|
||||
(double) a[14]
|
||||
));
|
||||
|
||||
parser.declareString(constructorArg(), CLASS_ASSIGNMENT_OBJECTIVE);
|
||||
parser.declareDouble(constructorArg(), DOWNSAMPLE_FACTOR);
|
||||
parser.declareDouble(constructorArg(), ETA);
|
||||
parser.declareDouble(constructorArg(), ETA_GROWTH_RATE_PER_TREE);
|
||||
parser.declareDouble(constructorArg(), FEATURE_BAG_FRACTION);
|
||||
parser.declareInt(constructorArg(), MAX_ATTEMPTS_TO_ADD_TREE);
|
||||
parser.declareInt(constructorArg(), MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER);
|
||||
parser.declareInt(constructorArg(), MAX_TREES);
|
||||
parser.declareInt(constructorArg(), NUM_FOLDS);
|
||||
parser.declareInt(constructorArg(), NUM_SPLITS_PER_FEATURE);
|
||||
parser.declareDouble(constructorArg(), REGULARIZATION_DEPTH_PENALTY_MULTIPLIER);
|
||||
parser.declareDouble(constructorArg(), REGULARIZATION_LEAF_WEIGHT_PENALTY_MULTIPLIER);
|
||||
parser.declareDouble(constructorArg(), REGULARIZATION_SOFT_TREE_DEPTH_LIMIT);
|
||||
parser.declareDouble(constructorArg(), REGULARIZATION_SOFT_TREE_DEPTH_TOLERANCE);
|
||||
parser.declareDouble(constructorArg(), REGULARIZATION_TREE_SIZE_PENALTY_MULTIPLIER);
|
||||
|
||||
return parser;
|
||||
}
|
||||
|
||||
private final String classAssignmentObjective;
|
||||
private final double downsampleFactor;
|
||||
private final double eta;
|
||||
private final double etaGrowthRatePerTree;
|
||||
private final double featureBagFraction;
|
||||
private final int maxAttemptsToAddTree;
|
||||
private final int maxOptimizationRoundsPerHyperparameter;
|
||||
private final int maxTrees;
|
||||
private final int numFolds;
|
||||
private final int numSplitsPerFeature;
|
||||
private final double regularizationDepthPenaltyMultiplier;
|
||||
private final double regularizationLeafWeightPenaltyMultiplier;
|
||||
private final double regularizationSoftTreeDepthLimit;
|
||||
private final double regularizationSoftTreeDepthTolerance;
|
||||
private final double regularizationTreeSizePenaltyMultiplier;
|
||||
|
||||
public Hyperparameters(String classAssignmentObjective,
|
||||
double downsampleFactor,
|
||||
double eta,
|
||||
double etaGrowthRatePerTree,
|
||||
double featureBagFraction,
|
||||
int maxAttemptsToAddTree,
|
||||
int maxOptimizationRoundsPerHyperparameter,
|
||||
int maxTrees,
|
||||
int numFolds,
|
||||
int numSplitsPerFeature,
|
||||
double regularizationDepthPenaltyMultiplier,
|
||||
double regularizationLeafWeightPenaltyMultiplier,
|
||||
double regularizationSoftTreeDepthLimit,
|
||||
double regularizationSoftTreeDepthTolerance,
|
||||
double regularizationTreeSizePenaltyMultiplier) {
|
||||
this.classAssignmentObjective = Objects.requireNonNull(classAssignmentObjective);
|
||||
this.downsampleFactor = downsampleFactor;
|
||||
this.eta = eta;
|
||||
this.etaGrowthRatePerTree = etaGrowthRatePerTree;
|
||||
this.featureBagFraction = featureBagFraction;
|
||||
this.maxAttemptsToAddTree = maxAttemptsToAddTree;
|
||||
this.maxOptimizationRoundsPerHyperparameter = maxOptimizationRoundsPerHyperparameter;
|
||||
this.maxTrees = maxTrees;
|
||||
this.numFolds = numFolds;
|
||||
this.numSplitsPerFeature = numSplitsPerFeature;
|
||||
this.regularizationDepthPenaltyMultiplier = regularizationDepthPenaltyMultiplier;
|
||||
this.regularizationLeafWeightPenaltyMultiplier = regularizationLeafWeightPenaltyMultiplier;
|
||||
this.regularizationSoftTreeDepthLimit = regularizationSoftTreeDepthLimit;
|
||||
this.regularizationSoftTreeDepthTolerance = regularizationSoftTreeDepthTolerance;
|
||||
this.regularizationTreeSizePenaltyMultiplier = regularizationTreeSizePenaltyMultiplier;
|
||||
}
|
||||
|
||||
public Hyperparameters(StreamInput in) throws IOException {
|
||||
this.classAssignmentObjective = in.readString();
|
||||
this.downsampleFactor = in.readDouble();
|
||||
this.eta = in.readDouble();
|
||||
this.etaGrowthRatePerTree = in.readDouble();
|
||||
this.featureBagFraction = in.readDouble();
|
||||
this.maxAttemptsToAddTree = in.readVInt();
|
||||
this.maxOptimizationRoundsPerHyperparameter = in.readVInt();
|
||||
this.maxTrees = in.readVInt();
|
||||
this.numFolds = in.readVInt();
|
||||
this.numSplitsPerFeature = in.readVInt();
|
||||
this.regularizationDepthPenaltyMultiplier = in.readDouble();
|
||||
this.regularizationLeafWeightPenaltyMultiplier = in.readDouble();
|
||||
this.regularizationSoftTreeDepthLimit = in.readDouble();
|
||||
this.regularizationSoftTreeDepthTolerance = in.readDouble();
|
||||
this.regularizationTreeSizePenaltyMultiplier = in.readDouble();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeString(classAssignmentObjective);
|
||||
out.writeDouble(downsampleFactor);
|
||||
out.writeDouble(eta);
|
||||
out.writeDouble(etaGrowthRatePerTree);
|
||||
out.writeDouble(featureBagFraction);
|
||||
out.writeVInt(maxAttemptsToAddTree);
|
||||
out.writeVInt(maxOptimizationRoundsPerHyperparameter);
|
||||
out.writeVInt(maxTrees);
|
||||
out.writeVInt(numFolds);
|
||||
out.writeVInt(numSplitsPerFeature);
|
||||
out.writeDouble(regularizationDepthPenaltyMultiplier);
|
||||
out.writeDouble(regularizationLeafWeightPenaltyMultiplier);
|
||||
out.writeDouble(regularizationSoftTreeDepthLimit);
|
||||
out.writeDouble(regularizationSoftTreeDepthTolerance);
|
||||
out.writeDouble(regularizationTreeSizePenaltyMultiplier);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(CLASS_ASSIGNMENT_OBJECTIVE.getPreferredName(), classAssignmentObjective);
|
||||
builder.field(DOWNSAMPLE_FACTOR.getPreferredName(), downsampleFactor);
|
||||
builder.field(ETA.getPreferredName(), eta);
|
||||
builder.field(ETA_GROWTH_RATE_PER_TREE.getPreferredName(), etaGrowthRatePerTree);
|
||||
builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction);
|
||||
builder.field(MAX_ATTEMPTS_TO_ADD_TREE.getPreferredName(), maxAttemptsToAddTree);
|
||||
builder.field(MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER.getPreferredName(), maxOptimizationRoundsPerHyperparameter);
|
||||
builder.field(MAX_TREES.getPreferredName(), maxTrees);
|
||||
builder.field(NUM_FOLDS.getPreferredName(), numFolds);
|
||||
builder.field(NUM_SPLITS_PER_FEATURE.getPreferredName(), numSplitsPerFeature);
|
||||
builder.field(REGULARIZATION_DEPTH_PENALTY_MULTIPLIER.getPreferredName(), regularizationDepthPenaltyMultiplier);
|
||||
builder.field(REGULARIZATION_LEAF_WEIGHT_PENALTY_MULTIPLIER.getPreferredName(), regularizationLeafWeightPenaltyMultiplier);
|
||||
builder.field(REGULARIZATION_SOFT_TREE_DEPTH_LIMIT.getPreferredName(), regularizationSoftTreeDepthLimit);
|
||||
builder.field(REGULARIZATION_SOFT_TREE_DEPTH_TOLERANCE.getPreferredName(), regularizationSoftTreeDepthTolerance);
|
||||
builder.field(REGULARIZATION_TREE_SIZE_PENALTY_MULTIPLIER.getPreferredName(), regularizationTreeSizePenaltyMultiplier);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
|
||||
Hyperparameters that = (Hyperparameters) o;
|
||||
return Objects.equals(classAssignmentObjective, that.classAssignmentObjective)
|
||||
&& downsampleFactor == that.downsampleFactor
|
||||
&& eta == that.eta
|
||||
&& etaGrowthRatePerTree == that.etaGrowthRatePerTree
|
||||
&& featureBagFraction == that.featureBagFraction
|
||||
&& maxAttemptsToAddTree == that.maxAttemptsToAddTree
|
||||
&& maxOptimizationRoundsPerHyperparameter == that.maxOptimizationRoundsPerHyperparameter
|
||||
&& maxTrees == that.maxTrees
|
||||
&& numFolds == that.numFolds
|
||||
&& numSplitsPerFeature == that.numSplitsPerFeature
|
||||
&& regularizationDepthPenaltyMultiplier == that.regularizationDepthPenaltyMultiplier
|
||||
&& regularizationLeafWeightPenaltyMultiplier == that.regularizationLeafWeightPenaltyMultiplier
|
||||
&& regularizationSoftTreeDepthLimit == that.regularizationSoftTreeDepthLimit
|
||||
&& regularizationSoftTreeDepthTolerance == that.regularizationSoftTreeDepthTolerance
|
||||
&& regularizationTreeSizePenaltyMultiplier == that.regularizationTreeSizePenaltyMultiplier;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(
|
||||
classAssignmentObjective,
|
||||
downsampleFactor,
|
||||
eta,
|
||||
etaGrowthRatePerTree,
|
||||
featureBagFraction,
|
||||
maxAttemptsToAddTree,
|
||||
maxOptimizationRoundsPerHyperparameter,
|
||||
maxTrees,
|
||||
numFolds,
|
||||
numSplitsPerFeature,
|
||||
regularizationDepthPenaltyMultiplier,
|
||||
regularizationLeafWeightPenaltyMultiplier,
|
||||
regularizationSoftTreeDepthLimit,
|
||||
regularizationSoftTreeDepthTolerance,
|
||||
regularizationTreeSizePenaltyMultiplier
|
||||
);
|
||||
}
|
||||
}
|
@ -0,0 +1,80 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.stats.classification;
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.unit.TimeValue;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
|
||||
public class TimingStats implements Writeable, ToXContentObject {
|
||||
|
||||
public static final ParseField ELAPSED_TIME = new ParseField("elapsed_time");
|
||||
public static final ParseField ITERATION_TIME = new ParseField("iteration_time");
|
||||
|
||||
public static TimingStats fromXContent(XContentParser parser, boolean ignoreUnknownFields) {
|
||||
return createParser(ignoreUnknownFields).apply(parser, null);
|
||||
}
|
||||
|
||||
private static ConstructingObjectParser<TimingStats, Void> createParser(boolean ignoreUnknownFields) {
|
||||
ConstructingObjectParser<TimingStats, Void> parser = new ConstructingObjectParser<>("classification_timing_stats",
|
||||
ignoreUnknownFields,
|
||||
a -> new TimingStats(TimeValue.timeValueMillis((long) a[0]), TimeValue.timeValueMillis((long) a[1])));
|
||||
|
||||
parser.declareLong(ConstructingObjectParser.constructorArg(), ELAPSED_TIME);
|
||||
parser.declareLong(ConstructingObjectParser.constructorArg(), ITERATION_TIME);
|
||||
return parser;
|
||||
}
|
||||
|
||||
private final TimeValue elapsedTime;
|
||||
private final TimeValue iterationTime;
|
||||
|
||||
public TimingStats(TimeValue elapsedTime, TimeValue iterationTime) {
|
||||
this.elapsedTime = Objects.requireNonNull(elapsedTime);
|
||||
this.iterationTime = Objects.requireNonNull(iterationTime);
|
||||
}
|
||||
|
||||
public TimingStats(StreamInput in) throws IOException {
|
||||
this.elapsedTime = in.readTimeValue();
|
||||
this.iterationTime = in.readTimeValue();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeTimeValue(elapsedTime);
|
||||
out.writeTimeValue(iterationTime);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.humanReadableField(ELAPSED_TIME.getPreferredName(), ELAPSED_TIME.getPreferredName() + "_string", elapsedTime);
|
||||
builder.humanReadableField(ITERATION_TIME.getPreferredName(), ITERATION_TIME.getPreferredName() + "_string", iterationTime);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
TimingStats that = (TimingStats) o;
|
||||
return Objects.equals(elapsedTime, that.elapsedTime) && Objects.equals(iterationTime, that.iterationTime);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(elapsedTime, iterationTime);
|
||||
}
|
||||
}
|
@ -0,0 +1,82 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.stats.classification;
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.common.FoldValues;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
public class ValidationLoss implements ToXContentObject, Writeable {
|
||||
|
||||
public static final ParseField LOSS_TYPE = new ParseField("loss_type");
|
||||
public static final ParseField FOLD_VALUES = new ParseField("fold_values");
|
||||
|
||||
public static ValidationLoss fromXContent(XContentParser parser, boolean ignoreUnknownFields) {
|
||||
return createParser(ignoreUnknownFields).apply(parser, null);
|
||||
}
|
||||
|
||||
private static ConstructingObjectParser<ValidationLoss, Void> createParser(boolean ignoreUnknownFields) {
|
||||
ConstructingObjectParser<ValidationLoss, Void> parser = new ConstructingObjectParser<>("classification_validation_loss",
|
||||
ignoreUnknownFields,
|
||||
a -> new ValidationLoss((String) a[0], (List<FoldValues>) a[1]));
|
||||
|
||||
parser.declareString(ConstructingObjectParser.constructorArg(), LOSS_TYPE);
|
||||
parser.declareObjectArray(ConstructingObjectParser.constructorArg(),
|
||||
(p, c) -> FoldValues.fromXContent(p, ignoreUnknownFields), FOLD_VALUES);
|
||||
return parser;
|
||||
}
|
||||
|
||||
private final String lossType;
|
||||
private final List<FoldValues> foldValues;
|
||||
|
||||
public ValidationLoss(String lossType, List<FoldValues> values) {
|
||||
this.lossType = Objects.requireNonNull(lossType);
|
||||
this.foldValues = Objects.requireNonNull(values);
|
||||
}
|
||||
|
||||
public ValidationLoss(StreamInput in) throws IOException {
|
||||
lossType = in.readString();
|
||||
foldValues = in.readList(FoldValues::new);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeString(lossType);
|
||||
out.writeList(foldValues);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(LOSS_TYPE.getPreferredName(), lossType);
|
||||
builder.field(FOLD_VALUES.getPreferredName(), foldValues);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
ValidationLoss that = (ValidationLoss) o;
|
||||
return Objects.equals(lossType, that.lossType) && Objects.equals(foldValues, that.foldValues);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(lossType, foldValues);
|
||||
}
|
||||
}
|
@ -0,0 +1,84 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.stats.common;
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
public class FoldValues implements Writeable, ToXContentObject {
|
||||
|
||||
public static final ParseField FOLD = new ParseField("fold");
|
||||
public static final ParseField VALUES = new ParseField("values");
|
||||
|
||||
public static FoldValues fromXContent(XContentParser parser, boolean ignoreUnknownFields) {
|
||||
return createParser(ignoreUnknownFields).apply(parser, null);
|
||||
}
|
||||
|
||||
private static ConstructingObjectParser<FoldValues, Void> createParser(boolean ignoreUnknownFields) {
|
||||
ConstructingObjectParser<FoldValues, Void> parser = new ConstructingObjectParser<>("fold_values", ignoreUnknownFields,
|
||||
a -> new FoldValues((int) a[0], (List<Double>) a[1]));
|
||||
parser.declareInt(ConstructingObjectParser.constructorArg(), FOLD);
|
||||
parser.declareDoubleArray(ConstructingObjectParser.constructorArg(), VALUES);
|
||||
return parser;
|
||||
}
|
||||
|
||||
private final int fold;
|
||||
private final double[] values;
|
||||
|
||||
private FoldValues(int fold, List<Double> values) {
|
||||
this(fold, values.stream().mapToDouble(Double::doubleValue).toArray());
|
||||
}
|
||||
|
||||
public FoldValues(int fold, double[] values) {
|
||||
this.fold = fold;
|
||||
this.values = values;
|
||||
}
|
||||
|
||||
public FoldValues(StreamInput in) throws IOException {
|
||||
fold = in.readVInt();
|
||||
values = in.readDoubleArray();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeVInt(fold);
|
||||
out.writeDoubleArray(values);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(FOLD.getPreferredName(), fold);
|
||||
builder.array(VALUES.getPreferredName(), values);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (o == this) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
|
||||
FoldValues other = (FoldValues) o;
|
||||
return fold == other.fold && Arrays.equals(values, other.values);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(fold, Arrays.hashCode(values));
|
||||
}
|
||||
}
|
@ -0,0 +1,122 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection;
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xpack.core.common.time.TimeUtils;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStats;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.Fields;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.time.Instant;
|
||||
import java.util.Objects;
|
||||
|
||||
public class OutlierDetectionStats implements AnalysisStats {
|
||||
|
||||
public static final String TYPE_VALUE = "outlier_detection_stats";
|
||||
|
||||
public static final ParseField PARAMETERS = new ParseField("parameters");
|
||||
public static final ParseField TIMING_STATS = new ParseField("timings_stats");
|
||||
|
||||
public static final ConstructingObjectParser<OutlierDetectionStats, Void> STRICT_PARSER = createParser(false);
|
||||
public static final ConstructingObjectParser<OutlierDetectionStats, Void> LENIENT_PARSER = createParser(true);
|
||||
|
||||
private static ConstructingObjectParser<OutlierDetectionStats, Void> createParser(boolean ignoreUnknownFields) {
|
||||
ConstructingObjectParser<OutlierDetectionStats, Void> parser = new ConstructingObjectParser<>(TYPE_VALUE, ignoreUnknownFields,
|
||||
a -> new OutlierDetectionStats((String) a[0], (Instant) a[1], (Parameters) a[2], (TimingStats) a[3]));
|
||||
|
||||
parser.declareString((bucket, s) -> {}, Fields.TYPE);
|
||||
parser.declareString(ConstructingObjectParser.constructorArg(), Fields.JOB_ID);
|
||||
parser.declareField(ConstructingObjectParser.constructorArg(),
|
||||
p -> TimeUtils.parseTimeFieldToInstant(p, Fields.TIMESTAMP.getPreferredName()),
|
||||
Fields.TIMESTAMP,
|
||||
ObjectParser.ValueType.VALUE);
|
||||
parser.declareObject(ConstructingObjectParser.constructorArg(),
|
||||
(p, c) -> Parameters.fromXContent(p, ignoreUnknownFields), PARAMETERS);
|
||||
parser.declareObject(ConstructingObjectParser.constructorArg(),
|
||||
(p, c) -> TimingStats.fromXContent(p, ignoreUnknownFields), TIMING_STATS);
|
||||
return parser;
|
||||
}
|
||||
|
||||
private final String jobId;
|
||||
private final Instant timestamp;
|
||||
private final Parameters parameters;
|
||||
private final TimingStats timingStats;
|
||||
|
||||
public OutlierDetectionStats(String jobId, Instant timestamp, Parameters parameters, TimingStats timingStats) {
|
||||
this.jobId = Objects.requireNonNull(jobId);
|
||||
// We intend to store this timestamp in millis granularity. Thus we're rounding here to ensure
|
||||
// internal representation matches toXContent
|
||||
this.timestamp = Instant.ofEpochMilli(ExceptionsHelper.requireNonNull(timestamp, Fields.TIMESTAMP).toEpochMilli());
|
||||
this.parameters = Objects.requireNonNull(parameters);
|
||||
this.timingStats = Objects.requireNonNull(timingStats);
|
||||
}
|
||||
|
||||
public OutlierDetectionStats(StreamInput in) throws IOException {
|
||||
this.jobId = in.readString();
|
||||
this.timestamp = in.readInstant();
|
||||
this.parameters = new Parameters(in);
|
||||
this.timingStats = new TimingStats(in);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return TYPE_VALUE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeString(jobId);
|
||||
out.writeInstant(timestamp);
|
||||
parameters.writeTo(out);
|
||||
timingStats.writeTo(out);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
if (params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false)) {
|
||||
builder.field(Fields.TYPE.getPreferredName(), TYPE_VALUE);
|
||||
builder.field(Fields.JOB_ID.getPreferredName(), jobId);
|
||||
}
|
||||
builder.timeField(Fields.TIMESTAMP.getPreferredName(), Fields.TIMESTAMP.getPreferredName() + "_string", timestamp.toEpochMilli());
|
||||
builder.field(PARAMETERS.getPreferredName(), parameters);
|
||||
builder.field(TIMING_STATS.getPreferredName(), timingStats);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
OutlierDetectionStats that = (OutlierDetectionStats) o;
|
||||
return Objects.equals(jobId, that.jobId)
|
||||
&& Objects.equals(timestamp, that.timestamp)
|
||||
&& Objects.equals(parameters, that.parameters)
|
||||
&& Objects.equals(timingStats, that.timingStats);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(jobId, timestamp, parameters, timingStats);
|
||||
}
|
||||
|
||||
public String documentId(String jobId) {
|
||||
return documentIdPrefix(jobId) + timestamp.toEpochMilli();
|
||||
}
|
||||
|
||||
public static String documentIdPrefix(String jobId) {
|
||||
return TYPE_VALUE + "_" + jobId + "_";
|
||||
}
|
||||
}
|
@ -0,0 +1,125 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection;
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
|
||||
public class Parameters implements Writeable, ToXContentObject {
|
||||
|
||||
public static final ParseField N_NEIGHBORS = new ParseField("n_neighbors");
|
||||
public static final ParseField METHOD = new ParseField("method");
|
||||
public static final ParseField FEATURE_INFLUENCE_THRESHOLD = new ParseField("feature_influence_threshold");
|
||||
public static final ParseField COMPUTE_FEATURE_INFLUENCE = new ParseField("compute_feature_influence");
|
||||
public static final ParseField OUTLIER_FRACTION = new ParseField("outlier_fraction");
|
||||
public static final ParseField STANDARDIZATION_ENABLED = new ParseField("standardization_enabled");
|
||||
|
||||
public static Parameters fromXContent(XContentParser parser, boolean ignoreUnknownFields) {
|
||||
return createParser(ignoreUnknownFields).apply(parser, null);
|
||||
}
|
||||
|
||||
private static ConstructingObjectParser<Parameters, Void> createParser(boolean ignoreUnknownFields) {
|
||||
ConstructingObjectParser<Parameters, Void> parser = new ConstructingObjectParser<>("outlier_detection_parameters",
|
||||
ignoreUnknownFields,
|
||||
a -> new Parameters(
|
||||
(int) a[0],
|
||||
(String) a[1],
|
||||
(boolean) a[2],
|
||||
(double) a[3],
|
||||
(double) a[4],
|
||||
(boolean) a[5]
|
||||
));
|
||||
|
||||
parser.declareInt(constructorArg(), N_NEIGHBORS);
|
||||
parser.declareString(constructorArg(), METHOD);
|
||||
parser.declareBoolean(constructorArg(), COMPUTE_FEATURE_INFLUENCE);
|
||||
parser.declareDouble(constructorArg(), FEATURE_INFLUENCE_THRESHOLD);
|
||||
parser.declareDouble(constructorArg(), OUTLIER_FRACTION);
|
||||
parser.declareBoolean(constructorArg(), STANDARDIZATION_ENABLED);
|
||||
|
||||
return parser;
|
||||
}
|
||||
|
||||
private final int nNeighbors;
|
||||
private final String method;
|
||||
private final boolean computeFeatureInfluence;
|
||||
private final double featureInfluenceThreshold;
|
||||
private final double outlierFraction;
|
||||
private final boolean standardizationEnabled;
|
||||
|
||||
public Parameters(int nNeighbors, String method, boolean computeFeatureInfluence, double featureInfluenceThreshold,
|
||||
double outlierFraction, boolean standardizationEnabled) {
|
||||
this.nNeighbors = nNeighbors;
|
||||
this.method = method;
|
||||
this.computeFeatureInfluence = computeFeatureInfluence;
|
||||
this.featureInfluenceThreshold = featureInfluenceThreshold;
|
||||
this.outlierFraction = outlierFraction;
|
||||
this.standardizationEnabled = standardizationEnabled;
|
||||
}
|
||||
|
||||
public Parameters(StreamInput in) throws IOException {
|
||||
this.nNeighbors = in.readVInt();
|
||||
this.method = in.readString();
|
||||
this.computeFeatureInfluence = in.readBoolean();
|
||||
this.featureInfluenceThreshold = in.readDouble();
|
||||
this.outlierFraction = in.readDouble();
|
||||
this.standardizationEnabled = in.readBoolean();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeVInt(nNeighbors);
|
||||
out.writeString(method);
|
||||
out.writeBoolean(computeFeatureInfluence);
|
||||
out.writeDouble(featureInfluenceThreshold);
|
||||
out.writeDouble(outlierFraction);
|
||||
out.writeBoolean(standardizationEnabled);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(N_NEIGHBORS.getPreferredName(), nNeighbors);
|
||||
builder.field(METHOD.getPreferredName(), method);
|
||||
builder.field(COMPUTE_FEATURE_INFLUENCE.getPreferredName(), computeFeatureInfluence);
|
||||
builder.field(FEATURE_INFLUENCE_THRESHOLD.getPreferredName(), featureInfluenceThreshold);
|
||||
builder.field(OUTLIER_FRACTION.getPreferredName(), outlierFraction);
|
||||
builder.field(STANDARDIZATION_ENABLED.getPreferredName(), standardizationEnabled);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
|
||||
Parameters that = (Parameters) o;
|
||||
return nNeighbors == that.nNeighbors
|
||||
&& Objects.equals(method, that.method)
|
||||
&& computeFeatureInfluence == that.computeFeatureInfluence
|
||||
&& featureInfluenceThreshold == that.featureInfluenceThreshold
|
||||
&& outlierFraction == that.outlierFraction
|
||||
&& standardizationEnabled == that.standardizationEnabled;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(nNeighbors, method, computeFeatureInfluence, featureInfluenceThreshold, outlierFraction,
|
||||
standardizationEnabled);
|
||||
}
|
||||
}
|
@ -0,0 +1,73 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection;
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.unit.TimeValue;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
|
||||
public class TimingStats implements Writeable, ToXContentObject {
|
||||
|
||||
public static final ParseField ELAPSED_TIME = new ParseField("elapsed_time");
|
||||
|
||||
public static TimingStats fromXContent(XContentParser parser, boolean ignoreUnknownFields) {
|
||||
return createParser(ignoreUnknownFields).apply(parser, null);
|
||||
}
|
||||
|
||||
private static ConstructingObjectParser<TimingStats, Void> createParser(boolean ignoreUnknownFields) {
|
||||
ConstructingObjectParser<TimingStats, Void> parser = new ConstructingObjectParser<>("outlier_detection_timing_stats",
|
||||
ignoreUnknownFields,
|
||||
a -> new TimingStats(TimeValue.timeValueMillis((long) a[0])));
|
||||
|
||||
parser.declareLong(ConstructingObjectParser.constructorArg(), ELAPSED_TIME);
|
||||
return parser;
|
||||
}
|
||||
|
||||
private final TimeValue elapsedTime;
|
||||
|
||||
public TimingStats(TimeValue elapsedTime) {
|
||||
this.elapsedTime = Objects.requireNonNull(elapsedTime);
|
||||
}
|
||||
|
||||
public TimingStats(StreamInput in) throws IOException {
|
||||
this.elapsedTime = in.readTimeValue();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeTimeValue(elapsedTime);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.humanReadableField(ELAPSED_TIME.getPreferredName(), ELAPSED_TIME.getPreferredName() + "_string", elapsedTime);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
TimingStats that = (TimingStats) o;
|
||||
return Objects.equals(elapsedTime, that.elapsedTime);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(elapsedTime);
|
||||
}
|
||||
}
|
@ -0,0 +1,226 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.stats.regression;
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
|
||||
public class Hyperparameters implements ToXContentObject, Writeable {
|
||||
|
||||
public static final ParseField DOWNSAMPLE_FACTOR = new ParseField("downsample_factor");
|
||||
public static final ParseField ETA = new ParseField("eta");
|
||||
public static final ParseField ETA_GROWTH_RATE_PER_TREE = new ParseField("eta_growth_rate_per_tree");
|
||||
public static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction");
|
||||
public static final ParseField MAX_ATTEMPTS_TO_ADD_TREE = new ParseField("max_attempts_to_add_tree");
|
||||
public static final ParseField MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER = new ParseField(
|
||||
"max_optimization_rounds_per_hyperparameter");
|
||||
public static final ParseField MAX_TREES = new ParseField("max_trees");
|
||||
public static final ParseField NUM_FOLDS = new ParseField("num_folds");
|
||||
public static final ParseField NUM_SPLITS_PER_FEATURE = new ParseField("num_splits_per_feature");
|
||||
public static final ParseField REGULARIZATION_DEPTH_PENALTY_MULTIPLIER = new ParseField("regularization_depth_penalty_multiplier");
|
||||
public static final ParseField REGULARIZATION_LEAF_WEIGHT_PENALTY_MULTIPLIER
|
||||
= new ParseField("regularization_leaf_weight_penalty_multiplier");
|
||||
public static final ParseField REGULARIZATION_SOFT_TREE_DEPTH_LIMIT = new ParseField("regularization_soft_tree_depth_limit");
|
||||
public static final ParseField REGULARIZATION_SOFT_TREE_DEPTH_TOLERANCE = new ParseField("regularization_soft_tree_depth_tolerance");
|
||||
public static final ParseField REGULARIZATION_TREE_SIZE_PENALTY_MULTIPLIER =
|
||||
new ParseField("regularization_tree_size_penalty_multiplier");
|
||||
|
||||
public static Hyperparameters fromXContent(XContentParser parser, boolean ignoreUnknownFields) {
|
||||
return createParser(ignoreUnknownFields).apply(parser, null);
|
||||
}
|
||||
|
||||
private static ConstructingObjectParser<Hyperparameters, Void> createParser(boolean ignoreUnknownFields) {
|
||||
ConstructingObjectParser<Hyperparameters, Void> parser = new ConstructingObjectParser<>("regression_hyperparameters",
|
||||
ignoreUnknownFields,
|
||||
a -> new Hyperparameters(
|
||||
(double) a[0],
|
||||
(double) a[1],
|
||||
(double) a[2],
|
||||
(double) a[3],
|
||||
(int) a[4],
|
||||
(int) a[5],
|
||||
(int) a[6],
|
||||
(int) a[7],
|
||||
(int) a[8],
|
||||
(double) a[9],
|
||||
(double) a[10],
|
||||
(double) a[11],
|
||||
(double) a[12],
|
||||
(double) a[13]
|
||||
));
|
||||
|
||||
parser.declareDouble(constructorArg(), DOWNSAMPLE_FACTOR);
|
||||
parser.declareDouble(constructorArg(), ETA);
|
||||
parser.declareDouble(constructorArg(), ETA_GROWTH_RATE_PER_TREE);
|
||||
parser.declareDouble(constructorArg(), FEATURE_BAG_FRACTION);
|
||||
parser.declareInt(constructorArg(), MAX_ATTEMPTS_TO_ADD_TREE);
|
||||
parser.declareInt(constructorArg(), MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER);
|
||||
parser.declareInt(constructorArg(), MAX_TREES);
|
||||
parser.declareInt(constructorArg(), NUM_FOLDS);
|
||||
parser.declareInt(constructorArg(), NUM_SPLITS_PER_FEATURE);
|
||||
parser.declareDouble(constructorArg(), REGULARIZATION_DEPTH_PENALTY_MULTIPLIER);
|
||||
parser.declareDouble(constructorArg(), REGULARIZATION_LEAF_WEIGHT_PENALTY_MULTIPLIER);
|
||||
parser.declareDouble(constructorArg(), REGULARIZATION_SOFT_TREE_DEPTH_LIMIT);
|
||||
parser.declareDouble(constructorArg(), REGULARIZATION_SOFT_TREE_DEPTH_TOLERANCE);
|
||||
parser.declareDouble(constructorArg(), REGULARIZATION_TREE_SIZE_PENALTY_MULTIPLIER);
|
||||
|
||||
return parser;
|
||||
}
|
||||
|
||||
private final double downsampleFactor;
|
||||
private final double eta;
|
||||
private final double etaGrowthRatePerTree;
|
||||
private final double featureBagFraction;
|
||||
private final int maxAttemptsToAddTree;
|
||||
private final int maxOptimizationRoundsPerHyperparameter;
|
||||
private final int maxTrees;
|
||||
private final int numFolds;
|
||||
private final int numSplitsPerFeature;
|
||||
private final double regularizationDepthPenaltyMultiplier;
|
||||
private final double regularizationLeafWeightPenaltyMultiplier;
|
||||
private final double regularizationSoftTreeDepthLimit;
|
||||
private final double regularizationSoftTreeDepthTolerance;
|
||||
private final double regularizationTreeSizePenaltyMultiplier;
|
||||
|
||||
public Hyperparameters(double downsampleFactor,
|
||||
double eta,
|
||||
double etaGrowthRatePerTree,
|
||||
double featureBagFraction,
|
||||
int maxAttemptsToAddTree,
|
||||
int maxOptimizationRoundsPerHyperparameter,
|
||||
int maxTrees,
|
||||
int numFolds,
|
||||
int numSplitsPerFeature,
|
||||
double regularizationDepthPenaltyMultiplier,
|
||||
double regularizationLeafWeightPenaltyMultiplier,
|
||||
double regularizationSoftTreeDepthLimit,
|
||||
double regularizationSoftTreeDepthTolerance,
|
||||
double regularizationTreeSizePenaltyMultiplier) {
|
||||
this.downsampleFactor = downsampleFactor;
|
||||
this.eta = eta;
|
||||
this.etaGrowthRatePerTree = etaGrowthRatePerTree;
|
||||
this.featureBagFraction = featureBagFraction;
|
||||
this.maxAttemptsToAddTree = maxAttemptsToAddTree;
|
||||
this.maxOptimizationRoundsPerHyperparameter = maxOptimizationRoundsPerHyperparameter;
|
||||
this.maxTrees = maxTrees;
|
||||
this.numFolds = numFolds;
|
||||
this.numSplitsPerFeature = numSplitsPerFeature;
|
||||
this.regularizationDepthPenaltyMultiplier = regularizationDepthPenaltyMultiplier;
|
||||
this.regularizationLeafWeightPenaltyMultiplier = regularizationLeafWeightPenaltyMultiplier;
|
||||
this.regularizationSoftTreeDepthLimit = regularizationSoftTreeDepthLimit;
|
||||
this.regularizationSoftTreeDepthTolerance = regularizationSoftTreeDepthTolerance;
|
||||
this.regularizationTreeSizePenaltyMultiplier = regularizationTreeSizePenaltyMultiplier;
|
||||
}
|
||||
|
||||
public Hyperparameters(StreamInput in) throws IOException {
|
||||
this.downsampleFactor = in.readDouble();
|
||||
this.eta = in.readDouble();
|
||||
this.etaGrowthRatePerTree = in.readDouble();
|
||||
this.featureBagFraction = in.readDouble();
|
||||
this.maxAttemptsToAddTree = in.readVInt();
|
||||
this.maxOptimizationRoundsPerHyperparameter = in.readVInt();
|
||||
this.maxTrees = in.readVInt();
|
||||
this.numFolds = in.readVInt();
|
||||
this.numSplitsPerFeature = in.readVInt();
|
||||
this.regularizationDepthPenaltyMultiplier = in.readDouble();
|
||||
this.regularizationLeafWeightPenaltyMultiplier = in.readDouble();
|
||||
this.regularizationSoftTreeDepthLimit = in.readDouble();
|
||||
this.regularizationSoftTreeDepthTolerance = in.readDouble();
|
||||
this.regularizationTreeSizePenaltyMultiplier = in.readDouble();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeDouble(downsampleFactor);
|
||||
out.writeDouble(eta);
|
||||
out.writeDouble(etaGrowthRatePerTree);
|
||||
out.writeDouble(featureBagFraction);
|
||||
out.writeVInt(maxAttemptsToAddTree);
|
||||
out.writeVInt(maxOptimizationRoundsPerHyperparameter);
|
||||
out.writeVInt(maxTrees);
|
||||
out.writeVInt(numFolds);
|
||||
out.writeVInt(numSplitsPerFeature);
|
||||
out.writeDouble(regularizationDepthPenaltyMultiplier);
|
||||
out.writeDouble(regularizationLeafWeightPenaltyMultiplier);
|
||||
out.writeDouble(regularizationSoftTreeDepthLimit);
|
||||
out.writeDouble(regularizationSoftTreeDepthTolerance);
|
||||
out.writeDouble(regularizationTreeSizePenaltyMultiplier);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(DOWNSAMPLE_FACTOR.getPreferredName(), downsampleFactor);
|
||||
builder.field(ETA.getPreferredName(), eta);
|
||||
builder.field(ETA_GROWTH_RATE_PER_TREE.getPreferredName(), etaGrowthRatePerTree);
|
||||
builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction);
|
||||
builder.field(MAX_ATTEMPTS_TO_ADD_TREE.getPreferredName(), maxAttemptsToAddTree);
|
||||
builder.field(MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER.getPreferredName(), maxOptimizationRoundsPerHyperparameter);
|
||||
builder.field(MAX_TREES.getPreferredName(), maxTrees);
|
||||
builder.field(NUM_FOLDS.getPreferredName(), numFolds);
|
||||
builder.field(NUM_SPLITS_PER_FEATURE.getPreferredName(), numSplitsPerFeature);
|
||||
builder.field(REGULARIZATION_DEPTH_PENALTY_MULTIPLIER.getPreferredName(), regularizationDepthPenaltyMultiplier);
|
||||
builder.field(REGULARIZATION_LEAF_WEIGHT_PENALTY_MULTIPLIER.getPreferredName(), regularizationLeafWeightPenaltyMultiplier);
|
||||
builder.field(REGULARIZATION_SOFT_TREE_DEPTH_LIMIT.getPreferredName(), regularizationSoftTreeDepthLimit);
|
||||
builder.field(REGULARIZATION_SOFT_TREE_DEPTH_TOLERANCE.getPreferredName(), regularizationSoftTreeDepthTolerance);
|
||||
builder.field(REGULARIZATION_TREE_SIZE_PENALTY_MULTIPLIER.getPreferredName(), regularizationTreeSizePenaltyMultiplier);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
|
||||
Hyperparameters that = (Hyperparameters) o;
|
||||
return downsampleFactor == that.downsampleFactor
|
||||
&& eta == that.eta
|
||||
&& etaGrowthRatePerTree == that.etaGrowthRatePerTree
|
||||
&& featureBagFraction == that.featureBagFraction
|
||||
&& maxAttemptsToAddTree == that.maxAttemptsToAddTree
|
||||
&& maxOptimizationRoundsPerHyperparameter == that.maxOptimizationRoundsPerHyperparameter
|
||||
&& maxTrees == that.maxTrees
|
||||
&& numFolds == that.numFolds
|
||||
&& numSplitsPerFeature == that.numSplitsPerFeature
|
||||
&& regularizationDepthPenaltyMultiplier == that.regularizationDepthPenaltyMultiplier
|
||||
&& regularizationLeafWeightPenaltyMultiplier == that.regularizationLeafWeightPenaltyMultiplier
|
||||
&& regularizationSoftTreeDepthLimit == that.regularizationSoftTreeDepthLimit
|
||||
&& regularizationSoftTreeDepthTolerance == that.regularizationSoftTreeDepthTolerance
|
||||
&& regularizationTreeSizePenaltyMultiplier == that.regularizationTreeSizePenaltyMultiplier;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(
|
||||
downsampleFactor,
|
||||
eta,
|
||||
etaGrowthRatePerTree,
|
||||
featureBagFraction,
|
||||
maxAttemptsToAddTree,
|
||||
maxOptimizationRoundsPerHyperparameter,
|
||||
maxTrees,
|
||||
numFolds,
|
||||
numSplitsPerFeature,
|
||||
regularizationDepthPenaltyMultiplier,
|
||||
regularizationLeafWeightPenaltyMultiplier,
|
||||
regularizationSoftTreeDepthLimit,
|
||||
regularizationSoftTreeDepthTolerance,
|
||||
regularizationTreeSizePenaltyMultiplier
|
||||
);
|
||||
}
|
||||
}
|
@ -0,0 +1,148 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.stats.regression;
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xpack.core.common.time.TimeUtils;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStats;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.Fields;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.time.Instant;
|
||||
import java.util.Objects;
|
||||
|
||||
public class RegressionStats implements AnalysisStats {
|
||||
|
||||
public static final String TYPE_VALUE = "regression_stats";
|
||||
|
||||
public static final ParseField ITERATION = new ParseField("iteration");
|
||||
public static final ParseField HYPERPARAMETERS = new ParseField("hyperparameters");
|
||||
public static final ParseField TIMING_STATS = new ParseField("timing_stats");
|
||||
public static final ParseField VALIDATION_LOSS = new ParseField("validation_loss");
|
||||
|
||||
public static final ConstructingObjectParser<RegressionStats, Void> STRICT_PARSER = createParser(false);
|
||||
public static final ConstructingObjectParser<RegressionStats, Void> LENIENT_PARSER = createParser(true);
|
||||
|
||||
private static ConstructingObjectParser<RegressionStats, Void> createParser(boolean ignoreUnknownFields) {
|
||||
ConstructingObjectParser<RegressionStats, Void> parser = new ConstructingObjectParser<>(TYPE_VALUE, ignoreUnknownFields,
|
||||
a -> new RegressionStats(
|
||||
(String) a[0],
|
||||
(Instant) a[1],
|
||||
(int) a[2],
|
||||
(Hyperparameters) a[3],
|
||||
(TimingStats) a[4],
|
||||
(ValidationLoss) a[5]
|
||||
)
|
||||
);
|
||||
|
||||
parser.declareString((bucket, s) -> {}, Fields.TYPE);
|
||||
parser.declareString(ConstructingObjectParser.constructorArg(), Fields.JOB_ID);
|
||||
parser.declareField(ConstructingObjectParser.constructorArg(),
|
||||
p -> TimeUtils.parseTimeFieldToInstant(p, Fields.TIMESTAMP.getPreferredName()),
|
||||
Fields.TIMESTAMP,
|
||||
ObjectParser.ValueType.VALUE);
|
||||
parser.declareInt(ConstructingObjectParser.constructorArg(), ITERATION);
|
||||
parser.declareObject(ConstructingObjectParser.constructorArg(),
|
||||
(p, c) -> Hyperparameters.fromXContent(p, ignoreUnknownFields), HYPERPARAMETERS);
|
||||
parser.declareObject(ConstructingObjectParser.constructorArg(),
|
||||
(p, c) -> TimingStats.fromXContent(p, ignoreUnknownFields), TIMING_STATS);
|
||||
parser.declareObject(ConstructingObjectParser.constructorArg(),
|
||||
(p, c) -> ValidationLoss.fromXContent(p, ignoreUnknownFields), VALIDATION_LOSS);
|
||||
return parser;
|
||||
}
|
||||
|
||||
private final String jobId;
|
||||
private final Instant timestamp;
|
||||
private final int iteration;
|
||||
private final Hyperparameters hyperparameters;
|
||||
private final TimingStats timingStats;
|
||||
private final ValidationLoss validationLoss;
|
||||
|
||||
public RegressionStats(String jobId, Instant timestamp, int iteration, Hyperparameters hyperparameters, TimingStats timingStats,
|
||||
ValidationLoss validationLoss) {
|
||||
this.jobId = Objects.requireNonNull(jobId);
|
||||
// We intend to store this timestamp in millis granularity. Thus we're rounding here to ensure
|
||||
// internal representation matches toXContent
|
||||
this.timestamp = Instant.ofEpochMilli(ExceptionsHelper.requireNonNull(timestamp, Fields.TIMESTAMP).toEpochMilli());
|
||||
this.iteration = iteration;
|
||||
this.hyperparameters = Objects.requireNonNull(hyperparameters);
|
||||
this.timingStats = Objects.requireNonNull(timingStats);
|
||||
this.validationLoss = Objects.requireNonNull(validationLoss);
|
||||
}
|
||||
|
||||
public RegressionStats(StreamInput in) throws IOException {
|
||||
this.jobId = in.readString();
|
||||
this.timestamp = in.readInstant();
|
||||
this.iteration = in.readVInt();
|
||||
this.hyperparameters = new Hyperparameters(in);
|
||||
this.timingStats = new TimingStats(in);
|
||||
this.validationLoss = new ValidationLoss(in);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return TYPE_VALUE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeString(jobId);
|
||||
out.writeInstant(timestamp);
|
||||
out.writeVInt(iteration);
|
||||
hyperparameters.writeTo(out);
|
||||
timingStats.writeTo(out);
|
||||
validationLoss.writeTo(out);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
if (params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false)) {
|
||||
builder.field(Fields.TYPE.getPreferredName(), TYPE_VALUE);
|
||||
builder.field(Fields.JOB_ID.getPreferredName(), jobId);
|
||||
}
|
||||
builder.timeField(Fields.TIMESTAMP.getPreferredName(), Fields.TIMESTAMP.getPreferredName() + "_string", timestamp.toEpochMilli());
|
||||
builder.field(ITERATION.getPreferredName(), iteration);
|
||||
builder.field(HYPERPARAMETERS.getPreferredName(), hyperparameters);
|
||||
builder.field(TIMING_STATS.getPreferredName(), timingStats);
|
||||
builder.field(VALIDATION_LOSS.getPreferredName(), validationLoss);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
RegressionStats that = (RegressionStats) o;
|
||||
return Objects.equals(jobId, that.jobId)
|
||||
&& Objects.equals(timestamp, that.timestamp)
|
||||
&& iteration == that.iteration
|
||||
&& Objects.equals(hyperparameters, that.hyperparameters)
|
||||
&& Objects.equals(timingStats, that.timingStats)
|
||||
&& Objects.equals(validationLoss, that.validationLoss);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(jobId, timestamp, iteration, hyperparameters, timingStats, validationLoss);
|
||||
}
|
||||
|
||||
public String documentId(String jobId) {
|
||||
return documentIdPrefix(jobId) + timestamp.toEpochMilli();
|
||||
}
|
||||
|
||||
public static String documentIdPrefix(String jobId) {
|
||||
return TYPE_VALUE + "_" + jobId + "_";
|
||||
}
|
||||
}
|
@ -0,0 +1,79 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.stats.regression;
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.unit.TimeValue;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
|
||||
public class TimingStats implements Writeable, ToXContentObject {
|
||||
|
||||
public static final ParseField ELAPSED_TIME = new ParseField("elapsed_time");
|
||||
public static final ParseField ITERATION_TIME = new ParseField("iteration_time");
|
||||
|
||||
public static TimingStats fromXContent(XContentParser parser, boolean ignoreUnknownFields) {
|
||||
return createParser(ignoreUnknownFields).apply(parser, null);
|
||||
}
|
||||
|
||||
private static ConstructingObjectParser<TimingStats, Void> createParser(boolean ignoreUnknownFields) {
|
||||
ConstructingObjectParser<TimingStats, Void> parser = new ConstructingObjectParser<>("regression_timing_stats", ignoreUnknownFields,
|
||||
a -> new TimingStats(TimeValue.timeValueMillis((long) a[0]), TimeValue.timeValueMillis((long) a[1])));
|
||||
|
||||
parser.declareLong(ConstructingObjectParser.constructorArg(), ELAPSED_TIME);
|
||||
parser.declareLong(ConstructingObjectParser.constructorArg(), ITERATION_TIME);
|
||||
return parser;
|
||||
}
|
||||
|
||||
private final TimeValue elapsedTime;
|
||||
private final TimeValue iterationTime;
|
||||
|
||||
public TimingStats(TimeValue elapsedTime, TimeValue iterationTime) {
|
||||
this.elapsedTime = Objects.requireNonNull(elapsedTime);
|
||||
this.iterationTime = Objects.requireNonNull(iterationTime);
|
||||
}
|
||||
|
||||
public TimingStats(StreamInput in) throws IOException {
|
||||
this.elapsedTime = in.readTimeValue();
|
||||
this.iterationTime = in.readTimeValue();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeTimeValue(elapsedTime);
|
||||
out.writeTimeValue(iterationTime);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.humanReadableField(ELAPSED_TIME.getPreferredName(), ELAPSED_TIME.getPreferredName() + "_string", elapsedTime);
|
||||
builder.humanReadableField(ITERATION_TIME.getPreferredName(), ITERATION_TIME.getPreferredName() + "_string", iterationTime);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
TimingStats that = (TimingStats) o;
|
||||
return Objects.equals(elapsedTime, that.elapsedTime) && Objects.equals(iterationTime, that.iterationTime);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(elapsedTime, iterationTime);
|
||||
}
|
||||
}
|
@ -0,0 +1,82 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.stats.regression;
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.common.FoldValues;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
public class ValidationLoss implements ToXContentObject, Writeable {
|
||||
|
||||
public static final ParseField LOSS_TYPE = new ParseField("loss_type");
|
||||
public static final ParseField FOLD_VALUES = new ParseField("fold_values");
|
||||
|
||||
public static ValidationLoss fromXContent(XContentParser parser, boolean ignoreUnknownFields) {
|
||||
return createParser(ignoreUnknownFields).apply(parser, null);
|
||||
}
|
||||
|
||||
private static ConstructingObjectParser<ValidationLoss, Void> createParser(boolean ignoreUnknownFields) {
|
||||
ConstructingObjectParser<ValidationLoss, Void> parser = new ConstructingObjectParser<>("regression_validation_loss",
|
||||
ignoreUnknownFields,
|
||||
a -> new ValidationLoss((String) a[0], (List<FoldValues>) a[1]));
|
||||
|
||||
parser.declareString(ConstructingObjectParser.constructorArg(), LOSS_TYPE);
|
||||
parser.declareObjectArray(ConstructingObjectParser.constructorArg(),
|
||||
(p, c) -> FoldValues.fromXContent(p, ignoreUnknownFields), FOLD_VALUES);
|
||||
return parser;
|
||||
}
|
||||
|
||||
private final String lossType;
|
||||
private final List<FoldValues> foldValues;
|
||||
|
||||
public ValidationLoss(String lossType, List<FoldValues> values) {
|
||||
this.lossType = Objects.requireNonNull(lossType);
|
||||
this.foldValues = Objects.requireNonNull(values);
|
||||
}
|
||||
|
||||
public ValidationLoss(StreamInput in) throws IOException {
|
||||
lossType = in.readString();
|
||||
foldValues = in.readList(FoldValues::new);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeString(lossType);
|
||||
out.writeList(foldValues);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(LOSS_TYPE.getPreferredName(), lossType);
|
||||
builder.field(FOLD_VALUES.getPreferredName(), foldValues);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
ValidationLoss that = (ValidationLoss) o;
|
||||
return Objects.equals(lossType, that.lossType) && Objects.equals(foldValues, that.foldValues);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(lossType, foldValues);
|
||||
}
|
||||
}
|
@ -3,18 +3,120 @@
|
||||
"_meta": {
|
||||
"version" : "${xpack.ml.version}"
|
||||
},
|
||||
"dynamic": false,
|
||||
"properties" : {
|
||||
"type" : {
|
||||
"type" : "keyword"
|
||||
"iteration": {
|
||||
"type": "integer"
|
||||
},
|
||||
"hyperparameters": {
|
||||
"properties": {
|
||||
"class_assignment_objective": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"downsample_factor": {
|
||||
"type": "double"
|
||||
},
|
||||
"eta": {
|
||||
"type": "double"
|
||||
},
|
||||
"eta_growth_rate_per_tree": {
|
||||
"type": "double"
|
||||
},
|
||||
"feature_bag_fraction": {
|
||||
"type": "double"
|
||||
},
|
||||
"max_attempts_to_add_tree": {
|
||||
"type": "integer"
|
||||
},
|
||||
"max_optimization_rounds_per_hyperparameter": {
|
||||
"type": "integer"
|
||||
},
|
||||
"max_trees": {
|
||||
"type": "integer"
|
||||
},
|
||||
"num_folds": {
|
||||
"type": "integer"
|
||||
},
|
||||
"num_splits_per_feature": {
|
||||
"type": "integer"
|
||||
},
|
||||
"regularization_depth_penalty_multiplier": {
|
||||
"type": "double"
|
||||
},
|
||||
"regularization_leaf_weight_penalty_multiplier": {
|
||||
"type": "double"
|
||||
},
|
||||
"regularization_soft_tree_depth_limit": {
|
||||
"type": "double"
|
||||
},
|
||||
"regularization_soft_tree_depth_tolerance": {
|
||||
"type": "double"
|
||||
},
|
||||
"regularization_tree_size_penalty_multiplier": {
|
||||
"type": "double"
|
||||
}
|
||||
}
|
||||
},
|
||||
"job_id" : {
|
||||
"type" : "keyword"
|
||||
},
|
||||
"timestamp" : {
|
||||
"type" : "date"
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"compute_feature_influence": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"feature_influence_threshold": {
|
||||
"type": "double"
|
||||
},
|
||||
"method": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"n_neighbors": {
|
||||
"type": "integer"
|
||||
},
|
||||
"outlier_fraction": {
|
||||
"type": "double"
|
||||
},
|
||||
"standardization_enabled": {
|
||||
"type": "boolean"
|
||||
}
|
||||
}
|
||||
},
|
||||
"peak_usage_bytes" : {
|
||||
"type" : "long"
|
||||
},
|
||||
"timestamp" : {
|
||||
"type" : "date"
|
||||
},
|
||||
"timing_stats": {
|
||||
"properties": {
|
||||
"elapsed_time": {
|
||||
"type": "long"
|
||||
},
|
||||
"iteration_time": {
|
||||
"type": "long"
|
||||
}
|
||||
}
|
||||
},
|
||||
"type" : {
|
||||
"type" : "keyword"
|
||||
},
|
||||
"validation_loss": {
|
||||
"properties": {
|
||||
"fold_values": {
|
||||
"properties": {
|
||||
"fold": {
|
||||
"type": "integer"
|
||||
},
|
||||
"values": {
|
||||
"type": "double"
|
||||
}
|
||||
}
|
||||
},
|
||||
"loss_type": {
|
||||
"type": "keyword"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -5,14 +5,20 @@
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.action;
|
||||
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.action.util.QueryPage;
|
||||
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction.Response;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigTests;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStats;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStatsNamedWriteablesProvider;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.MemoryUsage;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.MemoryUsageTests;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStatsTests;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStatsTests;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStatsTests;
|
||||
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
|
||||
|
||||
import java.util.ArrayList;
|
||||
@ -21,6 +27,13 @@ import java.util.stream.IntStream;
|
||||
|
||||
public class GetDataFrameAnalyticsStatsActionResponseTests extends AbstractWireSerializingTestCase<Response> {
|
||||
|
||||
@Override
|
||||
protected NamedWriteableRegistry getNamedWriteableRegistry() {
|
||||
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
|
||||
namedWriteables.addAll(new AnalysisStatsNamedWriteablesProvider().getNamedWriteables());
|
||||
return new NamedWriteableRegistry(namedWriteables);
|
||||
}
|
||||
|
||||
public static Response randomResponse(int listSize) {
|
||||
List<Response.Stats> analytics = new ArrayList<>(listSize);
|
||||
for (int j = 0; j < listSize; j++) {
|
||||
@ -30,8 +43,15 @@ public class GetDataFrameAnalyticsStatsActionResponseTests extends AbstractWireS
|
||||
IntStream.of(progressSize).forEach(progressIndex -> progress.add(
|
||||
new PhaseProgress(randomAlphaOfLength(10), randomIntBetween(0, 100))));
|
||||
MemoryUsage memoryUsage = randomBoolean() ? null : MemoryUsageTests.createRandom();
|
||||
AnalysisStats analysisStats = randomBoolean() ? null :
|
||||
randomFrom(
|
||||
ClassificationStatsTests.createRandom(),
|
||||
OutlierDetectionStatsTests.createRandom(),
|
||||
RegressionStatsTests.createRandom()
|
||||
);
|
||||
Response.Stats stats = new Response.Stats(DataFrameAnalyticsConfigTests.randomValidId(),
|
||||
randomFrom(DataFrameAnalyticsState.values()), failureReason, progress, memoryUsage, null, randomAlphaOfLength(20));
|
||||
randomFrom(DataFrameAnalyticsState.values()), failureReason, progress, memoryUsage, analysisStats, null,
|
||||
randomAlphaOfLength(20));
|
||||
analytics.add(stats);
|
||||
}
|
||||
return new Response(new QueryPage<>(analytics, analytics.size(), GetDataFrameAnalyticsAction.Response.RESULTS_FIELD));
|
||||
|
@ -0,0 +1,69 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.stats.classification;
|
||||
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.ToXContent;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.time.Instant;
|
||||
import java.util.Collections;
|
||||
|
||||
public class ClassificationStatsTests extends AbstractBWCSerializationTestCase<ClassificationStats> {
|
||||
|
||||
private boolean lenient;
|
||||
|
||||
@Before
|
||||
public void chooseLenient() {
|
||||
lenient = randomBoolean();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return lenient;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected ClassificationStats mutateInstanceForVersion(ClassificationStats instance, Version version) {
|
||||
return instance;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected ClassificationStats doParseInstance(XContentParser parser) throws IOException {
|
||||
return lenient ? ClassificationStats.LENIENT_PARSER.apply(parser, null) : ClassificationStats.STRICT_PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected ToXContent.Params getToXContentParams() {
|
||||
return new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true"));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<ClassificationStats> instanceReader() {
|
||||
return ClassificationStats::new;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected ClassificationStats createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
public static ClassificationStats createRandom() {
|
||||
return new ClassificationStats(
|
||||
randomAlphaOfLength(10),
|
||||
Instant.now(),
|
||||
randomIntBetween(1, Integer.MAX_VALUE),
|
||||
HyperparametersTests.createRandom(),
|
||||
TimingStatsTests.createRandom(),
|
||||
ValidationLossTests.createRandom()
|
||||
);
|
||||
}
|
||||
}
|
@ -0,0 +1,69 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.stats.classification;
|
||||
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public class HyperparametersTests extends AbstractBWCSerializationTestCase<Hyperparameters> {
|
||||
|
||||
private boolean lenient;
|
||||
|
||||
@Before
|
||||
public void chooseLenient() {
|
||||
lenient = randomBoolean();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return lenient;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Hyperparameters mutateInstanceForVersion(Hyperparameters instance, Version version) {
|
||||
return instance;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Hyperparameters doParseInstance(XContentParser parser) throws IOException {
|
||||
return Hyperparameters.fromXContent(parser, lenient);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<Hyperparameters> instanceReader() {
|
||||
return Hyperparameters::new;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Hyperparameters createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
public static Hyperparameters createRandom() {
|
||||
return new Hyperparameters(
|
||||
randomAlphaOfLength(10),
|
||||
randomDouble(),
|
||||
randomDouble(),
|
||||
randomDouble(),
|
||||
randomDouble(),
|
||||
randomIntBetween(0, Integer.MAX_VALUE),
|
||||
randomIntBetween(0, Integer.MAX_VALUE),
|
||||
randomIntBetween(0, Integer.MAX_VALUE),
|
||||
randomIntBetween(0, Integer.MAX_VALUE),
|
||||
randomIntBetween(0, Integer.MAX_VALUE),
|
||||
randomDouble(),
|
||||
randomDouble(),
|
||||
randomDouble(),
|
||||
randomDouble(),
|
||||
randomDouble()
|
||||
);
|
||||
}
|
||||
}
|
@ -0,0 +1,54 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.stats.classification;
|
||||
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.unit.TimeValue;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public class TimingStatsTests extends AbstractBWCSerializationTestCase<TimingStats> {
|
||||
|
||||
private boolean lenient;
|
||||
|
||||
@Before
|
||||
public void chooseLenient() {
|
||||
lenient = randomBoolean();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return lenient;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected TimingStats mutateInstanceForVersion(TimingStats instance, Version version) {
|
||||
return instance;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected TimingStats doParseInstance(XContentParser parser) throws IOException {
|
||||
return TimingStats.fromXContent(parser, lenient);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<TimingStats> instanceReader() {
|
||||
return TimingStats::new;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected TimingStats createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
public static TimingStats createRandom() {
|
||||
return new TimingStats(TimeValue.timeValueMillis(randomNonNegativeLong()), TimeValue.timeValueMillis(randomNonNegativeLong()));
|
||||
}
|
||||
}
|
@ -0,0 +1,52 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.stats.classification;
|
||||
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.common.FoldValuesTests;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public class ValidationLossTests extends AbstractBWCSerializationTestCase<ValidationLoss> {
|
||||
|
||||
private boolean lenient;
|
||||
|
||||
@Before
|
||||
public void chooseStrictOrLenient() {
|
||||
lenient = randomBoolean();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected ValidationLoss doParseInstance(XContentParser parser) throws IOException {
|
||||
return ValidationLoss.fromXContent(parser, lenient);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<ValidationLoss> instanceReader() {
|
||||
return ValidationLoss::new;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected ValidationLoss createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
public static ValidationLoss createRandom() {
|
||||
return new ValidationLoss(
|
||||
randomAlphaOfLength(10),
|
||||
randomList(5, () -> FoldValuesTests.createRandom())
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected ValidationLoss mutateInstanceForVersion(ValidationLoss instance, Version version) {
|
||||
return instance;
|
||||
}
|
||||
}
|
@ -0,0 +1,58 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.stats.common;
|
||||
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public class FoldValuesTests extends AbstractBWCSerializationTestCase<FoldValues> {
|
||||
|
||||
private boolean lenient;
|
||||
|
||||
@Before
|
||||
public void chooseLenient() {
|
||||
lenient = randomBoolean();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return lenient;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected FoldValues doParseInstance(XContentParser parser) throws IOException {
|
||||
return FoldValues.fromXContent(parser, lenient);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<FoldValues> instanceReader() {
|
||||
return FoldValues::new;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected FoldValues createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
public static FoldValues createRandom() {
|
||||
int valuesSize = randomIntBetween(0, 10);
|
||||
double[] values = new double[valuesSize];
|
||||
for (int i = 0; i < valuesSize; i++) {
|
||||
values[i] = randomDouble();
|
||||
}
|
||||
return new FoldValues(randomIntBetween(0, Integer.MAX_VALUE), values);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected FoldValues mutateInstanceForVersion(FoldValues instance, Version version) {
|
||||
return instance;
|
||||
}
|
||||
}
|
@ -0,0 +1,68 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection;
|
||||
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.ToXContent;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.time.Instant;
|
||||
import java.util.Collections;
|
||||
|
||||
public class OutlierDetectionStatsTests extends AbstractBWCSerializationTestCase<OutlierDetectionStats> {
|
||||
|
||||
private boolean lenient;
|
||||
|
||||
@Before
|
||||
public void chooseLenient() {
|
||||
lenient = randomBoolean();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return lenient;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected OutlierDetectionStats mutateInstanceForVersion(OutlierDetectionStats instance, Version version) {
|
||||
return instance;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected OutlierDetectionStats doParseInstance(XContentParser parser) throws IOException {
|
||||
return lenient ? OutlierDetectionStats.LENIENT_PARSER.apply(parser, null)
|
||||
: OutlierDetectionStats.STRICT_PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected ToXContent.Params getToXContentParams() {
|
||||
return new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true"));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<OutlierDetectionStats> instanceReader() {
|
||||
return OutlierDetectionStats::new;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected OutlierDetectionStats createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
public static OutlierDetectionStats createRandom() {
|
||||
return new OutlierDetectionStats(
|
||||
randomAlphaOfLength(10),
|
||||
Instant.now(),
|
||||
ParametersTests.createRandom(),
|
||||
TimingStatsTests.createRandom()
|
||||
);
|
||||
}
|
||||
}
|
@ -0,0 +1,61 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection;
|
||||
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public class ParametersTests extends AbstractBWCSerializationTestCase<Parameters> {
|
||||
|
||||
private boolean lenient;
|
||||
|
||||
@Before
|
||||
public void chooseLenient() {
|
||||
lenient = randomBoolean();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return lenient;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Parameters mutateInstanceForVersion(Parameters instance, Version version) {
|
||||
return instance;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Parameters doParseInstance(XContentParser parser) throws IOException {
|
||||
return Parameters.fromXContent(parser, lenient);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<Parameters> instanceReader() {
|
||||
return Parameters::new;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Parameters createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
public static Parameters createRandom() {
|
||||
|
||||
return new Parameters(
|
||||
randomIntBetween(1, Integer.MAX_VALUE),
|
||||
randomAlphaOfLength(5),
|
||||
randomBoolean(),
|
||||
randomDouble(),
|
||||
randomDouble(),
|
||||
randomBoolean()
|
||||
);
|
||||
}
|
||||
}
|
@ -0,0 +1,54 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection;
|
||||
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.unit.TimeValue;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public class TimingStatsTests extends AbstractBWCSerializationTestCase<TimingStats> {
|
||||
|
||||
private boolean lenient;
|
||||
|
||||
@Before
|
||||
public void chooseLenient() {
|
||||
lenient = randomBoolean();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return lenient;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected TimingStats mutateInstanceForVersion(TimingStats instance, Version version) {
|
||||
return instance;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected TimingStats doParseInstance(XContentParser parser) throws IOException {
|
||||
return TimingStats.fromXContent(parser, lenient);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<TimingStats> instanceReader() {
|
||||
return TimingStats::new;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected TimingStats createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
public static TimingStats createRandom() {
|
||||
return new TimingStats(TimeValue.timeValueMillis(randomNonNegativeLong()));
|
||||
}
|
||||
}
|
@ -0,0 +1,68 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.stats.regression;
|
||||
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public class HyperparametersTests extends AbstractBWCSerializationTestCase<Hyperparameters> {
|
||||
|
||||
private boolean lenient;
|
||||
|
||||
@Before
|
||||
public void chooseLenient() {
|
||||
lenient = randomBoolean();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return lenient;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Hyperparameters mutateInstanceForVersion(Hyperparameters instance, Version version) {
|
||||
return instance;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Hyperparameters doParseInstance(XContentParser parser) throws IOException {
|
||||
return Hyperparameters.fromXContent(parser, lenient);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<Hyperparameters> instanceReader() {
|
||||
return Hyperparameters::new;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Hyperparameters createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
public static Hyperparameters createRandom() {
|
||||
return new Hyperparameters(
|
||||
randomDouble(),
|
||||
randomDouble(),
|
||||
randomDouble(),
|
||||
randomDouble(),
|
||||
randomIntBetween(0, Integer.MAX_VALUE),
|
||||
randomIntBetween(0, Integer.MAX_VALUE),
|
||||
randomIntBetween(0, Integer.MAX_VALUE),
|
||||
randomIntBetween(0, Integer.MAX_VALUE),
|
||||
randomIntBetween(0, Integer.MAX_VALUE),
|
||||
randomDouble(),
|
||||
randomDouble(),
|
||||
randomDouble(),
|
||||
randomDouble(),
|
||||
randomDouble()
|
||||
);
|
||||
}
|
||||
}
|
@ -0,0 +1,69 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.stats.regression;
|
||||
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.ToXContent;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.time.Instant;
|
||||
import java.util.Collections;
|
||||
|
||||
public class RegressionStatsTests extends AbstractBWCSerializationTestCase<RegressionStats> {
|
||||
|
||||
private boolean lenient;
|
||||
|
||||
@Before
|
||||
public void chooseLenient() {
|
||||
lenient = randomBoolean();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return lenient;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected RegressionStats mutateInstanceForVersion(RegressionStats instance, Version version) {
|
||||
return instance;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected RegressionStats doParseInstance(XContentParser parser) throws IOException {
|
||||
return lenient ? RegressionStats.LENIENT_PARSER.apply(parser, null) : RegressionStats.STRICT_PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected ToXContent.Params getToXContentParams() {
|
||||
return new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true"));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<RegressionStats> instanceReader() {
|
||||
return RegressionStats::new;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected RegressionStats createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
public static RegressionStats createRandom() {
|
||||
return new RegressionStats(
|
||||
randomAlphaOfLength(10),
|
||||
Instant.now(),
|
||||
randomIntBetween(1, Integer.MAX_VALUE),
|
||||
HyperparametersTests.createRandom(),
|
||||
TimingStatsTests.createRandom(),
|
||||
ValidationLossTests.createRandom()
|
||||
);
|
||||
}
|
||||
}
|
@ -0,0 +1,54 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.stats.regression;
|
||||
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.unit.TimeValue;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public class TimingStatsTests extends AbstractBWCSerializationTestCase<TimingStats> {
|
||||
|
||||
private boolean lenient;
|
||||
|
||||
@Before
|
||||
public void chooseLenient() {
|
||||
lenient = randomBoolean();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return lenient;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected TimingStats mutateInstanceForVersion(TimingStats instance, Version version) {
|
||||
return instance;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected TimingStats doParseInstance(XContentParser parser) throws IOException {
|
||||
return TimingStats.fromXContent(parser, lenient);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<TimingStats> instanceReader() {
|
||||
return TimingStats::new;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected TimingStats createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
public static TimingStats createRandom() {
|
||||
return new TimingStats(TimeValue.timeValueMillis(randomNonNegativeLong()), TimeValue.timeValueMillis(randomNonNegativeLong()));
|
||||
}
|
||||
}
|
@ -0,0 +1,52 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.stats.regression;
|
||||
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.common.FoldValuesTests;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public class ValidationLossTests extends AbstractBWCSerializationTestCase<ValidationLoss> {
|
||||
|
||||
private boolean lenient;
|
||||
|
||||
@Before
|
||||
public void chooseStrictOrLenient() {
|
||||
lenient = randomBoolean();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected ValidationLoss doParseInstance(XContentParser parser) throws IOException {
|
||||
return ValidationLoss.fromXContent(parser, lenient);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<ValidationLoss> instanceReader() {
|
||||
return ValidationLoss::new;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected ValidationLoss createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
public static ValidationLoss createRandom() {
|
||||
return new ValidationLoss(
|
||||
randomAlphaOfLength(10),
|
||||
randomList(5, () -> FoldValuesTests.createRandom())
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected ValidationLoss mutateInstanceForVersion(ValidationLoss instance, Version version) {
|
||||
return instance;
|
||||
}
|
||||
}
|
@ -41,7 +41,12 @@ import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction.R
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStats;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.Fields;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.MemoryUsage;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStats;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStats;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStats;
|
||||
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
|
||||
@ -103,7 +108,8 @@ public class TransportGetDataFrameAnalyticsStatsAction
|
||||
Stats stats = buildStats(
|
||||
task.getParams().getId(),
|
||||
statsHolder.getProgressTracker().report(),
|
||||
statsHolder.getMemoryUsage()
|
||||
statsHolder.getMemoryUsage(),
|
||||
statsHolder.getAnalysisStats()
|
||||
);
|
||||
listener.onResponse(new QueryPage<>(Collections.singletonList(stats), 1,
|
||||
GetDataFrameAnalyticsAction.Response.RESULTS_FIELD));
|
||||
@ -192,7 +198,10 @@ public class TransportGetDataFrameAnalyticsStatsAction
|
||||
|
||||
MultiSearchRequest multiSearchRequest = new MultiSearchRequest();
|
||||
multiSearchRequest.add(buildStoredProgressSearch(configId));
|
||||
multiSearchRequest.add(buildMemoryUsageSearch(configId));
|
||||
multiSearchRequest.add(buildStatsDocSearch(configId, MemoryUsage.TYPE_VALUE));
|
||||
multiSearchRequest.add(buildStatsDocSearch(configId, OutlierDetectionStats.TYPE_VALUE));
|
||||
multiSearchRequest.add(buildStatsDocSearch(configId, ClassificationStats.TYPE_VALUE));
|
||||
multiSearchRequest.add(buildStatsDocSearch(configId, RegressionStats.TYPE_VALUE));
|
||||
|
||||
executeAsyncWithOrigin(client, ML_ORIGIN, MultiSearchAction.INSTANCE, multiSearchRequest, ActionListener.wrap(
|
||||
multiSearchResponse -> {
|
||||
@ -213,7 +222,8 @@ public class TransportGetDataFrameAnalyticsStatsAction
|
||||
}
|
||||
listener.onResponse(buildStats(configId,
|
||||
retrievedStatsHolder.progress.get(),
|
||||
retrievedStatsHolder.memoryUsage
|
||||
retrievedStatsHolder.memoryUsage,
|
||||
retrievedStatsHolder.analysisStats
|
||||
));
|
||||
},
|
||||
e -> listener.onFailure(ExceptionsHelper.serverError("Error searching for stats", e))
|
||||
@ -228,18 +238,17 @@ public class TransportGetDataFrameAnalyticsStatsAction
|
||||
return searchRequest;
|
||||
}
|
||||
|
||||
private static SearchRequest buildMemoryUsageSearch(String configId) {
|
||||
private static SearchRequest buildStatsDocSearch(String configId, String statsType) {
|
||||
SearchRequest searchRequest = new SearchRequest(MlStatsIndex.indexPattern());
|
||||
searchRequest.indicesOptions(IndicesOptions.lenientExpandOpen());
|
||||
searchRequest.source().size(1);
|
||||
QueryBuilder query = QueryBuilders.boolQuery()
|
||||
.filter(QueryBuilders.termQuery(MemoryUsage.JOB_ID.getPreferredName(), configId))
|
||||
.filter(QueryBuilders.termQuery(MemoryUsage.TYPE.getPreferredName(), MemoryUsage.TYPE_VALUE));
|
||||
.filter(QueryBuilders.termQuery(Fields.JOB_ID.getPreferredName(), configId))
|
||||
.filter(QueryBuilders.termQuery(Fields.TYPE.getPreferredName(), statsType));
|
||||
searchRequest.source().query(query);
|
||||
searchRequest.source().sort(SortBuilders.fieldSort(MemoryUsage.TIMESTAMP.getPreferredName()).order(SortOrder.DESC)
|
||||
searchRequest.source().sort(SortBuilders.fieldSort(Fields.TIMESTAMP.getPreferredName()).order(SortOrder.DESC)
|
||||
// We need this for the search not to fail when there are no mappings yet in the index
|
||||
.unmappedType("long"));
|
||||
searchRequest.source().sort(MemoryUsage.TIMESTAMP.getPreferredName(), SortOrder.DESC);
|
||||
return searchRequest;
|
||||
}
|
||||
|
||||
@ -249,6 +258,12 @@ public class TransportGetDataFrameAnalyticsStatsAction
|
||||
retrievedStatsHolder.progress = MlParserUtils.parse(hit, StoredProgress.PARSER);
|
||||
} else if (hitId.startsWith(MemoryUsage.documentIdPrefix(configId))) {
|
||||
retrievedStatsHolder.memoryUsage = MlParserUtils.parse(hit, MemoryUsage.LENIENT_PARSER);
|
||||
} else if (hitId.startsWith(OutlierDetectionStats.documentIdPrefix(configId))) {
|
||||
retrievedStatsHolder.analysisStats = MlParserUtils.parse(hit, OutlierDetectionStats.LENIENT_PARSER);
|
||||
} else if (hitId.startsWith(ClassificationStats.documentIdPrefix(configId))) {
|
||||
retrievedStatsHolder.analysisStats = MlParserUtils.parse(hit, ClassificationStats.LENIENT_PARSER);
|
||||
} else if (hitId.startsWith(RegressionStats.documentIdPrefix(configId))) {
|
||||
retrievedStatsHolder.analysisStats = MlParserUtils.parse(hit, RegressionStats.LENIENT_PARSER);
|
||||
} else {
|
||||
throw ExceptionsHelper.serverError("unexpected doc id [" + hitId + "]");
|
||||
}
|
||||
@ -256,7 +271,8 @@ public class TransportGetDataFrameAnalyticsStatsAction
|
||||
|
||||
private GetDataFrameAnalyticsStatsAction.Response.Stats buildStats(String concreteAnalyticsId,
|
||||
List<PhaseProgress> progress,
|
||||
MemoryUsage memoryUsage) {
|
||||
MemoryUsage memoryUsage,
|
||||
AnalysisStats analysisStats) {
|
||||
ClusterState clusterState = clusterService.state();
|
||||
PersistentTasksCustomMetaData tasks = clusterState.getMetaData().custom(PersistentTasksCustomMetaData.TYPE);
|
||||
PersistentTasksCustomMetaData.PersistentTask<?> analyticsTask = MlTasks.getDataFrameAnalyticsTask(concreteAnalyticsId, tasks);
|
||||
@ -278,6 +294,7 @@ public class TransportGetDataFrameAnalyticsStatsAction
|
||||
failureReason,
|
||||
progress,
|
||||
memoryUsage,
|
||||
analysisStats,
|
||||
node,
|
||||
assignmentExplanation
|
||||
);
|
||||
@ -287,5 +304,6 @@ public class TransportGetDataFrameAnalyticsStatsAction
|
||||
|
||||
private volatile StoredProgress progress = new StoredProgress(new ProgressTracker().report());
|
||||
private volatile MemoryUsage memoryUsage;
|
||||
private volatile AnalysisStats analysisStats;
|
||||
}
|
||||
}
|
||||
|
@ -23,6 +23,9 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.MemoryUsage;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStats;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStats;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStats;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
|
||||
@ -175,6 +178,21 @@ public class AnalyticsResultProcessor {
|
||||
statsHolder.setMemoryUsage(memoryUsage);
|
||||
indexStatsResult(memoryUsage, memoryUsage::documentId);
|
||||
}
|
||||
OutlierDetectionStats outlierDetectionStats = result.getOutlierDetectionStats();
|
||||
if (outlierDetectionStats != null) {
|
||||
statsHolder.setAnalysisStats(outlierDetectionStats);
|
||||
indexStatsResult(outlierDetectionStats, outlierDetectionStats::documentId);
|
||||
}
|
||||
ClassificationStats classificationStats = result.getClassificationStats();
|
||||
if (classificationStats != null) {
|
||||
statsHolder.setAnalysisStats(classificationStats);
|
||||
indexStatsResult(classificationStats, classificationStats::documentId);
|
||||
}
|
||||
RegressionStats regressionStats = result.getRegressionStats();
|
||||
if (regressionStats != null) {
|
||||
statsHolder.setAnalysisStats(regressionStats);
|
||||
indexStatsResult(regressionStats, regressionStats::documentId);
|
||||
}
|
||||
}
|
||||
|
||||
private void createAndIndexInferenceModel(TrainedModelDefinition.Builder inferenceModel) {
|
||||
|
@ -12,6 +12,9 @@ import org.elasticsearch.common.xcontent.ToXContent;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.MemoryUsage;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStats;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStats;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStats;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
|
||||
|
||||
import java.io.IOException;
|
||||
@ -28,9 +31,20 @@ public class AnalyticsResult implements ToXContentObject {
|
||||
private static final ParseField PROGRESS_PERCENT = new ParseField("progress_percent");
|
||||
private static final ParseField INFERENCE_MODEL = new ParseField("inference_model");
|
||||
private static final ParseField ANALYTICS_MEMORY_USAGE = new ParseField("analytics_memory_usage");
|
||||
private static final ParseField OUTLIER_DETECTION_STATS = new ParseField("outlier_detection_stats");
|
||||
private static final ParseField CLASSIFICATION_STATS = new ParseField("classification_stats");
|
||||
private static final ParseField REGRESSION_STATS = new ParseField("regression_stats");
|
||||
|
||||
public static final ConstructingObjectParser<AnalyticsResult, Void> PARSER = new ConstructingObjectParser<>(TYPE.getPreferredName(),
|
||||
a -> new AnalyticsResult((RowResults) a[0], (Integer) a[1], (TrainedModelDefinition.Builder) a[2], (MemoryUsage) a[3]));
|
||||
a -> new AnalyticsResult(
|
||||
(RowResults) a[0],
|
||||
(Integer) a[1],
|
||||
(TrainedModelDefinition.Builder) a[2],
|
||||
(MemoryUsage) a[3],
|
||||
(OutlierDetectionStats) a[4],
|
||||
(ClassificationStats) a[5],
|
||||
(RegressionStats) a[6]
|
||||
));
|
||||
|
||||
static {
|
||||
PARSER.declareObject(optionalConstructorArg(), RowResults.PARSER, RowResults.TYPE);
|
||||
@ -38,6 +52,9 @@ public class AnalyticsResult implements ToXContentObject {
|
||||
// TODO change back to STRICT_PARSER once native side is aligned
|
||||
PARSER.declareObject(optionalConstructorArg(), TrainedModelDefinition.LENIENT_PARSER, INFERENCE_MODEL);
|
||||
PARSER.declareObject(optionalConstructorArg(), MemoryUsage.STRICT_PARSER, ANALYTICS_MEMORY_USAGE);
|
||||
PARSER.declareObject(optionalConstructorArg(), OutlierDetectionStats.STRICT_PARSER, OUTLIER_DETECTION_STATS);
|
||||
PARSER.declareObject(optionalConstructorArg(), ClassificationStats.STRICT_PARSER, CLASSIFICATION_STATS);
|
||||
PARSER.declareObject(optionalConstructorArg(), RegressionStats.STRICT_PARSER, REGRESSION_STATS);
|
||||
}
|
||||
|
||||
private final RowResults rowResults;
|
||||
@ -45,16 +62,25 @@ public class AnalyticsResult implements ToXContentObject {
|
||||
private final TrainedModelDefinition.Builder inferenceModelBuilder;
|
||||
private final TrainedModelDefinition inferenceModel;
|
||||
private final MemoryUsage memoryUsage;
|
||||
private final OutlierDetectionStats outlierDetectionStats;
|
||||
private final ClassificationStats classificationStats;
|
||||
private final RegressionStats regressionStats;
|
||||
|
||||
public AnalyticsResult(@Nullable RowResults rowResults,
|
||||
@Nullable Integer progressPercent,
|
||||
@Nullable TrainedModelDefinition.Builder inferenceModelBuilder,
|
||||
@Nullable MemoryUsage memoryUsage) {
|
||||
@Nullable MemoryUsage memoryUsage,
|
||||
@Nullable OutlierDetectionStats outlierDetectionStats,
|
||||
@Nullable ClassificationStats classificationStats,
|
||||
@Nullable RegressionStats regressionStats) {
|
||||
this.rowResults = rowResults;
|
||||
this.progressPercent = progressPercent;
|
||||
this.inferenceModelBuilder = inferenceModelBuilder;
|
||||
this.inferenceModel = inferenceModelBuilder == null ? null : inferenceModelBuilder.build();
|
||||
this.memoryUsage = memoryUsage;
|
||||
this.outlierDetectionStats = outlierDetectionStats;
|
||||
this.classificationStats = classificationStats;
|
||||
this.regressionStats = regressionStats;
|
||||
}
|
||||
|
||||
public RowResults getRowResults() {
|
||||
@ -73,6 +99,18 @@ public class AnalyticsResult implements ToXContentObject {
|
||||
return memoryUsage;
|
||||
}
|
||||
|
||||
public OutlierDetectionStats getOutlierDetectionStats() {
|
||||
return outlierDetectionStats;
|
||||
}
|
||||
|
||||
public ClassificationStats getClassificationStats() {
|
||||
return classificationStats;
|
||||
}
|
||||
|
||||
public RegressionStats getRegressionStats() {
|
||||
return regressionStats;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
@ -90,6 +128,15 @@ public class AnalyticsResult implements ToXContentObject {
|
||||
if (memoryUsage != null) {
|
||||
builder.field(ANALYTICS_MEMORY_USAGE.getPreferredName(), memoryUsage, params);
|
||||
}
|
||||
if (outlierDetectionStats != null) {
|
||||
builder.field(OUTLIER_DETECTION_STATS.getPreferredName(), outlierDetectionStats, params);
|
||||
}
|
||||
if (classificationStats != null) {
|
||||
builder.field(CLASSIFICATION_STATS.getPreferredName(), classificationStats, params);
|
||||
}
|
||||
if (regressionStats != null) {
|
||||
builder.field(REGRESSION_STATS.getPreferredName(), regressionStats, params);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
@ -107,11 +154,15 @@ public class AnalyticsResult implements ToXContentObject {
|
||||
return Objects.equals(rowResults, that.rowResults)
|
||||
&& Objects.equals(progressPercent, that.progressPercent)
|
||||
&& Objects.equals(inferenceModel, that.inferenceModel)
|
||||
&& Objects.equals(memoryUsage, that.memoryUsage);
|
||||
&& Objects.equals(memoryUsage, that.memoryUsage)
|
||||
&& Objects.equals(outlierDetectionStats, that.outlierDetectionStats)
|
||||
&& Objects.equals(classificationStats, that.classificationStats)
|
||||
&& Objects.equals(regressionStats, that.regressionStats);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(rowResults, progressPercent, inferenceModel, memoryUsage);
|
||||
return Objects.hash(rowResults, progressPercent, inferenceModel, memoryUsage, outlierDetectionStats, classificationStats,
|
||||
regressionStats);
|
||||
}
|
||||
}
|
||||
|
@ -5,6 +5,7 @@
|
||||
*/
|
||||
package org.elasticsearch.xpack.ml.dataframe.stats;
|
||||
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStats;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.MemoryUsage;
|
||||
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
@ -17,10 +18,12 @@ public class StatsHolder {
|
||||
|
||||
private final ProgressTracker progressTracker;
|
||||
private final AtomicReference<MemoryUsage> memoryUsageHolder;
|
||||
private final AtomicReference<AnalysisStats> analysisStatsHolder;
|
||||
|
||||
public StatsHolder() {
|
||||
progressTracker = new ProgressTracker();
|
||||
memoryUsageHolder = new AtomicReference<>();
|
||||
analysisStatsHolder = new AtomicReference<>();
|
||||
}
|
||||
|
||||
public ProgressTracker getProgressTracker() {
|
||||
@ -34,4 +37,12 @@ public class StatsHolder {
|
||||
public MemoryUsage getMemoryUsage() {
|
||||
return memoryUsageHolder.get();
|
||||
}
|
||||
|
||||
public void setAnalysisStats(AnalysisStats analysisStats) {
|
||||
analysisStatsHolder.set(analysisStats);
|
||||
}
|
||||
|
||||
public AnalysisStats getAnalysisStats() {
|
||||
return analysisStatsHolder.get();
|
||||
}
|
||||
}
|
||||
|
@ -56,7 +56,7 @@ public class AnalyticsProcessManagerTests extends ESTestCase {
|
||||
private static final String CONFIG_ID = "config-id";
|
||||
private static final int NUM_ROWS = 100;
|
||||
private static final int NUM_COLS = 4;
|
||||
private static final AnalyticsResult PROCESS_RESULT = new AnalyticsResult(null, null, null, null);
|
||||
private static final AnalyticsResult PROCESS_RESULT = new AnalyticsResult(null, null, null, null, null, null, null);
|
||||
|
||||
private Client client;
|
||||
private DataFrameAnalyticsAuditor auditor;
|
||||
|
@ -100,7 +100,9 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
|
||||
|
||||
public void testProcess_GivenEmptyResults() {
|
||||
givenDataFrameRows(2);
|
||||
givenProcessResults(Arrays.asList(new AnalyticsResult(null, 50, null, null), new AnalyticsResult(null, 100, null, null)));
|
||||
givenProcessResults(Arrays.asList(
|
||||
new AnalyticsResult(null, 50, null, null, null, null, null),
|
||||
new AnalyticsResult(null, 100, null, null, null, null, null)));
|
||||
AnalyticsResultProcessor resultProcessor = createResultProcessor();
|
||||
|
||||
resultProcessor.process(process);
|
||||
@ -115,8 +117,8 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
|
||||
givenDataFrameRows(2);
|
||||
RowResults rowResults1 = mock(RowResults.class);
|
||||
RowResults rowResults2 = mock(RowResults.class);
|
||||
givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1, 50, null, null),
|
||||
new AnalyticsResult(rowResults2, 100, null, null)));
|
||||
givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1, 50, null, null, null, null, null),
|
||||
new AnalyticsResult(rowResults2, 100, null, null, null, null, null)));
|
||||
AnalyticsResultProcessor resultProcessor = createResultProcessor();
|
||||
|
||||
resultProcessor.process(process);
|
||||
@ -133,8 +135,8 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
|
||||
givenDataFrameRows(2);
|
||||
RowResults rowResults1 = mock(RowResults.class);
|
||||
RowResults rowResults2 = mock(RowResults.class);
|
||||
givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1, 50, null, null),
|
||||
new AnalyticsResult(rowResults2, 100, null, null)));
|
||||
givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1, 50, null, null, null, null, null),
|
||||
new AnalyticsResult(rowResults2, 100, null, null, null, null, null)));
|
||||
|
||||
doThrow(new RuntimeException("some failure")).when(dataFrameRowsJoiner).processRowResults(any(RowResults.class));
|
||||
|
||||
@ -167,7 +169,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
|
||||
extractedFieldList.add(new MultiField("bar", new DocValueField("bar.keyword", Collections.emptySet())));
|
||||
extractedFieldList.add(new DocValueField("baz", Collections.emptySet()));
|
||||
TrainedModelDefinition.Builder inferenceModel = TrainedModelDefinitionTests.createRandomBuilder();
|
||||
givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel, null)));
|
||||
givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel, null, null, null, null)));
|
||||
AnalyticsResultProcessor resultProcessor = createResultProcessor(extractedFieldList);
|
||||
|
||||
resultProcessor.process(process);
|
||||
@ -212,7 +214,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
|
||||
}).when(trainedModelProvider).storeTrainedModel(any(TrainedModelConfig.class), any(ActionListener.class));
|
||||
|
||||
TrainedModelDefinition.Builder inferenceModel = TrainedModelDefinitionTests.createRandomBuilder();
|
||||
givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel, null)));
|
||||
givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel, null, null, null, null)));
|
||||
AnalyticsResultProcessor resultProcessor = createResultProcessor();
|
||||
|
||||
resultProcessor.process(process);
|
||||
|
@ -11,7 +11,14 @@ import org.elasticsearch.common.xcontent.ToXContent;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.search.SearchModule;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.MemoryUsage;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.MemoryUsageTests;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStats;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStatsTests;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStats;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStatsTests;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStats;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStatsTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
|
||||
@ -36,6 +43,10 @@ public class AnalyticsResultTests extends AbstractXContentTestCase<AnalyticsResu
|
||||
RowResults rowResults = null;
|
||||
Integer progressPercent = null;
|
||||
TrainedModelDefinition.Builder inferenceModel = null;
|
||||
MemoryUsage memoryUsage = null;
|
||||
OutlierDetectionStats outlierDetectionStats = null;
|
||||
ClassificationStats classificationStats = null;
|
||||
RegressionStats regressionStats = null;
|
||||
if (randomBoolean()) {
|
||||
rowResults = RowResultsTests.createRandom();
|
||||
}
|
||||
@ -45,7 +56,20 @@ public class AnalyticsResultTests extends AbstractXContentTestCase<AnalyticsResu
|
||||
if (randomBoolean()) {
|
||||
inferenceModel = TrainedModelDefinitionTests.createRandomBuilder();
|
||||
}
|
||||
return new AnalyticsResult(rowResults, progressPercent, inferenceModel, MemoryUsageTests.createRandom());
|
||||
if (randomBoolean()) {
|
||||
memoryUsage = MemoryUsageTests.createRandom();
|
||||
}
|
||||
if (randomBoolean()) {
|
||||
outlierDetectionStats = OutlierDetectionStatsTests.createRandom();
|
||||
}
|
||||
if (randomBoolean()) {
|
||||
classificationStats = ClassificationStatsTests.createRandom();
|
||||
}
|
||||
if (randomBoolean()) {
|
||||
regressionStats = RegressionStatsTests.createRandom();
|
||||
}
|
||||
return new AnalyticsResult(rowResults, progressPercent, inferenceModel, memoryUsage, outlierDetectionStats, classificationStats,
|
||||
regressionStats);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
Loading…
x
Reference in New Issue
Block a user