This commit instruments data frame analytics with stats for the data that are being analyzed. In particular, we count training docs, test docs, and skipped docs. In order to account docs with missing values as skipped docs for analyses that do not support missing values, this commit changes the extractor so that it only ignores docs with missing values when it collects the data summary, which is used to estimate memory usage. Backport of #53998
This commit is contained in:
parent
7dcacf531f
commit
5ce7c99e74
|
@ -21,6 +21,7 @@ package org.elasticsearch.client.ml.dataframe;
|
||||||
|
|
||||||
import org.elasticsearch.client.ml.NodeAttributes;
|
import org.elasticsearch.client.ml.NodeAttributes;
|
||||||
import org.elasticsearch.client.ml.dataframe.stats.AnalysisStats;
|
import org.elasticsearch.client.ml.dataframe.stats.AnalysisStats;
|
||||||
|
import org.elasticsearch.client.ml.dataframe.stats.common.DataCounts;
|
||||||
import org.elasticsearch.client.ml.dataframe.stats.common.MemoryUsage;
|
import org.elasticsearch.client.ml.dataframe.stats.common.MemoryUsage;
|
||||||
import org.elasticsearch.common.Nullable;
|
import org.elasticsearch.common.Nullable;
|
||||||
import org.elasticsearch.common.ParseField;
|
import org.elasticsearch.common.ParseField;
|
||||||
|
@ -47,6 +48,7 @@ public class DataFrameAnalyticsStats {
|
||||||
static final ParseField STATE = new ParseField("state");
|
static final ParseField STATE = new ParseField("state");
|
||||||
static final ParseField FAILURE_REASON = new ParseField("failure_reason");
|
static final ParseField FAILURE_REASON = new ParseField("failure_reason");
|
||||||
static final ParseField PROGRESS = new ParseField("progress");
|
static final ParseField PROGRESS = new ParseField("progress");
|
||||||
|
static final ParseField DATA_COUNTS = new ParseField("data_counts");
|
||||||
static final ParseField MEMORY_USAGE = new ParseField("memory_usage");
|
static final ParseField MEMORY_USAGE = new ParseField("memory_usage");
|
||||||
static final ParseField ANALYSIS_STATS = new ParseField("analysis_stats");
|
static final ParseField ANALYSIS_STATS = new ParseField("analysis_stats");
|
||||||
static final ParseField NODE = new ParseField("node");
|
static final ParseField NODE = new ParseField("node");
|
||||||
|
@ -60,10 +62,11 @@ public class DataFrameAnalyticsStats {
|
||||||
(DataFrameAnalyticsState) args[1],
|
(DataFrameAnalyticsState) args[1],
|
||||||
(String) args[2],
|
(String) args[2],
|
||||||
(List<PhaseProgress>) args[3],
|
(List<PhaseProgress>) args[3],
|
||||||
(MemoryUsage) args[4],
|
(DataCounts) args[4],
|
||||||
(AnalysisStats) args[5],
|
(MemoryUsage) args[5],
|
||||||
(NodeAttributes) args[6],
|
(AnalysisStats) args[6],
|
||||||
(String) args[7]));
|
(NodeAttributes) args[7],
|
||||||
|
(String) args[8]));
|
||||||
|
|
||||||
static {
|
static {
|
||||||
PARSER.declareString(constructorArg(), ID);
|
PARSER.declareString(constructorArg(), ID);
|
||||||
|
@ -75,6 +78,7 @@ public class DataFrameAnalyticsStats {
|
||||||
}, STATE, ObjectParser.ValueType.STRING);
|
}, STATE, ObjectParser.ValueType.STRING);
|
||||||
PARSER.declareString(optionalConstructorArg(), FAILURE_REASON);
|
PARSER.declareString(optionalConstructorArg(), FAILURE_REASON);
|
||||||
PARSER.declareObjectArray(optionalConstructorArg(), PhaseProgress.PARSER, PROGRESS);
|
PARSER.declareObjectArray(optionalConstructorArg(), PhaseProgress.PARSER, PROGRESS);
|
||||||
|
PARSER.declareObject(optionalConstructorArg(), DataCounts.PARSER, DATA_COUNTS);
|
||||||
PARSER.declareObject(optionalConstructorArg(), MemoryUsage.PARSER, MEMORY_USAGE);
|
PARSER.declareObject(optionalConstructorArg(), MemoryUsage.PARSER, MEMORY_USAGE);
|
||||||
PARSER.declareObject(optionalConstructorArg(), (p, c) -> parseAnalysisStats(p), ANALYSIS_STATS);
|
PARSER.declareObject(optionalConstructorArg(), (p, c) -> parseAnalysisStats(p), ANALYSIS_STATS);
|
||||||
PARSER.declareObject(optionalConstructorArg(), NodeAttributes.PARSER, NODE);
|
PARSER.declareObject(optionalConstructorArg(), NodeAttributes.PARSER, NODE);
|
||||||
|
@ -93,19 +97,21 @@ public class DataFrameAnalyticsStats {
|
||||||
private final DataFrameAnalyticsState state;
|
private final DataFrameAnalyticsState state;
|
||||||
private final String failureReason;
|
private final String failureReason;
|
||||||
private final List<PhaseProgress> progress;
|
private final List<PhaseProgress> progress;
|
||||||
|
private final DataCounts dataCounts;
|
||||||
private final MemoryUsage memoryUsage;
|
private final MemoryUsage memoryUsage;
|
||||||
private final AnalysisStats analysisStats;
|
private final AnalysisStats analysisStats;
|
||||||
private final NodeAttributes node;
|
private final NodeAttributes node;
|
||||||
private final String assignmentExplanation;
|
private final String assignmentExplanation;
|
||||||
|
|
||||||
public DataFrameAnalyticsStats(String id, DataFrameAnalyticsState state, @Nullable String failureReason,
|
public DataFrameAnalyticsStats(String id, DataFrameAnalyticsState state, @Nullable String failureReason,
|
||||||
@Nullable List<PhaseProgress> progress, @Nullable MemoryUsage memoryUsage,
|
@Nullable List<PhaseProgress> progress, @Nullable DataCounts dataCounts,
|
||||||
@Nullable AnalysisStats analysisStats, @Nullable NodeAttributes node,
|
@Nullable MemoryUsage memoryUsage, @Nullable AnalysisStats analysisStats, @Nullable NodeAttributes node,
|
||||||
@Nullable String assignmentExplanation) {
|
@Nullable String assignmentExplanation) {
|
||||||
this.id = id;
|
this.id = id;
|
||||||
this.state = state;
|
this.state = state;
|
||||||
this.failureReason = failureReason;
|
this.failureReason = failureReason;
|
||||||
this.progress = progress;
|
this.progress = progress;
|
||||||
|
this.dataCounts = dataCounts;
|
||||||
this.memoryUsage = memoryUsage;
|
this.memoryUsage = memoryUsage;
|
||||||
this.analysisStats = analysisStats;
|
this.analysisStats = analysisStats;
|
||||||
this.node = node;
|
this.node = node;
|
||||||
|
@ -128,6 +134,11 @@ public class DataFrameAnalyticsStats {
|
||||||
return progress;
|
return progress;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Nullable
|
||||||
|
public DataCounts getDataCounts() {
|
||||||
|
return dataCounts;
|
||||||
|
}
|
||||||
|
|
||||||
@Nullable
|
@Nullable
|
||||||
public MemoryUsage getMemoryUsage() {
|
public MemoryUsage getMemoryUsage() {
|
||||||
return memoryUsage;
|
return memoryUsage;
|
||||||
|
@ -156,6 +167,7 @@ public class DataFrameAnalyticsStats {
|
||||||
&& Objects.equals(state, other.state)
|
&& Objects.equals(state, other.state)
|
||||||
&& Objects.equals(failureReason, other.failureReason)
|
&& Objects.equals(failureReason, other.failureReason)
|
||||||
&& Objects.equals(progress, other.progress)
|
&& Objects.equals(progress, other.progress)
|
||||||
|
&& Objects.equals(dataCounts, other.dataCounts)
|
||||||
&& Objects.equals(memoryUsage, other.memoryUsage)
|
&& Objects.equals(memoryUsage, other.memoryUsage)
|
||||||
&& Objects.equals(analysisStats, other.analysisStats)
|
&& Objects.equals(analysisStats, other.analysisStats)
|
||||||
&& Objects.equals(node, other.node)
|
&& Objects.equals(node, other.node)
|
||||||
|
@ -164,7 +176,7 @@ public class DataFrameAnalyticsStats {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int hashCode() {
|
public int hashCode() {
|
||||||
return Objects.hash(id, state, failureReason, progress, memoryUsage, analysisStats, node, assignmentExplanation);
|
return Objects.hash(id, state, failureReason, progress, dataCounts, memoryUsage, analysisStats, node, assignmentExplanation);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -174,6 +186,7 @@ public class DataFrameAnalyticsStats {
|
||||||
.add("state", state)
|
.add("state", state)
|
||||||
.add("failureReason", failureReason)
|
.add("failureReason", failureReason)
|
||||||
.add("progress", progress)
|
.add("progress", progress)
|
||||||
|
.add("dataCounts", dataCounts)
|
||||||
.add("memoryUsage", memoryUsage)
|
.add("memoryUsage", memoryUsage)
|
||||||
.add("analysisStats", analysisStats)
|
.add("analysisStats", analysisStats)
|
||||||
.add("node", node)
|
.add("node", node)
|
||||||
|
|
|
@ -0,0 +1,119 @@
|
||||||
|
/*
|
||||||
|
* Licensed to Elasticsearch under one or more contributor
|
||||||
|
* license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright
|
||||||
|
* ownership. Elasticsearch licenses this file to you under
|
||||||
|
* the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
* not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing,
|
||||||
|
* software distributed under the License is distributed on an
|
||||||
|
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
* KIND, either express or implied. See the License for the
|
||||||
|
* specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.client.ml.dataframe.stats.common;
|
||||||
|
|
||||||
|
import org.elasticsearch.common.Nullable;
|
||||||
|
import org.elasticsearch.common.ParseField;
|
||||||
|
import org.elasticsearch.common.inject.internal.ToStringBuilder;
|
||||||
|
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||||
|
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||||
|
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
|
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||||
|
|
||||||
|
public class DataCounts implements ToXContentObject {
|
||||||
|
|
||||||
|
public static final String TYPE_VALUE = "analytics_data_counts";
|
||||||
|
|
||||||
|
public static final ParseField TRAINING_DOCS_COUNT = new ParseField("training_docs_count");
|
||||||
|
public static final ParseField TEST_DOCS_COUNT = new ParseField("test_docs_count");
|
||||||
|
public static final ParseField SKIPPED_DOCS_COUNT = new ParseField("skipped_docs_count");
|
||||||
|
|
||||||
|
public static final ConstructingObjectParser<DataCounts, Void> PARSER = new ConstructingObjectParser<>(TYPE_VALUE, true,
|
||||||
|
a -> {
|
||||||
|
Long trainingDocsCount = (Long) a[0];
|
||||||
|
Long testDocsCount = (Long) a[1];
|
||||||
|
Long skippedDocsCount = (Long) a[2];
|
||||||
|
return new DataCounts(
|
||||||
|
getOrDefault(trainingDocsCount, 0L),
|
||||||
|
getOrDefault(testDocsCount, 0L),
|
||||||
|
getOrDefault(skippedDocsCount, 0L)
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
static {
|
||||||
|
PARSER.declareLong(optionalConstructorArg(), TRAINING_DOCS_COUNT);
|
||||||
|
PARSER.declareLong(optionalConstructorArg(), TEST_DOCS_COUNT);
|
||||||
|
PARSER.declareLong(optionalConstructorArg(), SKIPPED_DOCS_COUNT);
|
||||||
|
}
|
||||||
|
|
||||||
|
private final long trainingDocsCount;
|
||||||
|
private final long testDocsCount;
|
||||||
|
private final long skippedDocsCount;
|
||||||
|
|
||||||
|
public DataCounts(long trainingDocsCount, long testDocsCount, long skippedDocsCount) {
|
||||||
|
this.trainingDocsCount = trainingDocsCount;
|
||||||
|
this.testDocsCount = testDocsCount;
|
||||||
|
this.skippedDocsCount = skippedDocsCount;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||||
|
builder.startObject();
|
||||||
|
builder.field(TRAINING_DOCS_COUNT.getPreferredName(), trainingDocsCount);
|
||||||
|
builder.field(TEST_DOCS_COUNT.getPreferredName(), testDocsCount);
|
||||||
|
builder.field(SKIPPED_DOCS_COUNT.getPreferredName(), skippedDocsCount);
|
||||||
|
builder.endObject();
|
||||||
|
return builder;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object o) {
|
||||||
|
if (this == o) return true;
|
||||||
|
if (o == null || getClass() != o.getClass()) return false;
|
||||||
|
DataCounts that = (DataCounts) o;
|
||||||
|
return trainingDocsCount == that.trainingDocsCount
|
||||||
|
&& testDocsCount == that.testDocsCount
|
||||||
|
&& skippedDocsCount == that.skippedDocsCount;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
return Objects.hash(trainingDocsCount, testDocsCount, skippedDocsCount);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return new ToStringBuilder(getClass())
|
||||||
|
.add(TRAINING_DOCS_COUNT.getPreferredName(), trainingDocsCount)
|
||||||
|
.add(TEST_DOCS_COUNT.getPreferredName(), testDocsCount)
|
||||||
|
.add(SKIPPED_DOCS_COUNT.getPreferredName(), skippedDocsCount)
|
||||||
|
.toString();
|
||||||
|
}
|
||||||
|
|
||||||
|
public long getTrainingDocsCount() {
|
||||||
|
return trainingDocsCount;
|
||||||
|
}
|
||||||
|
|
||||||
|
public long getTestDocsCount() {
|
||||||
|
return testDocsCount;
|
||||||
|
}
|
||||||
|
|
||||||
|
public long getSkippedDocsCount() {
|
||||||
|
return skippedDocsCount;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static <T> T getOrDefault(@Nullable T value, T defaultValue) {
|
||||||
|
return value != null ? value : defaultValue;
|
||||||
|
}
|
||||||
|
}
|
|
@ -23,6 +23,7 @@ import org.elasticsearch.client.ml.NodeAttributesTests;
|
||||||
import org.elasticsearch.client.ml.dataframe.stats.AnalysisStats;
|
import org.elasticsearch.client.ml.dataframe.stats.AnalysisStats;
|
||||||
import org.elasticsearch.client.ml.dataframe.stats.AnalysisStatsNamedXContentProvider;
|
import org.elasticsearch.client.ml.dataframe.stats.AnalysisStatsNamedXContentProvider;
|
||||||
import org.elasticsearch.client.ml.dataframe.stats.classification.ClassificationStatsTests;
|
import org.elasticsearch.client.ml.dataframe.stats.classification.ClassificationStatsTests;
|
||||||
|
import org.elasticsearch.client.ml.dataframe.stats.common.DataCountsTests;
|
||||||
import org.elasticsearch.client.ml.dataframe.stats.common.MemoryUsageTests;
|
import org.elasticsearch.client.ml.dataframe.stats.common.MemoryUsageTests;
|
||||||
import org.elasticsearch.client.ml.dataframe.stats.outlierdetection.OutlierDetectionStatsTests;
|
import org.elasticsearch.client.ml.dataframe.stats.outlierdetection.OutlierDetectionStatsTests;
|
||||||
import org.elasticsearch.client.ml.dataframe.stats.regression.RegressionStatsTests;
|
import org.elasticsearch.client.ml.dataframe.stats.regression.RegressionStatsTests;
|
||||||
|
@ -68,6 +69,7 @@ public class DataFrameAnalyticsStatsTests extends ESTestCase {
|
||||||
randomFrom(DataFrameAnalyticsState.values()),
|
randomFrom(DataFrameAnalyticsState.values()),
|
||||||
randomBoolean() ? null : randomAlphaOfLength(10),
|
randomBoolean() ? null : randomAlphaOfLength(10),
|
||||||
randomBoolean() ? null : createRandomProgress(),
|
randomBoolean() ? null : createRandomProgress(),
|
||||||
|
randomBoolean() ? null : DataCountsTests.createRandom(),
|
||||||
randomBoolean() ? null : MemoryUsageTests.createRandom(),
|
randomBoolean() ? null : MemoryUsageTests.createRandom(),
|
||||||
analysisStats,
|
analysisStats,
|
||||||
randomBoolean() ? null : NodeAttributesTests.createRandom(),
|
randomBoolean() ? null : NodeAttributesTests.createRandom(),
|
||||||
|
@ -93,6 +95,9 @@ public class DataFrameAnalyticsStatsTests extends ESTestCase {
|
||||||
if (stats.getProgress() != null) {
|
if (stats.getProgress() != null) {
|
||||||
builder.field(DataFrameAnalyticsStats.PROGRESS.getPreferredName(), stats.getProgress());
|
builder.field(DataFrameAnalyticsStats.PROGRESS.getPreferredName(), stats.getProgress());
|
||||||
}
|
}
|
||||||
|
if (stats.getDataCounts() != null) {
|
||||||
|
builder.field(DataFrameAnalyticsStats.DATA_COUNTS.getPreferredName(), stats.getDataCounts());
|
||||||
|
}
|
||||||
if (stats.getMemoryUsage() != null) {
|
if (stats.getMemoryUsage() != null) {
|
||||||
builder.field(DataFrameAnalyticsStats.MEMORY_USAGE.getPreferredName(), stats.getMemoryUsage());
|
builder.field(DataFrameAnalyticsStats.MEMORY_USAGE.getPreferredName(), stats.getMemoryUsage());
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,51 @@
|
||||||
|
/*
|
||||||
|
* Licensed to Elasticsearch under one or more contributor
|
||||||
|
* license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright
|
||||||
|
* ownership. Elasticsearch licenses this file to you under
|
||||||
|
* the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
* not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing,
|
||||||
|
* software distributed under the License is distributed on an
|
||||||
|
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
* KIND, either express or implied. See the License for the
|
||||||
|
* specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.client.ml.dataframe.stats.common;
|
||||||
|
|
||||||
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
|
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
|
||||||
|
public class DataCountsTests extends AbstractXContentTestCase<DataCounts> {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected DataCounts createTestInstance() {
|
||||||
|
return createRandom();
|
||||||
|
}
|
||||||
|
|
||||||
|
public static DataCounts createRandom() {
|
||||||
|
return new DataCounts(
|
||||||
|
randomNonNegativeLong(),
|
||||||
|
randomNonNegativeLong(),
|
||||||
|
randomNonNegativeLong()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected DataCounts doParseInstance(XContentParser parser) throws IOException {
|
||||||
|
return DataCounts.PARSER.apply(parser, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected boolean supportsUnknownFields() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
|
@ -29,6 +29,7 @@ import org.elasticsearch.xpack.core.action.util.QueryPage;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStats;
|
import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStats;
|
||||||
|
import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.MemoryUsage;
|
import org.elasticsearch.xpack.core.ml.dataframe.stats.MemoryUsage;
|
||||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||||
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
|
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
|
||||||
|
@ -165,6 +166,9 @@ public class GetDataFrameAnalyticsStatsAction extends ActionType<GetDataFrameAna
|
||||||
*/
|
*/
|
||||||
private final List<PhaseProgress> progress;
|
private final List<PhaseProgress> progress;
|
||||||
|
|
||||||
|
@Nullable
|
||||||
|
private final DataCounts dataCounts;
|
||||||
|
|
||||||
@Nullable
|
@Nullable
|
||||||
private final MemoryUsage memoryUsage;
|
private final MemoryUsage memoryUsage;
|
||||||
|
|
||||||
|
@ -177,12 +181,13 @@ public class GetDataFrameAnalyticsStatsAction extends ActionType<GetDataFrameAna
|
||||||
private final String assignmentExplanation;
|
private final String assignmentExplanation;
|
||||||
|
|
||||||
public Stats(String id, DataFrameAnalyticsState state, @Nullable String failureReason, List<PhaseProgress> progress,
|
public Stats(String id, DataFrameAnalyticsState state, @Nullable String failureReason, List<PhaseProgress> progress,
|
||||||
@Nullable MemoryUsage memoryUsage, @Nullable AnalysisStats analysisStats, @Nullable DiscoveryNode node,
|
@Nullable DataCounts dataCounts, @Nullable MemoryUsage memoryUsage, @Nullable AnalysisStats analysisStats,
|
||||||
@Nullable String assignmentExplanation) {
|
@Nullable DiscoveryNode node, @Nullable String assignmentExplanation) {
|
||||||
this.id = Objects.requireNonNull(id);
|
this.id = Objects.requireNonNull(id);
|
||||||
this.state = Objects.requireNonNull(state);
|
this.state = Objects.requireNonNull(state);
|
||||||
this.failureReason = failureReason;
|
this.failureReason = failureReason;
|
||||||
this.progress = Objects.requireNonNull(progress);
|
this.progress = Objects.requireNonNull(progress);
|
||||||
|
this.dataCounts = dataCounts;
|
||||||
this.memoryUsage = memoryUsage;
|
this.memoryUsage = memoryUsage;
|
||||||
this.analysisStats = analysisStats;
|
this.analysisStats = analysisStats;
|
||||||
this.node = node;
|
this.node = node;
|
||||||
|
@ -198,6 +203,11 @@ public class GetDataFrameAnalyticsStatsAction extends ActionType<GetDataFrameAna
|
||||||
} else {
|
} else {
|
||||||
progress = in.readList(PhaseProgress::new);
|
progress = in.readList(PhaseProgress::new);
|
||||||
}
|
}
|
||||||
|
if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
|
||||||
|
dataCounts = in.readOptionalWriteable(DataCounts::new);
|
||||||
|
} else {
|
||||||
|
dataCounts = null;
|
||||||
|
}
|
||||||
if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
|
if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
|
||||||
memoryUsage = in.readOptionalWriteable(MemoryUsage::new);
|
memoryUsage = in.readOptionalWriteable(MemoryUsage::new);
|
||||||
} else {
|
} else {
|
||||||
|
@ -261,6 +271,11 @@ public class GetDataFrameAnalyticsStatsAction extends ActionType<GetDataFrameAna
|
||||||
return progress;
|
return progress;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Nullable
|
||||||
|
public DataCounts getDataCounts() {
|
||||||
|
return dataCounts;
|
||||||
|
}
|
||||||
|
|
||||||
@Nullable
|
@Nullable
|
||||||
public MemoryUsage getMemoryUsage() {
|
public MemoryUsage getMemoryUsage() {
|
||||||
return memoryUsage;
|
return memoryUsage;
|
||||||
|
@ -293,6 +308,9 @@ public class GetDataFrameAnalyticsStatsAction extends ActionType<GetDataFrameAna
|
||||||
if (progress != null) {
|
if (progress != null) {
|
||||||
builder.field("progress", progress);
|
builder.field("progress", progress);
|
||||||
}
|
}
|
||||||
|
if (dataCounts != null) {
|
||||||
|
builder.field("data_counts", dataCounts);
|
||||||
|
}
|
||||||
if (memoryUsage != null) {
|
if (memoryUsage != null) {
|
||||||
builder.field("memory_usage", memoryUsage);
|
builder.field("memory_usage", memoryUsage);
|
||||||
}
|
}
|
||||||
|
@ -331,6 +349,9 @@ public class GetDataFrameAnalyticsStatsAction extends ActionType<GetDataFrameAna
|
||||||
} else {
|
} else {
|
||||||
out.writeList(progress);
|
out.writeList(progress);
|
||||||
}
|
}
|
||||||
|
if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
|
||||||
|
out.writeOptionalWriteable(dataCounts);
|
||||||
|
}
|
||||||
if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
|
if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
|
||||||
out.writeOptionalWriteable(memoryUsage);
|
out.writeOptionalWriteable(memoryUsage);
|
||||||
}
|
}
|
||||||
|
@ -369,7 +390,8 @@ public class GetDataFrameAnalyticsStatsAction extends ActionType<GetDataFrameAna
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int hashCode() {
|
public int hashCode() {
|
||||||
return Objects.hash(id, state, failureReason, progress, memoryUsage, analysisStats, node, assignmentExplanation);
|
return Objects.hash(id, state, failureReason, progress, dataCounts, memoryUsage, analysisStats, node,
|
||||||
|
assignmentExplanation);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -385,6 +407,7 @@ public class GetDataFrameAnalyticsStatsAction extends ActionType<GetDataFrameAna
|
||||||
&& Objects.equals(this.state, other.state)
|
&& Objects.equals(this.state, other.state)
|
||||||
&& Objects.equals(this.failureReason, other.failureReason)
|
&& Objects.equals(this.failureReason, other.failureReason)
|
||||||
&& Objects.equals(this.progress, other.progress)
|
&& Objects.equals(this.progress, other.progress)
|
||||||
|
&& Objects.equals(this.dataCounts, other.dataCounts)
|
||||||
&& Objects.equals(this.memoryUsage, other.memoryUsage)
|
&& Objects.equals(this.memoryUsage, other.memoryUsage)
|
||||||
&& Objects.equals(this.analysisStats, other.analysisStats)
|
&& Objects.equals(this.analysisStats, other.analysisStats)
|
||||||
&& Objects.equals(this.node, other.node)
|
&& Objects.equals(this.node, other.node)
|
||||||
|
|
|
@ -0,0 +1,120 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License;
|
||||||
|
* you may not use this file except in compliance with the Elastic License.
|
||||||
|
*/
|
||||||
|
package org.elasticsearch.xpack.core.ml.dataframe.stats.common;
|
||||||
|
|
||||||
|
import org.elasticsearch.common.ParseField;
|
||||||
|
import org.elasticsearch.common.io.stream.StreamInput;
|
||||||
|
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||||
|
import org.elasticsearch.common.io.stream.Writeable;
|
||||||
|
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||||
|
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||||
|
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||||
|
import org.elasticsearch.xpack.core.ml.dataframe.stats.Fields;
|
||||||
|
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
|
public class DataCounts implements ToXContentObject, Writeable {
|
||||||
|
|
||||||
|
public static final String TYPE_VALUE = "analytics_data_counts";
|
||||||
|
|
||||||
|
public static final ParseField TRAINING_DOCS_COUNT = new ParseField("training_docs_count");
|
||||||
|
public static final ParseField TEST_DOCS_COUNT = new ParseField("test_docs_count");
|
||||||
|
public static final ParseField SKIPPED_DOCS_COUNT = new ParseField("skipped_docs_count");
|
||||||
|
|
||||||
|
public static final ConstructingObjectParser<DataCounts, Void> STRICT_PARSER = createParser(false);
|
||||||
|
public static final ConstructingObjectParser<DataCounts, Void> LENIENT_PARSER = createParser(true);
|
||||||
|
|
||||||
|
private static ConstructingObjectParser<DataCounts, Void> createParser(boolean ignoreUnknownFields) {
|
||||||
|
ConstructingObjectParser<DataCounts, Void> parser = new ConstructingObjectParser<>(TYPE_VALUE, ignoreUnknownFields,
|
||||||
|
a -> new DataCounts((String) a[0], (long) a[1], (long) a[2], (long) a[3]));
|
||||||
|
|
||||||
|
parser.declareString((bucket, s) -> {}, Fields.TYPE);
|
||||||
|
parser.declareString(ConstructingObjectParser.constructorArg(), Fields.JOB_ID);
|
||||||
|
parser.declareLong(ConstructingObjectParser.constructorArg(), TRAINING_DOCS_COUNT);
|
||||||
|
parser.declareLong(ConstructingObjectParser.constructorArg(), TEST_DOCS_COUNT);
|
||||||
|
parser.declareLong(ConstructingObjectParser.constructorArg(), SKIPPED_DOCS_COUNT);
|
||||||
|
return parser;
|
||||||
|
}
|
||||||
|
|
||||||
|
private final String jobId;
|
||||||
|
private final long trainingDocsCount;
|
||||||
|
private final long testDocsCount;
|
||||||
|
private final long skippedDocsCount;
|
||||||
|
|
||||||
|
public DataCounts(String jobId, long trainingDocsCount, long testDocsCount, long skippedDocsCount) {
|
||||||
|
this.jobId = Objects.requireNonNull(jobId);
|
||||||
|
this.trainingDocsCount = trainingDocsCount;
|
||||||
|
this.testDocsCount = testDocsCount;
|
||||||
|
this.skippedDocsCount = skippedDocsCount;
|
||||||
|
}
|
||||||
|
|
||||||
|
public DataCounts(StreamInput in) throws IOException {
|
||||||
|
this.jobId = in.readString();
|
||||||
|
this.trainingDocsCount = in.readVLong();
|
||||||
|
this.testDocsCount = in.readVLong();
|
||||||
|
this.skippedDocsCount = in.readVLong();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void writeTo(StreamOutput out) throws IOException {
|
||||||
|
out.writeString(jobId);
|
||||||
|
out.writeVLong(trainingDocsCount);
|
||||||
|
out.writeVLong(testDocsCount);
|
||||||
|
out.writeVLong(skippedDocsCount);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||||
|
builder.startObject();
|
||||||
|
if (params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false)) {
|
||||||
|
builder.field(Fields.TYPE.getPreferredName(), TYPE_VALUE);
|
||||||
|
builder.field(Fields.JOB_ID.getPreferredName(), jobId);
|
||||||
|
}
|
||||||
|
builder.field(TRAINING_DOCS_COUNT.getPreferredName(), trainingDocsCount);
|
||||||
|
builder.field(TEST_DOCS_COUNT.getPreferredName(), testDocsCount);
|
||||||
|
builder.field(SKIPPED_DOCS_COUNT.getPreferredName(), skippedDocsCount);
|
||||||
|
builder.endObject();
|
||||||
|
return builder;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object o) {
|
||||||
|
if (this == o) return true;
|
||||||
|
if (o == null || getClass() != o.getClass()) return false;
|
||||||
|
DataCounts that = (DataCounts) o;
|
||||||
|
return Objects.equals(jobId, that.jobId)
|
||||||
|
&& trainingDocsCount == that.trainingDocsCount
|
||||||
|
&& testDocsCount == that.testDocsCount
|
||||||
|
&& skippedDocsCount == that.skippedDocsCount;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
return Objects.hash(jobId, trainingDocsCount, testDocsCount, skippedDocsCount);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static String documentId(String jobId) {
|
||||||
|
return TYPE_VALUE + "_" + jobId;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getJobId() {
|
||||||
|
return jobId;
|
||||||
|
}
|
||||||
|
|
||||||
|
public long getTrainingDocsCount() {
|
||||||
|
return trainingDocsCount;
|
||||||
|
}
|
||||||
|
|
||||||
|
public long getTestDocsCount() {
|
||||||
|
return testDocsCount;
|
||||||
|
}
|
||||||
|
|
||||||
|
public long getSkippedDocsCount() {
|
||||||
|
return skippedDocsCount;
|
||||||
|
}
|
||||||
|
}
|
|
@ -85,6 +85,9 @@
|
||||||
"peak_usage_bytes" : {
|
"peak_usage_bytes" : {
|
||||||
"type" : "long"
|
"type" : "long"
|
||||||
},
|
},
|
||||||
|
"skipped_docs_count": {
|
||||||
|
"type": "long"
|
||||||
|
},
|
||||||
"timestamp" : {
|
"timestamp" : {
|
||||||
"type" : "date"
|
"type" : "date"
|
||||||
},
|
},
|
||||||
|
@ -98,6 +101,12 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"test_docs_count": {
|
||||||
|
"type": "long"
|
||||||
|
},
|
||||||
|
"training_docs_count": {
|
||||||
|
"type": "long"
|
||||||
|
},
|
||||||
"type" : {
|
"type" : {
|
||||||
"type" : "keyword"
|
"type" : "keyword"
|
||||||
},
|
},
|
||||||
|
|
|
@ -14,6 +14,8 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigTests;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStats;
|
import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStats;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStatsNamedWriteablesProvider;
|
import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStatsNamedWriteablesProvider;
|
||||||
|
import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts;
|
||||||
|
import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCountsTests;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.MemoryUsage;
|
import org.elasticsearch.xpack.core.ml.dataframe.stats.MemoryUsage;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.MemoryUsageTests;
|
import org.elasticsearch.xpack.core.ml.dataframe.stats.MemoryUsageTests;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStatsTests;
|
import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStatsTests;
|
||||||
|
@ -42,6 +44,7 @@ public class GetDataFrameAnalyticsStatsActionResponseTests extends AbstractWireS
|
||||||
List<PhaseProgress> progress = new ArrayList<>(progressSize);
|
List<PhaseProgress> progress = new ArrayList<>(progressSize);
|
||||||
IntStream.of(progressSize).forEach(progressIndex -> progress.add(
|
IntStream.of(progressSize).forEach(progressIndex -> progress.add(
|
||||||
new PhaseProgress(randomAlphaOfLength(10), randomIntBetween(0, 100))));
|
new PhaseProgress(randomAlphaOfLength(10), randomIntBetween(0, 100))));
|
||||||
|
DataCounts dataCounts = randomBoolean() ? null : DataCountsTests.createRandom();
|
||||||
MemoryUsage memoryUsage = randomBoolean() ? null : MemoryUsageTests.createRandom();
|
MemoryUsage memoryUsage = randomBoolean() ? null : MemoryUsageTests.createRandom();
|
||||||
AnalysisStats analysisStats = randomBoolean() ? null :
|
AnalysisStats analysisStats = randomBoolean() ? null :
|
||||||
randomFrom(
|
randomFrom(
|
||||||
|
@ -50,7 +53,7 @@ public class GetDataFrameAnalyticsStatsActionResponseTests extends AbstractWireS
|
||||||
RegressionStatsTests.createRandom()
|
RegressionStatsTests.createRandom()
|
||||||
);
|
);
|
||||||
Response.Stats stats = new Response.Stats(DataFrameAnalyticsConfigTests.randomValidId(),
|
Response.Stats stats = new Response.Stats(DataFrameAnalyticsConfigTests.randomValidId(),
|
||||||
randomFrom(DataFrameAnalyticsState.values()), failureReason, progress, memoryUsage, analysisStats, null,
|
randomFrom(DataFrameAnalyticsState.values()), failureReason, progress, dataCounts, memoryUsage, analysisStats, null,
|
||||||
randomAlphaOfLength(20));
|
randomAlphaOfLength(20));
|
||||||
analytics.add(stats);
|
analytics.add(stats);
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,66 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License;
|
||||||
|
* you may not use this file except in compliance with the Elastic License.
|
||||||
|
*/
|
||||||
|
package org.elasticsearch.xpack.core.ml.dataframe.stats.common;
|
||||||
|
|
||||||
|
import org.elasticsearch.Version;
|
||||||
|
import org.elasticsearch.common.io.stream.Writeable;
|
||||||
|
import org.elasticsearch.common.xcontent.ToXContent;
|
||||||
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
|
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
|
||||||
|
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
|
||||||
|
import org.junit.Before;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.Collections;
|
||||||
|
|
||||||
|
public class DataCountsTests extends AbstractBWCSerializationTestCase<DataCounts> {
|
||||||
|
|
||||||
|
private boolean lenient;
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void chooseLenient() {
|
||||||
|
lenient = randomBoolean();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected boolean supportsUnknownFields() {
|
||||||
|
return lenient;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected DataCounts mutateInstanceForVersion(DataCounts instance, Version version) {
|
||||||
|
return instance;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected DataCounts doParseInstance(XContentParser parser) throws IOException {
|
||||||
|
return lenient ? DataCounts.LENIENT_PARSER.apply(parser, null) : DataCounts.STRICT_PARSER.apply(parser, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected ToXContent.Params getToXContentParams() {
|
||||||
|
return new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true"));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected Writeable.Reader<DataCounts> instanceReader() {
|
||||||
|
return DataCounts::new;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected DataCounts createTestInstance() {
|
||||||
|
return createRandom();
|
||||||
|
}
|
||||||
|
|
||||||
|
public static DataCounts createRandom() {
|
||||||
|
return new DataCounts(
|
||||||
|
randomAlphaOfLength(10),
|
||||||
|
randomNonNegativeLong(),
|
||||||
|
randomNonNegativeLong(),
|
||||||
|
randomNonNegativeLong()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
|
@ -22,6 +22,7 @@ import org.elasticsearch.index.query.QueryBuilders;
|
||||||
import org.elasticsearch.rest.RestStatus;
|
import org.elasticsearch.rest.RestStatus;
|
||||||
import org.elasticsearch.search.SearchHit;
|
import org.elasticsearch.search.SearchHit;
|
||||||
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
|
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
|
||||||
|
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams;
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams;
|
||||||
|
@ -49,6 +50,7 @@ import static org.hamcrest.Matchers.hasKey;
|
||||||
import static org.hamcrest.Matchers.hasSize;
|
import static org.hamcrest.Matchers.hasSize;
|
||||||
import static org.hamcrest.Matchers.in;
|
import static org.hamcrest.Matchers.in;
|
||||||
import static org.hamcrest.Matchers.is;
|
import static org.hamcrest.Matchers.is;
|
||||||
|
import static org.hamcrest.Matchers.lessThan;
|
||||||
import static org.hamcrest.Matchers.lessThanOrEqualTo;
|
import static org.hamcrest.Matchers.lessThanOrEqualTo;
|
||||||
import static org.hamcrest.Matchers.startsWith;
|
import static org.hamcrest.Matchers.startsWith;
|
||||||
|
|
||||||
|
@ -158,6 +160,12 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
assertTopClasses(resultsObject, 2, KEYWORD_FIELD, KEYWORD_FIELD_VALUES);
|
assertTopClasses(resultsObject, 2, KEYWORD_FIELD, KEYWORD_FIELD_VALUES);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
GetDataFrameAnalyticsStatsAction.Response.Stats stats = getAnalyticsStats(jobId);
|
||||||
|
assertThat(stats.getDataCounts().getJobId(), equalTo(jobId));
|
||||||
|
assertThat(stats.getDataCounts().getTrainingDocsCount(), equalTo(300L));
|
||||||
|
assertThat(stats.getDataCounts().getTestDocsCount(), equalTo(0L));
|
||||||
|
assertThat(stats.getDataCounts().getSkippedDocsCount(), equalTo(0L));
|
||||||
|
|
||||||
assertProgress(jobId, 100, 100, 100, 100);
|
assertProgress(jobId, 100, 100, 100, 100);
|
||||||
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||||
assertModelStatePersisted(stateDocId());
|
assertModelStatePersisted(stateDocId());
|
||||||
|
@ -225,6 +233,14 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
assertThat(trainingRowsCount, greaterThan(0));
|
assertThat(trainingRowsCount, greaterThan(0));
|
||||||
assertThat(nonTrainingRowsCount, greaterThan(0));
|
assertThat(nonTrainingRowsCount, greaterThan(0));
|
||||||
|
|
||||||
|
GetDataFrameAnalyticsStatsAction.Response.Stats stats = getAnalyticsStats(jobId);
|
||||||
|
assertThat(stats.getDataCounts().getJobId(), equalTo(jobId));
|
||||||
|
assertThat(stats.getDataCounts().getTrainingDocsCount(), greaterThan(0L));
|
||||||
|
assertThat(stats.getDataCounts().getTrainingDocsCount(), lessThan(300L));
|
||||||
|
assertThat(stats.getDataCounts().getTestDocsCount(), greaterThan(0L));
|
||||||
|
assertThat(stats.getDataCounts().getTestDocsCount(), lessThan(300L));
|
||||||
|
assertThat(stats.getDataCounts().getSkippedDocsCount(), equalTo(0L));
|
||||||
|
|
||||||
assertProgress(jobId, 100, 100, 100, 100);
|
assertProgress(jobId, 100, 100, 100, 100);
|
||||||
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||||
assertModelStatePersisted(stateDocId());
|
assertModelStatePersisted(stateDocId());
|
||||||
|
|
|
@ -12,6 +12,7 @@ import org.elasticsearch.action.index.IndexRequest;
|
||||||
import org.elasticsearch.action.search.SearchResponse;
|
import org.elasticsearch.action.search.SearchResponse;
|
||||||
import org.elasticsearch.action.support.WriteRequest;
|
import org.elasticsearch.action.support.WriteRequest;
|
||||||
import org.elasticsearch.search.SearchHit;
|
import org.elasticsearch.search.SearchHit;
|
||||||
|
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection;
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection;
|
||||||
import org.junit.After;
|
import org.junit.After;
|
||||||
|
@ -79,6 +80,12 @@ public class OutlierDetectionWithMissingFieldsIT extends MlNativeDataFrameAnalyt
|
||||||
startAnalytics(id);
|
startAnalytics(id);
|
||||||
waitUntilAnalyticsIsStopped(id);
|
waitUntilAnalyticsIsStopped(id);
|
||||||
|
|
||||||
|
GetDataFrameAnalyticsStatsAction.Response.Stats stats = getAnalyticsStats(id);
|
||||||
|
assertThat(stats.getDataCounts().getJobId(), equalTo(id));
|
||||||
|
assertThat(stats.getDataCounts().getTrainingDocsCount(), equalTo(5L));
|
||||||
|
assertThat(stats.getDataCounts().getTestDocsCount(), equalTo(0L));
|
||||||
|
assertThat(stats.getDataCounts().getSkippedDocsCount(), equalTo(2L));
|
||||||
|
|
||||||
SearchResponse sourceData = client().prepareSearch(sourceIndex).get();
|
SearchResponse sourceData = client().prepareSearch(sourceIndex).get();
|
||||||
for (SearchHit hit : sourceData.getHits()) {
|
for (SearchHit hit : sourceData.getHits()) {
|
||||||
GetResponse destDocGetResponse = client().prepareGet().setIndex(config.getDest().getIndex()).setId(hit.getId()).get();
|
GetResponse destDocGetResponse = client().prepareGet().setIndex(config.getDest().getIndex()).setId(hit.getId()).get();
|
||||||
|
|
|
@ -15,6 +15,7 @@ import org.elasticsearch.action.support.WriteRequest;
|
||||||
import org.elasticsearch.common.unit.TimeValue;
|
import org.elasticsearch.common.unit.TimeValue;
|
||||||
import org.elasticsearch.rest.RestStatus;
|
import org.elasticsearch.rest.RestStatus;
|
||||||
import org.elasticsearch.search.SearchHit;
|
import org.elasticsearch.search.SearchHit;
|
||||||
|
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams;
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams;
|
||||||
|
@ -33,6 +34,7 @@ import static org.hamcrest.Matchers.equalTo;
|
||||||
import static org.hamcrest.Matchers.greaterThan;
|
import static org.hamcrest.Matchers.greaterThan;
|
||||||
import static org.hamcrest.Matchers.hasSize;
|
import static org.hamcrest.Matchers.hasSize;
|
||||||
import static org.hamcrest.Matchers.is;
|
import static org.hamcrest.Matchers.is;
|
||||||
|
import static org.hamcrest.Matchers.lessThan;
|
||||||
|
|
||||||
public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
|
|
||||||
|
@ -143,6 +145,13 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
|
|
||||||
assertProgress(jobId, 100, 100, 100, 100);
|
assertProgress(jobId, 100, 100, 100, 100);
|
||||||
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||||
|
|
||||||
|
GetDataFrameAnalyticsStatsAction.Response.Stats stats = getAnalyticsStats(jobId);
|
||||||
|
assertThat(stats.getDataCounts().getJobId(), equalTo(jobId));
|
||||||
|
assertThat(stats.getDataCounts().getTrainingDocsCount(), equalTo(350L));
|
||||||
|
assertThat(stats.getDataCounts().getTestDocsCount(), equalTo(0L));
|
||||||
|
assertThat(stats.getDataCounts().getSkippedDocsCount(), equalTo(0L));
|
||||||
|
|
||||||
assertModelStatePersisted(stateDocId());
|
assertModelStatePersisted(stateDocId());
|
||||||
assertInferenceModelPersisted(jobId);
|
assertInferenceModelPersisted(jobId);
|
||||||
assertMlResultsFieldMappings(destIndex, predictedClassField, "double");
|
assertMlResultsFieldMappings(destIndex, predictedClassField, "double");
|
||||||
|
@ -199,6 +208,14 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
assertThat(trainingRowsCount, greaterThan(0));
|
assertThat(trainingRowsCount, greaterThan(0));
|
||||||
assertThat(nonTrainingRowsCount, greaterThan(0));
|
assertThat(nonTrainingRowsCount, greaterThan(0));
|
||||||
|
|
||||||
|
GetDataFrameAnalyticsStatsAction.Response.Stats stats = getAnalyticsStats(jobId);
|
||||||
|
assertThat(stats.getDataCounts().getJobId(), equalTo(jobId));
|
||||||
|
assertThat(stats.getDataCounts().getTrainingDocsCount(), greaterThan(0L));
|
||||||
|
assertThat(stats.getDataCounts().getTrainingDocsCount(), lessThan(350L));
|
||||||
|
assertThat(stats.getDataCounts().getTestDocsCount(), greaterThan(0L));
|
||||||
|
assertThat(stats.getDataCounts().getTestDocsCount(), lessThan(350L));
|
||||||
|
assertThat(stats.getDataCounts().getSkippedDocsCount(), equalTo(0L));
|
||||||
|
|
||||||
assertProgress(jobId, 100, 100, 100, 100);
|
assertProgress(jobId, 100, 100, 100, 100);
|
||||||
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||||
assertModelStatePersisted(stateDocId());
|
assertModelStatePersisted(stateDocId());
|
||||||
|
|
|
@ -85,6 +85,11 @@ public class RunDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsIntegTest
|
||||||
|
|
||||||
startAnalytics(id);
|
startAnalytics(id);
|
||||||
waitUntilAnalyticsIsStopped(id);
|
waitUntilAnalyticsIsStopped(id);
|
||||||
|
GetDataFrameAnalyticsStatsAction.Response.Stats stats = getAnalyticsStats(id);
|
||||||
|
assertThat(stats.getDataCounts().getJobId(), equalTo(id));
|
||||||
|
assertThat(stats.getDataCounts().getTrainingDocsCount(), equalTo(5L));
|
||||||
|
assertThat(stats.getDataCounts().getTestDocsCount(), equalTo(0L));
|
||||||
|
assertThat(stats.getDataCounts().getSkippedDocsCount(), equalTo(0L));
|
||||||
|
|
||||||
SearchResponse sourceData = client().prepareSearch(sourceIndex).get();
|
SearchResponse sourceData = client().prepareSearch(sourceIndex).get();
|
||||||
double scoreOfOutlier = 0.0;
|
double scoreOfOutlier = 0.0;
|
||||||
|
|
|
@ -42,6 +42,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState;
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStats;
|
import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStats;
|
||||||
|
import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.Fields;
|
import org.elasticsearch.xpack.core.ml.dataframe.stats.Fields;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.MemoryUsage;
|
import org.elasticsearch.xpack.core.ml.dataframe.stats.MemoryUsage;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStats;
|
import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStats;
|
||||||
|
@ -108,6 +109,7 @@ public class TransportGetDataFrameAnalyticsStatsAction
|
||||||
Stats stats = buildStats(
|
Stats stats = buildStats(
|
||||||
task.getParams().getId(),
|
task.getParams().getId(),
|
||||||
statsHolder.getProgressTracker().report(),
|
statsHolder.getProgressTracker().report(),
|
||||||
|
statsHolder.getDataCountsTracker().report(task.getParams().getId()),
|
||||||
statsHolder.getMemoryUsage(),
|
statsHolder.getMemoryUsage(),
|
||||||
statsHolder.getAnalysisStats()
|
statsHolder.getAnalysisStats()
|
||||||
);
|
);
|
||||||
|
@ -198,6 +200,7 @@ public class TransportGetDataFrameAnalyticsStatsAction
|
||||||
|
|
||||||
MultiSearchRequest multiSearchRequest = new MultiSearchRequest();
|
MultiSearchRequest multiSearchRequest = new MultiSearchRequest();
|
||||||
multiSearchRequest.add(buildStoredProgressSearch(configId));
|
multiSearchRequest.add(buildStoredProgressSearch(configId));
|
||||||
|
multiSearchRequest.add(buildStatsDocSearch(configId, DataCounts.TYPE_VALUE));
|
||||||
multiSearchRequest.add(buildStatsDocSearch(configId, MemoryUsage.TYPE_VALUE));
|
multiSearchRequest.add(buildStatsDocSearch(configId, MemoryUsage.TYPE_VALUE));
|
||||||
multiSearchRequest.add(buildStatsDocSearch(configId, OutlierDetectionStats.TYPE_VALUE));
|
multiSearchRequest.add(buildStatsDocSearch(configId, OutlierDetectionStats.TYPE_VALUE));
|
||||||
multiSearchRequest.add(buildStatsDocSearch(configId, ClassificationStats.TYPE_VALUE));
|
multiSearchRequest.add(buildStatsDocSearch(configId, ClassificationStats.TYPE_VALUE));
|
||||||
|
@ -222,6 +225,7 @@ public class TransportGetDataFrameAnalyticsStatsAction
|
||||||
}
|
}
|
||||||
listener.onResponse(buildStats(configId,
|
listener.onResponse(buildStats(configId,
|
||||||
retrievedStatsHolder.progress.get(),
|
retrievedStatsHolder.progress.get(),
|
||||||
|
retrievedStatsHolder.dataCounts,
|
||||||
retrievedStatsHolder.memoryUsage,
|
retrievedStatsHolder.memoryUsage,
|
||||||
retrievedStatsHolder.analysisStats
|
retrievedStatsHolder.analysisStats
|
||||||
));
|
));
|
||||||
|
@ -256,6 +260,8 @@ public class TransportGetDataFrameAnalyticsStatsAction
|
||||||
String hitId = hit.getId();
|
String hitId = hit.getId();
|
||||||
if (StoredProgress.documentId(configId).equals(hitId)) {
|
if (StoredProgress.documentId(configId).equals(hitId)) {
|
||||||
retrievedStatsHolder.progress = MlParserUtils.parse(hit, StoredProgress.PARSER);
|
retrievedStatsHolder.progress = MlParserUtils.parse(hit, StoredProgress.PARSER);
|
||||||
|
} else if (DataCounts.documentId(configId).equals(hitId)) {
|
||||||
|
retrievedStatsHolder.dataCounts = MlParserUtils.parse(hit, DataCounts.LENIENT_PARSER);
|
||||||
} else if (hitId.startsWith(MemoryUsage.documentIdPrefix(configId))) {
|
} else if (hitId.startsWith(MemoryUsage.documentIdPrefix(configId))) {
|
||||||
retrievedStatsHolder.memoryUsage = MlParserUtils.parse(hit, MemoryUsage.LENIENT_PARSER);
|
retrievedStatsHolder.memoryUsage = MlParserUtils.parse(hit, MemoryUsage.LENIENT_PARSER);
|
||||||
} else if (hitId.startsWith(OutlierDetectionStats.documentIdPrefix(configId))) {
|
} else if (hitId.startsWith(OutlierDetectionStats.documentIdPrefix(configId))) {
|
||||||
|
@ -271,6 +277,7 @@ public class TransportGetDataFrameAnalyticsStatsAction
|
||||||
|
|
||||||
private GetDataFrameAnalyticsStatsAction.Response.Stats buildStats(String concreteAnalyticsId,
|
private GetDataFrameAnalyticsStatsAction.Response.Stats buildStats(String concreteAnalyticsId,
|
||||||
List<PhaseProgress> progress,
|
List<PhaseProgress> progress,
|
||||||
|
DataCounts dataCounts,
|
||||||
MemoryUsage memoryUsage,
|
MemoryUsage memoryUsage,
|
||||||
AnalysisStats analysisStats) {
|
AnalysisStats analysisStats) {
|
||||||
ClusterState clusterState = clusterService.state();
|
ClusterState clusterState = clusterService.state();
|
||||||
|
@ -293,6 +300,7 @@ public class TransportGetDataFrameAnalyticsStatsAction
|
||||||
analyticsState,
|
analyticsState,
|
||||||
failureReason,
|
failureReason,
|
||||||
progress,
|
progress,
|
||||||
|
dataCounts,
|
||||||
memoryUsage,
|
memoryUsage,
|
||||||
analysisStats,
|
analysisStats,
|
||||||
node,
|
node,
|
||||||
|
@ -303,6 +311,7 @@ public class TransportGetDataFrameAnalyticsStatsAction
|
||||||
private static class RetrievedStatsHolder {
|
private static class RetrievedStatsHolder {
|
||||||
|
|
||||||
private volatile StoredProgress progress = new StoredProgress(new ProgressTracker().report());
|
private volatile StoredProgress progress = new StoredProgress(new ProgressTracker().report());
|
||||||
|
private volatile DataCounts dataCounts;
|
||||||
private volatile MemoryUsage memoryUsage;
|
private volatile MemoryUsage memoryUsage;
|
||||||
private volatile AnalysisStats analysisStats;
|
private volatile AnalysisStats analysisStats;
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,6 +19,9 @@ import org.elasticsearch.action.search.SearchScrollRequestBuilder;
|
||||||
import org.elasticsearch.client.Client;
|
import org.elasticsearch.client.Client;
|
||||||
import org.elasticsearch.common.Nullable;
|
import org.elasticsearch.common.Nullable;
|
||||||
import org.elasticsearch.common.unit.TimeValue;
|
import org.elasticsearch.common.unit.TimeValue;
|
||||||
|
import org.elasticsearch.index.query.BoolQueryBuilder;
|
||||||
|
import org.elasticsearch.index.query.QueryBuilder;
|
||||||
|
import org.elasticsearch.index.query.QueryBuilders;
|
||||||
import org.elasticsearch.search.SearchHit;
|
import org.elasticsearch.search.SearchHit;
|
||||||
import org.elasticsearch.search.fetch.StoredFieldsContext;
|
import org.elasticsearch.search.fetch.StoredFieldsContext;
|
||||||
import org.elasticsearch.search.sort.SortOrder;
|
import org.elasticsearch.search.sort.SortOrder;
|
||||||
|
@ -187,7 +190,7 @@ public class DataFrameDataExtractor {
|
||||||
if (values.length == 1 && (values[0] instanceof Number || values[0] instanceof String)) {
|
if (values.length == 1 && (values[0] instanceof Number || values[0] instanceof String)) {
|
||||||
extractedValues[i] = Objects.toString(values[0]);
|
extractedValues[i] = Objects.toString(values[0]);
|
||||||
} else {
|
} else {
|
||||||
if (values.length == 0 && context.includeRowsWithMissingValues) {
|
if (values.length == 0 && context.supportsRowsWithMissingValues) {
|
||||||
// if values is empty then it means it's a missing value
|
// if values is empty then it means it's a missing value
|
||||||
extractedValues[i] = NULL_VALUE;
|
extractedValues[i] = NULL_VALUE;
|
||||||
} else {
|
} else {
|
||||||
|
@ -263,13 +266,29 @@ public class DataFrameDataExtractor {
|
||||||
}
|
}
|
||||||
|
|
||||||
private SearchRequestBuilder buildDataSummarySearchRequestBuilder() {
|
private SearchRequestBuilder buildDataSummarySearchRequestBuilder() {
|
||||||
|
|
||||||
|
QueryBuilder summaryQuery = context.query;
|
||||||
|
if (context.supportsRowsWithMissingValues == false) {
|
||||||
|
summaryQuery = QueryBuilders.boolQuery()
|
||||||
|
.filter(summaryQuery)
|
||||||
|
.filter(allExtractedFieldsExistQuery());
|
||||||
|
}
|
||||||
|
|
||||||
return new SearchRequestBuilder(client, SearchAction.INSTANCE)
|
return new SearchRequestBuilder(client, SearchAction.INSTANCE)
|
||||||
.setIndices(context.indices)
|
.setIndices(context.indices)
|
||||||
.setSize(0)
|
.setSize(0)
|
||||||
.setQuery(context.query)
|
.setQuery(summaryQuery)
|
||||||
.setTrackTotalHits(true);
|
.setTrackTotalHits(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private QueryBuilder allExtractedFieldsExistQuery() {
|
||||||
|
BoolQueryBuilder query = QueryBuilders.boolQuery();
|
||||||
|
for (ExtractedField field : context.extractedFields.getAllFields()) {
|
||||||
|
query.filter(QueryBuilders.existsQuery(field.getName()));
|
||||||
|
}
|
||||||
|
return query;
|
||||||
|
}
|
||||||
|
|
||||||
public Set<String> getCategoricalFields(DataFrameAnalysis analysis) {
|
public Set<String> getCategoricalFields(DataFrameAnalysis analysis) {
|
||||||
return ExtractedFieldsDetector.getCategoricalFields(context.extractedFields, analysis);
|
return ExtractedFieldsDetector.getCategoricalFields(context.extractedFields, analysis);
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,10 +21,10 @@ public class DataFrameDataExtractorContext {
|
||||||
final int scrollSize;
|
final int scrollSize;
|
||||||
final Map<String, String> headers;
|
final Map<String, String> headers;
|
||||||
final boolean includeSource;
|
final boolean includeSource;
|
||||||
final boolean includeRowsWithMissingValues;
|
final boolean supportsRowsWithMissingValues;
|
||||||
|
|
||||||
DataFrameDataExtractorContext(String jobId, ExtractedFields extractedFields, List<String> indices, QueryBuilder query, int scrollSize,
|
DataFrameDataExtractorContext(String jobId, ExtractedFields extractedFields, List<String> indices, QueryBuilder query, int scrollSize,
|
||||||
Map<String, String> headers, boolean includeSource, boolean includeRowsWithMissingValues) {
|
Map<String, String> headers, boolean includeSource, boolean supportsRowsWithMissingValues) {
|
||||||
this.jobId = Objects.requireNonNull(jobId);
|
this.jobId = Objects.requireNonNull(jobId);
|
||||||
this.extractedFields = Objects.requireNonNull(extractedFields);
|
this.extractedFields = Objects.requireNonNull(extractedFields);
|
||||||
this.indices = indices.toArray(new String[indices.size()]);
|
this.indices = indices.toArray(new String[indices.size()]);
|
||||||
|
@ -32,6 +32,6 @@ public class DataFrameDataExtractorContext {
|
||||||
this.scrollSize = scrollSize;
|
this.scrollSize = scrollSize;
|
||||||
this.headers = headers;
|
this.headers = headers;
|
||||||
this.includeSource = includeSource;
|
this.includeSource = includeSource;
|
||||||
this.includeRowsWithMissingValues = includeRowsWithMissingValues;
|
this.supportsRowsWithMissingValues = supportsRowsWithMissingValues;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,11 +7,9 @@ package org.elasticsearch.xpack.ml.dataframe.extractor;
|
||||||
|
|
||||||
import org.elasticsearch.action.ActionListener;
|
import org.elasticsearch.action.ActionListener;
|
||||||
import org.elasticsearch.client.Client;
|
import org.elasticsearch.client.Client;
|
||||||
import org.elasticsearch.index.query.BoolQueryBuilder;
|
|
||||||
import org.elasticsearch.index.query.QueryBuilder;
|
import org.elasticsearch.index.query.QueryBuilder;
|
||||||
import org.elasticsearch.index.query.QueryBuilders;
|
import org.elasticsearch.index.query.QueryBuilders;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||||
import org.elasticsearch.xpack.ml.extractor.ExtractedField;
|
|
||||||
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
|
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
@ -28,18 +26,18 @@ public class DataFrameDataExtractorFactory {
|
||||||
private final QueryBuilder sourceQuery;
|
private final QueryBuilder sourceQuery;
|
||||||
private final ExtractedFields extractedFields;
|
private final ExtractedFields extractedFields;
|
||||||
private final Map<String, String> headers;
|
private final Map<String, String> headers;
|
||||||
private final boolean includeRowsWithMissingValues;
|
private final boolean supportsRowsWithMissingValues;
|
||||||
|
|
||||||
private DataFrameDataExtractorFactory(Client client, String analyticsId, List<String> indices, QueryBuilder sourceQuery,
|
private DataFrameDataExtractorFactory(Client client, String analyticsId, List<String> indices, QueryBuilder sourceQuery,
|
||||||
ExtractedFields extractedFields, Map<String, String> headers,
|
ExtractedFields extractedFields, Map<String, String> headers,
|
||||||
boolean includeRowsWithMissingValues) {
|
boolean supportsRowsWithMissingValues) {
|
||||||
this.client = Objects.requireNonNull(client);
|
this.client = Objects.requireNonNull(client);
|
||||||
this.analyticsId = Objects.requireNonNull(analyticsId);
|
this.analyticsId = Objects.requireNonNull(analyticsId);
|
||||||
this.indices = Objects.requireNonNull(indices);
|
this.indices = Objects.requireNonNull(indices);
|
||||||
this.sourceQuery = Objects.requireNonNull(sourceQuery);
|
this.sourceQuery = Objects.requireNonNull(sourceQuery);
|
||||||
this.extractedFields = Objects.requireNonNull(extractedFields);
|
this.extractedFields = Objects.requireNonNull(extractedFields);
|
||||||
this.headers = headers;
|
this.headers = headers;
|
||||||
this.includeRowsWithMissingValues = includeRowsWithMissingValues;
|
this.supportsRowsWithMissingValues = supportsRowsWithMissingValues;
|
||||||
}
|
}
|
||||||
|
|
||||||
public DataFrameDataExtractor newExtractor(boolean includeSource) {
|
public DataFrameDataExtractor newExtractor(boolean includeSource) {
|
||||||
|
@ -47,11 +45,11 @@ public class DataFrameDataExtractorFactory {
|
||||||
analyticsId,
|
analyticsId,
|
||||||
extractedFields,
|
extractedFields,
|
||||||
indices,
|
indices,
|
||||||
createQuery(),
|
QueryBuilders.boolQuery().filter(sourceQuery),
|
||||||
1000,
|
1000,
|
||||||
headers,
|
headers,
|
||||||
includeSource,
|
includeSource,
|
||||||
includeRowsWithMissingValues
|
supportsRowsWithMissingValues
|
||||||
);
|
);
|
||||||
return new DataFrameDataExtractor(client, context);
|
return new DataFrameDataExtractor(client, context);
|
||||||
}
|
}
|
||||||
|
@ -60,23 +58,6 @@ public class DataFrameDataExtractorFactory {
|
||||||
return extractedFields;
|
return extractedFields;
|
||||||
}
|
}
|
||||||
|
|
||||||
private QueryBuilder createQuery() {
|
|
||||||
BoolQueryBuilder query = QueryBuilders.boolQuery();
|
|
||||||
query.filter(sourceQuery);
|
|
||||||
if (includeRowsWithMissingValues == false) {
|
|
||||||
query.filter(allExtractedFieldsExistQuery());
|
|
||||||
}
|
|
||||||
return query;
|
|
||||||
}
|
|
||||||
|
|
||||||
private QueryBuilder allExtractedFieldsExistQuery() {
|
|
||||||
BoolQueryBuilder query = QueryBuilders.boolQuery();
|
|
||||||
for (ExtractedField field : extractedFields.getAllFields()) {
|
|
||||||
query.filter(QueryBuilders.existsQuery(field.getName()));
|
|
||||||
}
|
|
||||||
return query;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a new extractor factory
|
* Create a new extractor factory
|
||||||
*
|
*
|
||||||
|
@ -109,6 +90,7 @@ public class DataFrameDataExtractorFactory {
|
||||||
extractedFieldsDetectorFactory.createFromDest(config, ActionListener.wrap(
|
extractedFieldsDetectorFactory.createFromDest(config, ActionListener.wrap(
|
||||||
extractedFieldsDetector -> {
|
extractedFieldsDetector -> {
|
||||||
ExtractedFields extractedFields = extractedFieldsDetector.detect().v1();
|
ExtractedFields extractedFields = extractedFieldsDetector.detect().v1();
|
||||||
|
|
||||||
DataFrameDataExtractorFactory extractorFactory = new DataFrameDataExtractorFactory(client, config.getId(),
|
DataFrameDataExtractorFactory extractorFactory = new DataFrameDataExtractorFactory(client, config.getId(),
|
||||||
Collections.singletonList(config.getDest().getIndex()), config.getSource().getParsedQuery(), extractedFields,
|
Collections.singletonList(config.getDest().getIndex()), config.getSource().getParsedQuery(), extractedFields,
|
||||||
config.getHeaders(), config.getAnalysis().supportsMissingValues());
|
config.getHeaders(), config.getAnalysis().supportsMissingValues());
|
||||||
|
|
|
@ -11,6 +11,7 @@ import org.apache.logging.log4j.message.ParameterizedMessage;
|
||||||
import org.apache.lucene.util.SetOnce;
|
import org.apache.lucene.util.SetOnce;
|
||||||
import org.elasticsearch.action.admin.indices.refresh.RefreshAction;
|
import org.elasticsearch.action.admin.indices.refresh.RefreshAction;
|
||||||
import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
|
import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
|
||||||
|
import org.elasticsearch.action.index.IndexRequest;
|
||||||
import org.elasticsearch.action.search.SearchResponse;
|
import org.elasticsearch.action.search.SearchResponse;
|
||||||
import org.elasticsearch.action.support.IndicesOptions;
|
import org.elasticsearch.action.support.IndicesOptions;
|
||||||
import org.elasticsearch.client.Client;
|
import org.elasticsearch.client.Client;
|
||||||
|
@ -22,8 +23,10 @@ import org.elasticsearch.index.query.QueryBuilders;
|
||||||
import org.elasticsearch.search.SearchHit;
|
import org.elasticsearch.search.SearchHit;
|
||||||
import org.elasticsearch.threadpool.ThreadPool;
|
import org.elasticsearch.threadpool.ThreadPool;
|
||||||
import org.elasticsearch.xpack.core.ClientHelper;
|
import org.elasticsearch.xpack.core.ClientHelper;
|
||||||
|
import org.elasticsearch.xpack.core.ml.MlStatsIndex;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
|
||||||
|
import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts;
|
||||||
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
||||||
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
|
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
|
||||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||||
|
@ -34,7 +37,9 @@ import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFact
|
||||||
import org.elasticsearch.xpack.ml.dataframe.process.crossvalidation.CrossValidationSplitter;
|
import org.elasticsearch.xpack.ml.dataframe.process.crossvalidation.CrossValidationSplitter;
|
||||||
import org.elasticsearch.xpack.ml.dataframe.process.crossvalidation.CrossValidationSplitterFactory;
|
import org.elasticsearch.xpack.ml.dataframe.process.crossvalidation.CrossValidationSplitterFactory;
|
||||||
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
|
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
|
||||||
|
import org.elasticsearch.xpack.ml.dataframe.stats.DataCountsTracker;
|
||||||
import org.elasticsearch.xpack.ml.dataframe.stats.ProgressTracker;
|
import org.elasticsearch.xpack.ml.dataframe.stats.ProgressTracker;
|
||||||
|
import org.elasticsearch.xpack.ml.dataframe.stats.StatsPersister;
|
||||||
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
|
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
|
||||||
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
||||||
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
|
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
|
||||||
|
@ -156,7 +161,10 @@ public class AnalyticsProcessManager {
|
||||||
AnalyticsResultProcessor resultProcessor = processContext.resultProcessor.get();
|
AnalyticsResultProcessor resultProcessor = processContext.resultProcessor.get();
|
||||||
try {
|
try {
|
||||||
writeHeaderRecord(dataExtractor, process);
|
writeHeaderRecord(dataExtractor, process);
|
||||||
writeDataRows(dataExtractor, process, config.getAnalysis(), task.getStatsHolder().getProgressTracker());
|
writeDataRows(dataExtractor, process, config.getAnalysis(), task.getStatsHolder().getProgressTracker(),
|
||||||
|
task.getStatsHolder().getDataCountsTracker());
|
||||||
|
processContext.statsPersister.persistWithRetry(task.getStatsHolder().getDataCountsTracker().report(config.getId()),
|
||||||
|
DataCounts::documentId);
|
||||||
process.writeEndOfDataMessage();
|
process.writeEndOfDataMessage();
|
||||||
process.flushStream();
|
process.flushStream();
|
||||||
|
|
||||||
|
@ -205,8 +213,8 @@ public class AnalyticsProcessManager {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void writeDataRows(DataFrameDataExtractor dataExtractor, AnalyticsProcess<AnalyticsResult> process,
|
private void writeDataRows(DataFrameDataExtractor dataExtractor, AnalyticsProcess<AnalyticsResult> process, DataFrameAnalysis analysis,
|
||||||
DataFrameAnalysis analysis, ProgressTracker progressTracker) throws IOException {
|
ProgressTracker progressTracker, DataCountsTracker dataCountsTracker) throws IOException {
|
||||||
|
|
||||||
CrossValidationSplitter crossValidationSplitter = new CrossValidationSplitterFactory(dataExtractor.getFieldNames())
|
CrossValidationSplitter crossValidationSplitter = new CrossValidationSplitterFactory(dataExtractor.getFieldNames())
|
||||||
.create(analysis);
|
.create(analysis);
|
||||||
|
@ -223,11 +231,14 @@ public class AnalyticsProcessManager {
|
||||||
Optional<List<DataFrameDataExtractor.Row>> rows = dataExtractor.next();
|
Optional<List<DataFrameDataExtractor.Row>> rows = dataExtractor.next();
|
||||||
if (rows.isPresent()) {
|
if (rows.isPresent()) {
|
||||||
for (DataFrameDataExtractor.Row row : rows.get()) {
|
for (DataFrameDataExtractor.Row row : rows.get()) {
|
||||||
if (row.shouldSkip() == false) {
|
if (row.shouldSkip()) {
|
||||||
|
dataCountsTracker.incrementSkippedDocsCount();
|
||||||
|
} else {
|
||||||
String[] rowValues = row.getValues();
|
String[] rowValues = row.getValues();
|
||||||
System.arraycopy(rowValues, 0, record, 0, rowValues.length);
|
System.arraycopy(rowValues, 0, record, 0, rowValues.length);
|
||||||
record[record.length - 2] = String.valueOf(row.getChecksum());
|
record[record.length - 2] = String.valueOf(row.getChecksum());
|
||||||
crossValidationSplitter.process(record);
|
crossValidationSplitter.process(record, dataCountsTracker::incrementTrainingDocsCount,
|
||||||
|
dataCountsTracker::incrementTestDocsCount);
|
||||||
process.writeRecord(record);
|
process.writeRecord(record);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -253,6 +264,10 @@ public class AnalyticsProcessManager {
|
||||||
process.writeRecord(headerRecord);
|
process.writeRecord(headerRecord);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private void indexDataCounts(DataCounts dataCounts) {
|
||||||
|
IndexRequest indexRequest = new IndexRequest(MlStatsIndex.writeAlias());
|
||||||
|
}
|
||||||
|
|
||||||
private void restoreState(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config, @Nullable BytesReference state,
|
private void restoreState(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config, @Nullable BytesReference state,
|
||||||
AnalyticsProcess<AnalyticsResult> process) {
|
AnalyticsProcess<AnalyticsResult> process) {
|
||||||
if (config.getAnalysis().persistsState() == false) {
|
if (config.getAnalysis().persistsState() == false) {
|
||||||
|
@ -353,9 +368,11 @@ public class AnalyticsProcessManager {
|
||||||
private final SetOnce<DataFrameDataExtractor> dataExtractor = new SetOnce<>();
|
private final SetOnce<DataFrameDataExtractor> dataExtractor = new SetOnce<>();
|
||||||
private final SetOnce<AnalyticsResultProcessor> resultProcessor = new SetOnce<>();
|
private final SetOnce<AnalyticsResultProcessor> resultProcessor = new SetOnce<>();
|
||||||
private final SetOnce<String> failureReason = new SetOnce<>();
|
private final SetOnce<String> failureReason = new SetOnce<>();
|
||||||
|
private final StatsPersister statsPersister;
|
||||||
|
|
||||||
ProcessContext(DataFrameAnalyticsConfig config) {
|
ProcessContext(DataFrameAnalyticsConfig config) {
|
||||||
this.config = Objects.requireNonNull(config);
|
this.config = Objects.requireNonNull(config);
|
||||||
|
this.statsPersister = new StatsPersister(config.getId(), resultsPersisterService, auditor);
|
||||||
}
|
}
|
||||||
|
|
||||||
String getFailureReason() {
|
String getFailureReason() {
|
||||||
|
@ -378,6 +395,7 @@ public class AnalyticsProcessManager {
|
||||||
if (resultProcessor.get() != null) {
|
if (resultProcessor.get() != null) {
|
||||||
resultProcessor.get().cancel();
|
resultProcessor.get().cancel();
|
||||||
}
|
}
|
||||||
|
statsPersister.cancel();
|
||||||
if (process.get() != null) {
|
if (process.get() != null) {
|
||||||
try {
|
try {
|
||||||
process.get().kill();
|
process.get().kill();
|
||||||
|
@ -434,7 +452,7 @@ public class AnalyticsProcessManager {
|
||||||
DataFrameRowsJoiner dataFrameRowsJoiner =
|
DataFrameRowsJoiner dataFrameRowsJoiner =
|
||||||
new DataFrameRowsJoiner(config.getId(), dataExtractorFactory.newExtractor(true), resultsPersisterService);
|
new DataFrameRowsJoiner(config.getId(), dataExtractorFactory.newExtractor(true), resultsPersisterService);
|
||||||
return new AnalyticsResultProcessor(
|
return new AnalyticsResultProcessor(
|
||||||
config, dataFrameRowsJoiner, task.getStatsHolder(), trainedModelProvider, auditor, resultsPersisterService,
|
config, dataFrameRowsJoiner, task.getStatsHolder(), trainedModelProvider, auditor, statsPersister,
|
||||||
dataExtractor.get().getAllExtractedFields());
|
dataExtractor.get().getAllExtractedFields());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,14 +11,10 @@ import org.apache.logging.log4j.message.ParameterizedMessage;
|
||||||
import org.elasticsearch.Version;
|
import org.elasticsearch.Version;
|
||||||
import org.elasticsearch.action.ActionListener;
|
import org.elasticsearch.action.ActionListener;
|
||||||
import org.elasticsearch.action.LatchedActionListener;
|
import org.elasticsearch.action.LatchedActionListener;
|
||||||
import org.elasticsearch.action.support.WriteRequest;
|
|
||||||
import org.elasticsearch.common.Nullable;
|
import org.elasticsearch.common.Nullable;
|
||||||
import org.elasticsearch.common.xcontent.ToXContent;
|
|
||||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
|
||||||
import org.elasticsearch.common.xcontent.XContentHelper;
|
import org.elasticsearch.common.xcontent.XContentHelper;
|
||||||
import org.elasticsearch.common.xcontent.json.JsonXContent;
|
import org.elasticsearch.common.xcontent.json.JsonXContent;
|
||||||
import org.elasticsearch.license.License;
|
import org.elasticsearch.license.License;
|
||||||
import org.elasticsearch.xpack.core.ml.MlStatsIndex;
|
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||||
|
@ -31,18 +27,16 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
|
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
|
||||||
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
||||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||||
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
|
|
||||||
import org.elasticsearch.xpack.core.security.user.XPackUser;
|
import org.elasticsearch.xpack.core.security.user.XPackUser;
|
||||||
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
|
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
|
||||||
import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults;
|
import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults;
|
||||||
import org.elasticsearch.xpack.ml.dataframe.stats.StatsHolder;
|
import org.elasticsearch.xpack.ml.dataframe.stats.StatsHolder;
|
||||||
|
import org.elasticsearch.xpack.ml.dataframe.stats.StatsPersister;
|
||||||
import org.elasticsearch.xpack.ml.extractor.ExtractedField;
|
import org.elasticsearch.xpack.ml.extractor.ExtractedField;
|
||||||
import org.elasticsearch.xpack.ml.extractor.MultiField;
|
import org.elasticsearch.xpack.ml.extractor.MultiField;
|
||||||
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
||||||
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
|
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
|
||||||
import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.time.Instant;
|
import java.time.Instant;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.Iterator;
|
import java.util.Iterator;
|
||||||
|
@ -51,7 +45,6 @@ import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.concurrent.CountDownLatch;
|
import java.util.concurrent.CountDownLatch;
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
import java.util.function.Function;
|
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
import static java.util.stream.Collectors.toList;
|
import static java.util.stream.Collectors.toList;
|
||||||
|
@ -77,7 +70,7 @@ public class AnalyticsResultProcessor {
|
||||||
private final StatsHolder statsHolder;
|
private final StatsHolder statsHolder;
|
||||||
private final TrainedModelProvider trainedModelProvider;
|
private final TrainedModelProvider trainedModelProvider;
|
||||||
private final DataFrameAnalyticsAuditor auditor;
|
private final DataFrameAnalyticsAuditor auditor;
|
||||||
private final ResultsPersisterService resultsPersisterService;
|
private final StatsPersister statsPersister;
|
||||||
private final List<ExtractedField> fieldNames;
|
private final List<ExtractedField> fieldNames;
|
||||||
private final CountDownLatch completionLatch = new CountDownLatch(1);
|
private final CountDownLatch completionLatch = new CountDownLatch(1);
|
||||||
private volatile String failure;
|
private volatile String failure;
|
||||||
|
@ -85,14 +78,13 @@ public class AnalyticsResultProcessor {
|
||||||
|
|
||||||
public AnalyticsResultProcessor(DataFrameAnalyticsConfig analytics, DataFrameRowsJoiner dataFrameRowsJoiner,
|
public AnalyticsResultProcessor(DataFrameAnalyticsConfig analytics, DataFrameRowsJoiner dataFrameRowsJoiner,
|
||||||
StatsHolder statsHolder, TrainedModelProvider trainedModelProvider,
|
StatsHolder statsHolder, TrainedModelProvider trainedModelProvider,
|
||||||
DataFrameAnalyticsAuditor auditor, ResultsPersisterService resultsPersisterService,
|
DataFrameAnalyticsAuditor auditor, StatsPersister statsPersister, List<ExtractedField> fieldNames) {
|
||||||
List<ExtractedField> fieldNames) {
|
|
||||||
this.analytics = Objects.requireNonNull(analytics);
|
this.analytics = Objects.requireNonNull(analytics);
|
||||||
this.dataFrameRowsJoiner = Objects.requireNonNull(dataFrameRowsJoiner);
|
this.dataFrameRowsJoiner = Objects.requireNonNull(dataFrameRowsJoiner);
|
||||||
this.statsHolder = Objects.requireNonNull(statsHolder);
|
this.statsHolder = Objects.requireNonNull(statsHolder);
|
||||||
this.trainedModelProvider = Objects.requireNonNull(trainedModelProvider);
|
this.trainedModelProvider = Objects.requireNonNull(trainedModelProvider);
|
||||||
this.auditor = Objects.requireNonNull(auditor);
|
this.auditor = Objects.requireNonNull(auditor);
|
||||||
this.resultsPersisterService = Objects.requireNonNull(resultsPersisterService);
|
this.statsPersister = Objects.requireNonNull(statsPersister);
|
||||||
this.fieldNames = Collections.unmodifiableList(Objects.requireNonNull(fieldNames));
|
this.fieldNames = Collections.unmodifiableList(Objects.requireNonNull(fieldNames));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -112,6 +104,7 @@ public class AnalyticsResultProcessor {
|
||||||
|
|
||||||
public void cancel() {
|
public void cancel() {
|
||||||
dataFrameRowsJoiner.cancel();
|
dataFrameRowsJoiner.cancel();
|
||||||
|
statsPersister.cancel();
|
||||||
isCancelled = true;
|
isCancelled = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -176,22 +169,22 @@ public class AnalyticsResultProcessor {
|
||||||
MemoryUsage memoryUsage = result.getMemoryUsage();
|
MemoryUsage memoryUsage = result.getMemoryUsage();
|
||||||
if (memoryUsage != null) {
|
if (memoryUsage != null) {
|
||||||
statsHolder.setMemoryUsage(memoryUsage);
|
statsHolder.setMemoryUsage(memoryUsage);
|
||||||
indexStatsResult(memoryUsage, memoryUsage::documentId);
|
statsPersister.persistWithRetry(memoryUsage, memoryUsage::documentId);
|
||||||
}
|
}
|
||||||
OutlierDetectionStats outlierDetectionStats = result.getOutlierDetectionStats();
|
OutlierDetectionStats outlierDetectionStats = result.getOutlierDetectionStats();
|
||||||
if (outlierDetectionStats != null) {
|
if (outlierDetectionStats != null) {
|
||||||
statsHolder.setAnalysisStats(outlierDetectionStats);
|
statsHolder.setAnalysisStats(outlierDetectionStats);
|
||||||
indexStatsResult(outlierDetectionStats, outlierDetectionStats::documentId);
|
statsPersister.persistWithRetry(outlierDetectionStats, outlierDetectionStats::documentId);
|
||||||
}
|
}
|
||||||
ClassificationStats classificationStats = result.getClassificationStats();
|
ClassificationStats classificationStats = result.getClassificationStats();
|
||||||
if (classificationStats != null) {
|
if (classificationStats != null) {
|
||||||
statsHolder.setAnalysisStats(classificationStats);
|
statsHolder.setAnalysisStats(classificationStats);
|
||||||
indexStatsResult(classificationStats, classificationStats::documentId);
|
statsPersister.persistWithRetry(classificationStats, classificationStats::documentId);
|
||||||
}
|
}
|
||||||
RegressionStats regressionStats = result.getRegressionStats();
|
RegressionStats regressionStats = result.getRegressionStats();
|
||||||
if (regressionStats != null) {
|
if (regressionStats != null) {
|
||||||
statsHolder.setAnalysisStats(regressionStats);
|
statsHolder.setAnalysisStats(regressionStats);
|
||||||
indexStatsResult(regressionStats, regressionStats::documentId);
|
statsPersister.persistWithRetry(regressionStats, regressionStats::documentId);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -274,23 +267,4 @@ public class AnalyticsResultProcessor {
|
||||||
failure = "error processing results; " + e.getMessage();
|
failure = "error processing results; " + e.getMessage();
|
||||||
auditor.error(analytics.getId(), "Error processing results; " + e.getMessage());
|
auditor.error(analytics.getId(), "Error processing results; " + e.getMessage());
|
||||||
}
|
}
|
||||||
|
|
||||||
private void indexStatsResult(ToXContentObject result, Function<String, String> docIdSupplier) {
|
|
||||||
try {
|
|
||||||
resultsPersisterService.indexWithRetry(analytics.getId(),
|
|
||||||
MlStatsIndex.writeAlias(),
|
|
||||||
result,
|
|
||||||
new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")),
|
|
||||||
WriteRequest.RefreshPolicy.IMMEDIATE,
|
|
||||||
docIdSupplier.apply(analytics.getId()),
|
|
||||||
() -> isCancelled == false,
|
|
||||||
errorMsg -> auditor.error(analytics.getId(),
|
|
||||||
"failed to persist result with id [" + docIdSupplier.apply(analytics.getId()) + "]; " + errorMsg)
|
|
||||||
);
|
|
||||||
} catch (IOException ioe) {
|
|
||||||
LOGGER.error(() -> new ParameterizedMessage("[{}] Failed serializing stats result", analytics.getId()), ioe);
|
|
||||||
} catch (Exception e) {
|
|
||||||
LOGGER.error(() -> new ParameterizedMessage("[{}] Failed indexing stats result", analytics.getId()), e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,5 +10,5 @@ package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation;
|
||||||
*/
|
*/
|
||||||
public interface CrossValidationSplitter {
|
public interface CrossValidationSplitter {
|
||||||
|
|
||||||
void process(String[] row);
|
void process(String[] row, Runnable incrementTrainingDocs, Runnable incrementTestDocs);
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,6 +31,6 @@ public class CrossValidationSplitterFactory {
|
||||||
return new RandomCrossValidationSplitter(
|
return new RandomCrossValidationSplitter(
|
||||||
fieldNames, classification.getDependentVariable(), classification.getTrainingPercent(), classification.getRandomizeSeed());
|
fieldNames, classification.getDependentVariable(), classification.getTrainingPercent(), classification.getRandomizeSeed());
|
||||||
}
|
}
|
||||||
return row -> {};
|
return (row, incrementTrainingDocs, incrementTestDocs) -> incrementTrainingDocs.run();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -40,22 +40,25 @@ class RandomCrossValidationSplitter implements CrossValidationSplitter {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void process(String[] row) {
|
public void process(String[] row, Runnable incrementTrainingDocs, Runnable incrementTestDocs) {
|
||||||
if (canBeUsedForTraining(row)) {
|
if (canBeUsedForTraining(row) && isPickedForTraining()) {
|
||||||
if (isFirstRow) {
|
incrementTrainingDocs.run();
|
||||||
// Let's make sure we have at least one training row
|
} else {
|
||||||
isFirstRow = false;
|
|
||||||
} else if (isRandomlyExcludedFromTraining()) {
|
|
||||||
row[dependentVariableIndex] = DataFrameDataExtractor.NULL_VALUE;
|
row[dependentVariableIndex] = DataFrameDataExtractor.NULL_VALUE;
|
||||||
}
|
incrementTestDocs.run();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private boolean canBeUsedForTraining(String[] row) {
|
private boolean canBeUsedForTraining(String[] row) {
|
||||||
return row[dependentVariableIndex].length() > 0;
|
return row[dependentVariableIndex] != DataFrameDataExtractor.NULL_VALUE;
|
||||||
}
|
}
|
||||||
|
|
||||||
private boolean isRandomlyExcludedFromTraining() {
|
private boolean isPickedForTraining() {
|
||||||
return random.nextDouble() * 100 > trainingPercent;
|
if (isFirstRow) {
|
||||||
|
// Let's make sure we have at least one training row
|
||||||
|
isFirstRow = false;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return random.nextDouble() * 100 <= trainingPercent;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,37 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License;
|
||||||
|
* you may not use this file except in compliance with the Elastic License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.ml.dataframe.stats;
|
||||||
|
|
||||||
|
import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts;
|
||||||
|
|
||||||
|
public class DataCountsTracker {
|
||||||
|
|
||||||
|
private volatile long trainingDocsCount;
|
||||||
|
private volatile long testDocsCount;
|
||||||
|
private volatile long skippedDocsCount;
|
||||||
|
|
||||||
|
public void incrementTrainingDocsCount() {
|
||||||
|
trainingDocsCount++;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void incrementTestDocsCount() {
|
||||||
|
testDocsCount++;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void incrementSkippedDocsCount() {
|
||||||
|
skippedDocsCount++;
|
||||||
|
}
|
||||||
|
|
||||||
|
public DataCounts report(String jobId) {
|
||||||
|
return new DataCounts(
|
||||||
|
jobId,
|
||||||
|
trainingDocsCount,
|
||||||
|
testDocsCount,
|
||||||
|
skippedDocsCount
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
|
@ -19,11 +19,13 @@ public class StatsHolder {
|
||||||
private final ProgressTracker progressTracker;
|
private final ProgressTracker progressTracker;
|
||||||
private final AtomicReference<MemoryUsage> memoryUsageHolder;
|
private final AtomicReference<MemoryUsage> memoryUsageHolder;
|
||||||
private final AtomicReference<AnalysisStats> analysisStatsHolder;
|
private final AtomicReference<AnalysisStats> analysisStatsHolder;
|
||||||
|
private final DataCountsTracker dataCountsTracker;
|
||||||
|
|
||||||
public StatsHolder() {
|
public StatsHolder() {
|
||||||
progressTracker = new ProgressTracker();
|
progressTracker = new ProgressTracker();
|
||||||
memoryUsageHolder = new AtomicReference<>();
|
memoryUsageHolder = new AtomicReference<>();
|
||||||
analysisStatsHolder = new AtomicReference<>();
|
analysisStatsHolder = new AtomicReference<>();
|
||||||
|
dataCountsTracker = new DataCountsTracker();
|
||||||
}
|
}
|
||||||
|
|
||||||
public ProgressTracker getProgressTracker() {
|
public ProgressTracker getProgressTracker() {
|
||||||
|
@ -45,4 +47,8 @@ public class StatsHolder {
|
||||||
public AnalysisStats getAnalysisStats() {
|
public AnalysisStats getAnalysisStats() {
|
||||||
return analysisStatsHolder.get();
|
return analysisStatsHolder.get();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public DataCountsTracker getDataCountsTracker() {
|
||||||
|
return dataCountsTracker;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,66 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License;
|
||||||
|
* you may not use this file except in compliance with the Elastic License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.ml.dataframe.stats;
|
||||||
|
|
||||||
|
import org.apache.logging.log4j.LogManager;
|
||||||
|
import org.apache.logging.log4j.Logger;
|
||||||
|
import org.apache.logging.log4j.message.ParameterizedMessage;
|
||||||
|
import org.elasticsearch.action.support.WriteRequest;
|
||||||
|
import org.elasticsearch.common.xcontent.ToXContent;
|
||||||
|
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||||
|
import org.elasticsearch.xpack.core.ml.MlStatsIndex;
|
||||||
|
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
|
||||||
|
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
|
||||||
|
import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.function.Function;
|
||||||
|
|
||||||
|
public class StatsPersister {
|
||||||
|
|
||||||
|
private static final Logger LOGGER = LogManager.getLogger(StatsPersister.class);
|
||||||
|
|
||||||
|
private final String jobId;
|
||||||
|
private final ResultsPersisterService resultsPersisterService;
|
||||||
|
private final DataFrameAnalyticsAuditor auditor;
|
||||||
|
private volatile boolean isCancelled;
|
||||||
|
|
||||||
|
public StatsPersister(String jobId, ResultsPersisterService resultsPersisterService, DataFrameAnalyticsAuditor auditor) {
|
||||||
|
this.jobId = Objects.requireNonNull(jobId);
|
||||||
|
this.resultsPersisterService = Objects.requireNonNull(resultsPersisterService);
|
||||||
|
this.auditor = Objects.requireNonNull(auditor);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void persistWithRetry(ToXContentObject result, Function<String, String> docIdSupplier) {
|
||||||
|
if (isCancelled) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
resultsPersisterService.indexWithRetry(jobId,
|
||||||
|
MlStatsIndex.writeAlias(),
|
||||||
|
result,
|
||||||
|
new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")),
|
||||||
|
WriteRequest.RefreshPolicy.IMMEDIATE,
|
||||||
|
docIdSupplier.apply(jobId),
|
||||||
|
() -> isCancelled == false,
|
||||||
|
errorMsg -> auditor.error(jobId,
|
||||||
|
"failed to persist result with id [" + docIdSupplier.apply(jobId) + "]; " + errorMsg)
|
||||||
|
);
|
||||||
|
} catch (IOException ioe) {
|
||||||
|
LOGGER.error(() -> new ParameterizedMessage("[{}] Failed serializing stats result", jobId), ioe);
|
||||||
|
} catch (Exception e) {
|
||||||
|
LOGGER.error(() -> new ParameterizedMessage("[{}] Failed indexing stats result", jobId), e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void cancel() {
|
||||||
|
isCancelled = true;
|
||||||
|
}
|
||||||
|
}
|
|
@ -324,41 +324,43 @@ public class DataFrameDataExtractorTests extends ESTestCase {
|
||||||
assertThat(searchRequest, containsString("\"_source\":{\"includes\":[\"field_2\"],\"excludes\":[]}"));
|
assertThat(searchRequest, containsString("\"_source\":{\"includes\":[\"field_2\"],\"excludes\":[]}"));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testMissingValues_GivenShouldNotInclude() throws IOException {
|
public void testCollectDataSummary_GivenAnalysisSupportsMissingFields() {
|
||||||
|
TestExtractor dataExtractor = createExtractor(true, true);
|
||||||
|
|
||||||
|
// First and only batch
|
||||||
|
SearchResponse response = createSearchResponse(Arrays.asList(1_1, 1_2), Arrays.asList(2_1, 2_2));
|
||||||
|
dataExtractor.setNextResponse(response);
|
||||||
|
|
||||||
|
DataFrameDataExtractor.DataSummary dataSummary = dataExtractor.collectDataSummary();
|
||||||
|
|
||||||
|
assertThat(dataSummary.rows, equalTo(2L));
|
||||||
|
assertThat(dataSummary.cols, equalTo(2));
|
||||||
|
|
||||||
|
assertThat(dataExtractor.capturedSearchRequests.size(), equalTo(1));
|
||||||
|
String searchRequest = dataExtractor.capturedSearchRequests.get(0).request().toString().replaceAll("\\s", "");
|
||||||
|
assertThat(searchRequest, containsString("\"query\":{\"match_all\":{\"boost\":1.0}}"));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testCollectDataSummary_GivenAnalysisDoesNotSupportMissingFields() {
|
||||||
TestExtractor dataExtractor = createExtractor(true, false);
|
TestExtractor dataExtractor = createExtractor(true, false);
|
||||||
|
|
||||||
// First and only batch
|
// First and only batch
|
||||||
SearchResponse response1 = createSearchResponse(Arrays.asList(1_1, null, 1_3), Arrays.asList(2_1, 2_2, 2_3));
|
SearchResponse response = createSearchResponse(Arrays.asList(1_1, 1_2), Arrays.asList(2_1, 2_2));
|
||||||
dataExtractor.setNextResponse(response1);
|
dataExtractor.setNextResponse(response);
|
||||||
|
|
||||||
// Empty
|
DataFrameDataExtractor.DataSummary dataSummary = dataExtractor.collectDataSummary();
|
||||||
SearchResponse lastAndEmptyResponse = createEmptySearchResponse();
|
|
||||||
dataExtractor.setNextResponse(lastAndEmptyResponse);
|
|
||||||
|
|
||||||
assertThat(dataExtractor.hasNext(), is(true));
|
assertThat(dataSummary.rows, equalTo(2L));
|
||||||
|
assertThat(dataSummary.cols, equalTo(2));
|
||||||
|
|
||||||
// First batch
|
assertThat(dataExtractor.capturedSearchRequests.size(), equalTo(1));
|
||||||
Optional<List<DataFrameDataExtractor.Row>> rows = dataExtractor.next();
|
String searchRequest = dataExtractor.capturedSearchRequests.get(0).request().toString().replaceAll("\\s", "");
|
||||||
assertThat(rows.isPresent(), is(true));
|
assertThat(searchRequest, containsString(
|
||||||
assertThat(rows.get().size(), equalTo(3));
|
"\"query\":{\"bool\":{\"filter\":[{\"match_all\":{\"boost\":1.0}},{\"bool\":{\"filter\":" +
|
||||||
|
"[{\"exists\":{\"field\":\"field_1\",\"boost\":1.0}},{\"exists\":{\"field\":\"field_2\",\"boost\":1.0}}]"));
|
||||||
assertThat(rows.get().get(0).getValues(), equalTo(new String[] {"11", "21"}));
|
|
||||||
assertThat(rows.get().get(1).getValues(), is(nullValue()));
|
|
||||||
assertThat(rows.get().get(2).getValues(), equalTo(new String[] {"13", "23"}));
|
|
||||||
|
|
||||||
assertThat(rows.get().get(0).shouldSkip(), is(false));
|
|
||||||
assertThat(rows.get().get(1).shouldSkip(), is(true));
|
|
||||||
assertThat(rows.get().get(2).shouldSkip(), is(false));
|
|
||||||
|
|
||||||
assertThat(dataExtractor.hasNext(), is(true));
|
|
||||||
|
|
||||||
// Third batch should return empty
|
|
||||||
rows = dataExtractor.next();
|
|
||||||
assertThat(rows.isPresent(), is(false));
|
|
||||||
assertThat(dataExtractor.hasNext(), is(false));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testMissingValues_GivenShouldInclude() throws IOException {
|
public void testMissingValues_GivenSupported() throws IOException {
|
||||||
TestExtractor dataExtractor = createExtractor(true, true);
|
TestExtractor dataExtractor = createExtractor(true, true);
|
||||||
|
|
||||||
// First and only batch
|
// First and only batch
|
||||||
|
@ -393,6 +395,40 @@ public class DataFrameDataExtractorTests extends ESTestCase {
|
||||||
assertThat(dataExtractor.hasNext(), is(false));
|
assertThat(dataExtractor.hasNext(), is(false));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testMissingValues_GivenNotSupported() throws IOException {
|
||||||
|
TestExtractor dataExtractor = createExtractor(true, false);
|
||||||
|
|
||||||
|
// First and only batch
|
||||||
|
SearchResponse response1 = createSearchResponse(Arrays.asList(1_1, null, 1_3), Arrays.asList(2_1, 2_2, 2_3));
|
||||||
|
dataExtractor.setNextResponse(response1);
|
||||||
|
|
||||||
|
// Empty
|
||||||
|
SearchResponse lastAndEmptyResponse = createEmptySearchResponse();
|
||||||
|
dataExtractor.setNextResponse(lastAndEmptyResponse);
|
||||||
|
|
||||||
|
assertThat(dataExtractor.hasNext(), is(true));
|
||||||
|
|
||||||
|
// First batch
|
||||||
|
Optional<List<DataFrameDataExtractor.Row>> rows = dataExtractor.next();
|
||||||
|
assertThat(rows.isPresent(), is(true));
|
||||||
|
assertThat(rows.get().size(), equalTo(3));
|
||||||
|
|
||||||
|
assertThat(rows.get().get(0).getValues(), equalTo(new String[] {"11", "21"}));
|
||||||
|
assertThat(rows.get().get(1).getValues(), is(nullValue()));
|
||||||
|
assertThat(rows.get().get(2).getValues(), equalTo(new String[] {"13", "23"}));
|
||||||
|
|
||||||
|
assertThat(rows.get().get(0).shouldSkip(), is(false));
|
||||||
|
assertThat(rows.get().get(1).shouldSkip(), is(true));
|
||||||
|
assertThat(rows.get().get(2).shouldSkip(), is(false));
|
||||||
|
|
||||||
|
assertThat(dataExtractor.hasNext(), is(true));
|
||||||
|
|
||||||
|
// Third batch should return empty
|
||||||
|
rows = dataExtractor.next();
|
||||||
|
assertThat(rows.isPresent(), is(false));
|
||||||
|
assertThat(dataExtractor.hasNext(), is(false));
|
||||||
|
}
|
||||||
|
|
||||||
public void testGetCategoricalFields() {
|
public void testGetCategoricalFields() {
|
||||||
// Explicit cast of ExtractedField args necessary for Eclipse due to https://bugs.eclipse.org/bugs/show_bug.cgi?id=530915
|
// Explicit cast of ExtractedField args necessary for Eclipse due to https://bugs.eclipse.org/bugs/show_bug.cgi?id=530915
|
||||||
extractedFields = new ExtractedFields(Arrays.asList(
|
extractedFields = new ExtractedFields(Arrays.asList(
|
||||||
|
@ -424,9 +460,9 @@ public class DataFrameDataExtractorTests extends ESTestCase {
|
||||||
containsInAnyOrder("field_keyword", "field_text", "field_boolean"));
|
containsInAnyOrder("field_keyword", "field_text", "field_boolean"));
|
||||||
}
|
}
|
||||||
|
|
||||||
private TestExtractor createExtractor(boolean includeSource, boolean includeRowsWithMissingValues) {
|
private TestExtractor createExtractor(boolean includeSource, boolean supportsRowsWithMissingValues) {
|
||||||
DataFrameDataExtractorContext context = new DataFrameDataExtractorContext(
|
DataFrameDataExtractorContext context = new DataFrameDataExtractorContext(
|
||||||
JOB_ID, extractedFields, indices, query, scrollSize, headers, includeSource, includeRowsWithMissingValues);
|
JOB_ID, extractedFields, indices, query, scrollSize, headers, includeSource, supportsRowsWithMissingValues);
|
||||||
return new TestExtractor(client, context);
|
return new TestExtractor(client, context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -24,13 +24,13 @@ import org.elasticsearch.xpack.core.security.user.XPackUser;
|
||||||
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
|
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
|
||||||
import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults;
|
import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults;
|
||||||
import org.elasticsearch.xpack.ml.dataframe.stats.StatsHolder;
|
import org.elasticsearch.xpack.ml.dataframe.stats.StatsHolder;
|
||||||
|
import org.elasticsearch.xpack.ml.dataframe.stats.StatsPersister;
|
||||||
import org.elasticsearch.xpack.ml.extractor.DocValueField;
|
import org.elasticsearch.xpack.ml.extractor.DocValueField;
|
||||||
import org.elasticsearch.xpack.ml.extractor.ExtractedField;
|
import org.elasticsearch.xpack.ml.extractor.ExtractedField;
|
||||||
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
|
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
|
||||||
import org.elasticsearch.xpack.ml.extractor.MultiField;
|
import org.elasticsearch.xpack.ml.extractor.MultiField;
|
||||||
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
||||||
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
|
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
|
||||||
import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;
|
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
import org.mockito.ArgumentCaptor;
|
import org.mockito.ArgumentCaptor;
|
||||||
import org.mockito.InOrder;
|
import org.mockito.InOrder;
|
||||||
|
@ -66,7 +66,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
|
||||||
private StatsHolder statsHolder = new StatsHolder();
|
private StatsHolder statsHolder = new StatsHolder();
|
||||||
private TrainedModelProvider trainedModelProvider;
|
private TrainedModelProvider trainedModelProvider;
|
||||||
private DataFrameAnalyticsAuditor auditor;
|
private DataFrameAnalyticsAuditor auditor;
|
||||||
private ResultsPersisterService resultsPersisterService;
|
private StatsPersister statsPersister;
|
||||||
private DataFrameAnalyticsConfig analyticsConfig;
|
private DataFrameAnalyticsConfig analyticsConfig;
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
|
@ -76,7 +76,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
|
||||||
dataFrameRowsJoiner = mock(DataFrameRowsJoiner.class);
|
dataFrameRowsJoiner = mock(DataFrameRowsJoiner.class);
|
||||||
trainedModelProvider = mock(TrainedModelProvider.class);
|
trainedModelProvider = mock(TrainedModelProvider.class);
|
||||||
auditor = mock(DataFrameAnalyticsAuditor.class);
|
auditor = mock(DataFrameAnalyticsAuditor.class);
|
||||||
resultsPersisterService = mock(ResultsPersisterService.class);
|
statsPersister = mock(StatsPersister.class);
|
||||||
analyticsConfig = new DataFrameAnalyticsConfig.Builder()
|
analyticsConfig = new DataFrameAnalyticsConfig.Builder()
|
||||||
.setId(JOB_ID)
|
.setId(JOB_ID)
|
||||||
.setDescription(JOB_DESCRIPTION)
|
.setDescription(JOB_DESCRIPTION)
|
||||||
|
@ -251,7 +251,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
|
||||||
statsHolder,
|
statsHolder,
|
||||||
trainedModelProvider,
|
trainedModelProvider,
|
||||||
auditor,
|
auditor,
|
||||||
resultsPersisterService,
|
statsPersister,
|
||||||
fieldNames);
|
fieldNames);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,6 +26,8 @@ public class RandomCrossValidationSplitterTests extends ESTestCase {
|
||||||
private int dependentVariableIndex;
|
private int dependentVariableIndex;
|
||||||
private String dependentVariable;
|
private String dependentVariable;
|
||||||
private long randomizeSeed;
|
private long randomizeSeed;
|
||||||
|
private long trainingDocsCount;
|
||||||
|
private long testDocsCount;
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUpTests() {
|
public void setUpTests() {
|
||||||
|
@ -40,47 +42,48 @@ public class RandomCrossValidationSplitterTests extends ESTestCase {
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testProcess_GivenRowsWithoutDependentVariableValue() {
|
public void testProcess_GivenRowsWithoutDependentVariableValue() {
|
||||||
CrossValidationSplitter crossValidationSplitter = new RandomCrossValidationSplitter(fields, dependentVariable, 50.0, randomizeSeed);
|
CrossValidationSplitter crossValidationSplitter = createSplitter(50.0);
|
||||||
|
|
||||||
for (int i = 0; i < 100; i++) {
|
for (int i = 0; i < 100; i++) {
|
||||||
String[] row = new String[fields.size()];
|
String[] row = new String[fields.size()];
|
||||||
for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) {
|
for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) {
|
||||||
String value = fieldIndex == dependentVariableIndex ? "" : randomAlphaOfLength(10);
|
String value = fieldIndex == dependentVariableIndex ? DataFrameDataExtractor.NULL_VALUE : randomAlphaOfLength(10);
|
||||||
row[fieldIndex] = value;
|
row[fieldIndex] = value;
|
||||||
}
|
}
|
||||||
|
|
||||||
String[] processedRow = Arrays.copyOf(row, row.length);
|
String[] processedRow = Arrays.copyOf(row, row.length);
|
||||||
crossValidationSplitter.process(processedRow);
|
crossValidationSplitter.process(processedRow, this::incrementTrainingDocsCount, this::incrementTestDocsCount);
|
||||||
|
|
||||||
// As all these rows have no dependent variable value, they're not for training and should be unaffected
|
// As all these rows have no dependent variable value, they're not for training and should be unaffected
|
||||||
assertThat(Arrays.equals(processedRow, row), is(true));
|
assertThat(Arrays.equals(processedRow, row), is(true));
|
||||||
}
|
}
|
||||||
|
assertThat(trainingDocsCount, equalTo(0L));
|
||||||
|
assertThat(testDocsCount, equalTo(100L));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIsHundred() {
|
public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIsHundred() {
|
||||||
CrossValidationSplitter crossValidationSplitter = new RandomCrossValidationSplitter(
|
CrossValidationSplitter crossValidationSplitter = createSplitter(100.0);
|
||||||
fields, dependentVariable, 100.0, randomizeSeed);
|
|
||||||
|
|
||||||
for (int i = 0; i < 100; i++) {
|
for (int i = 0; i < 100; i++) {
|
||||||
String[] row = new String[fields.size()];
|
String[] row = new String[fields.size()];
|
||||||
for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) {
|
for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) {
|
||||||
String value = fieldIndex == dependentVariableIndex ? "" : randomAlphaOfLength(10);
|
row[fieldIndex] = randomAlphaOfLength(10);
|
||||||
row[fieldIndex] = value;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
String[] processedRow = Arrays.copyOf(row, row.length);
|
String[] processedRow = Arrays.copyOf(row, row.length);
|
||||||
crossValidationSplitter.process(processedRow);
|
crossValidationSplitter.process(processedRow, this::incrementTrainingDocsCount, this::incrementTestDocsCount);
|
||||||
|
|
||||||
// We should pick them all as training percent is 100
|
// We should pick them all as training percent is 100
|
||||||
assertThat(Arrays.equals(processedRow, row), is(true));
|
assertThat(Arrays.equals(processedRow, row), is(true));
|
||||||
}
|
}
|
||||||
|
assertThat(trainingDocsCount, equalTo(100L));
|
||||||
|
assertThat(testDocsCount, equalTo(0L));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIsRandom() {
|
public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIsRandom() {
|
||||||
double trainingPercent = randomDoubleBetween(1.0, 100.0, true);
|
double trainingPercent = randomDoubleBetween(1.0, 100.0, true);
|
||||||
double trainingFraction = trainingPercent / 100;
|
double trainingFraction = trainingPercent / 100;
|
||||||
CrossValidationSplitter crossValidationSplitter = new RandomCrossValidationSplitter(
|
CrossValidationSplitter crossValidationSplitter = createSplitter(trainingPercent);
|
||||||
fields, dependentVariable, trainingPercent, randomizeSeed);
|
|
||||||
|
|
||||||
int runCount = 20;
|
int runCount = 20;
|
||||||
int rowsCount = 1000;
|
int rowsCount = 1000;
|
||||||
|
@ -94,7 +97,7 @@ public class RandomCrossValidationSplitterTests extends ESTestCase {
|
||||||
}
|
}
|
||||||
|
|
||||||
String[] processedRow = Arrays.copyOf(row, row.length);
|
String[] processedRow = Arrays.copyOf(row, row.length);
|
||||||
crossValidationSplitter.process(processedRow);
|
crossValidationSplitter.process(processedRow, this::incrementTrainingDocsCount, this::incrementTestDocsCount);
|
||||||
|
|
||||||
for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) {
|
for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) {
|
||||||
if (fieldIndex != dependentVariableIndex) {
|
if (fieldIndex != dependentVariableIndex) {
|
||||||
|
@ -126,8 +129,7 @@ public class RandomCrossValidationSplitterTests extends ESTestCase {
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testProcess_ShouldHaveAtLeastOneTrainingRow() {
|
public void testProcess_ShouldHaveAtLeastOneTrainingRow() {
|
||||||
CrossValidationSplitter crossValidationSplitter = new RandomCrossValidationSplitter(
|
CrossValidationSplitter crossValidationSplitter = createSplitter(1.0);
|
||||||
fields, dependentVariable, 1.0, randomizeSeed);
|
|
||||||
|
|
||||||
// We have some non-training rows and then a training row to check
|
// We have some non-training rows and then a training row to check
|
||||||
// we maintain the first training row and not just the first row
|
// we maintain the first training row and not just the first row
|
||||||
|
@ -135,16 +137,30 @@ public class RandomCrossValidationSplitterTests extends ESTestCase {
|
||||||
String[] row = new String[fields.size()];
|
String[] row = new String[fields.size()];
|
||||||
for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) {
|
for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) {
|
||||||
if (i < 9 && fieldIndex == dependentVariableIndex) {
|
if (i < 9 && fieldIndex == dependentVariableIndex) {
|
||||||
row[fieldIndex] = "";
|
row[fieldIndex] = DataFrameDataExtractor.NULL_VALUE;
|
||||||
} else {
|
} else {
|
||||||
row[fieldIndex] = randomAlphaOfLength(10);
|
row[fieldIndex] = randomAlphaOfLength(10);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
String[] processedRow = Arrays.copyOf(row, row.length);
|
String[] processedRow = Arrays.copyOf(row, row.length);
|
||||||
crossValidationSplitter.process(processedRow);
|
crossValidationSplitter.process(processedRow, this::incrementTrainingDocsCount, this::incrementTestDocsCount);
|
||||||
|
|
||||||
assertThat(Arrays.equals(processedRow, row), is(true));
|
assertThat(Arrays.equals(processedRow, row), is(true));
|
||||||
}
|
}
|
||||||
|
assertThat(trainingDocsCount, equalTo(1L));
|
||||||
|
assertThat(testDocsCount, equalTo(9L));
|
||||||
|
}
|
||||||
|
|
||||||
|
private RandomCrossValidationSplitter createSplitter(double trainingPercent) {
|
||||||
|
return new RandomCrossValidationSplitter(fields, dependentVariable, trainingPercent, randomizeSeed);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void incrementTrainingDocsCount() {
|
||||||
|
trainingDocsCount++;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void incrementTestDocsCount() {
|
||||||
|
testDocsCount++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue