diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStats.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStats.java index 70169ce09b2..acdb9cccca1 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStats.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStats.java @@ -21,6 +21,7 @@ package org.elasticsearch.client.ml.dataframe; import org.elasticsearch.client.ml.NodeAttributes; import org.elasticsearch.client.ml.dataframe.stats.AnalysisStats; +import org.elasticsearch.client.ml.dataframe.stats.common.DataCounts; import org.elasticsearch.client.ml.dataframe.stats.common.MemoryUsage; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; @@ -47,6 +48,7 @@ public class DataFrameAnalyticsStats { static final ParseField STATE = new ParseField("state"); static final ParseField FAILURE_REASON = new ParseField("failure_reason"); static final ParseField PROGRESS = new ParseField("progress"); + static final ParseField DATA_COUNTS = new ParseField("data_counts"); static final ParseField MEMORY_USAGE = new ParseField("memory_usage"); static final ParseField ANALYSIS_STATS = new ParseField("analysis_stats"); static final ParseField NODE = new ParseField("node"); @@ -60,10 +62,11 @@ public class DataFrameAnalyticsStats { (DataFrameAnalyticsState) args[1], (String) args[2], (List) args[3], - (MemoryUsage) args[4], - (AnalysisStats) args[5], - (NodeAttributes) args[6], - (String) args[7])); + (DataCounts) args[4], + (MemoryUsage) args[5], + (AnalysisStats) args[6], + (NodeAttributes) args[7], + (String) args[8])); static { PARSER.declareString(constructorArg(), ID); @@ -75,6 +78,7 @@ public class DataFrameAnalyticsStats { }, STATE, ObjectParser.ValueType.STRING); PARSER.declareString(optionalConstructorArg(), FAILURE_REASON); PARSER.declareObjectArray(optionalConstructorArg(), PhaseProgress.PARSER, PROGRESS); + PARSER.declareObject(optionalConstructorArg(), DataCounts.PARSER, DATA_COUNTS); PARSER.declareObject(optionalConstructorArg(), MemoryUsage.PARSER, MEMORY_USAGE); PARSER.declareObject(optionalConstructorArg(), (p, c) -> parseAnalysisStats(p), ANALYSIS_STATS); PARSER.declareObject(optionalConstructorArg(), NodeAttributes.PARSER, NODE); @@ -93,19 +97,21 @@ public class DataFrameAnalyticsStats { private final DataFrameAnalyticsState state; private final String failureReason; private final List progress; + private final DataCounts dataCounts; private final MemoryUsage memoryUsage; private final AnalysisStats analysisStats; private final NodeAttributes node; private final String assignmentExplanation; public DataFrameAnalyticsStats(String id, DataFrameAnalyticsState state, @Nullable String failureReason, - @Nullable List progress, @Nullable MemoryUsage memoryUsage, - @Nullable AnalysisStats analysisStats, @Nullable NodeAttributes node, + @Nullable List progress, @Nullable DataCounts dataCounts, + @Nullable MemoryUsage memoryUsage, @Nullable AnalysisStats analysisStats, @Nullable NodeAttributes node, @Nullable String assignmentExplanation) { this.id = id; this.state = state; this.failureReason = failureReason; this.progress = progress; + this.dataCounts = dataCounts; this.memoryUsage = memoryUsage; this.analysisStats = analysisStats; this.node = node; @@ -128,6 +134,11 @@ public class DataFrameAnalyticsStats { return progress; } + @Nullable + public DataCounts getDataCounts() { + return dataCounts; + } + @Nullable public MemoryUsage getMemoryUsage() { return memoryUsage; @@ -156,6 +167,7 @@ public class DataFrameAnalyticsStats { && Objects.equals(state, other.state) && Objects.equals(failureReason, other.failureReason) && Objects.equals(progress, other.progress) + && Objects.equals(dataCounts, other.dataCounts) && Objects.equals(memoryUsage, other.memoryUsage) && Objects.equals(analysisStats, other.analysisStats) && Objects.equals(node, other.node) @@ -164,7 +176,7 @@ public class DataFrameAnalyticsStats { @Override public int hashCode() { - return Objects.hash(id, state, failureReason, progress, memoryUsage, analysisStats, node, assignmentExplanation); + return Objects.hash(id, state, failureReason, progress, dataCounts, memoryUsage, analysisStats, node, assignmentExplanation); } @Override @@ -174,6 +186,7 @@ public class DataFrameAnalyticsStats { .add("state", state) .add("failureReason", failureReason) .add("progress", progress) + .add("dataCounts", dataCounts) .add("memoryUsage", memoryUsage) .add("analysisStats", analysisStats) .add("node", node) diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/stats/common/DataCounts.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/stats/common/DataCounts.java new file mode 100644 index 00000000000..b7a90b1f0b5 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/stats/common/DataCounts.java @@ -0,0 +1,119 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml.dataframe.stats.common; + +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.inject.internal.ToStringBuilder; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; + +public class DataCounts implements ToXContentObject { + + public static final String TYPE_VALUE = "analytics_data_counts"; + + public static final ParseField TRAINING_DOCS_COUNT = new ParseField("training_docs_count"); + public static final ParseField TEST_DOCS_COUNT = new ParseField("test_docs_count"); + public static final ParseField SKIPPED_DOCS_COUNT = new ParseField("skipped_docs_count"); + + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(TYPE_VALUE, true, + a -> { + Long trainingDocsCount = (Long) a[0]; + Long testDocsCount = (Long) a[1]; + Long skippedDocsCount = (Long) a[2]; + return new DataCounts( + getOrDefault(trainingDocsCount, 0L), + getOrDefault(testDocsCount, 0L), + getOrDefault(skippedDocsCount, 0L) + ); + }); + + static { + PARSER.declareLong(optionalConstructorArg(), TRAINING_DOCS_COUNT); + PARSER.declareLong(optionalConstructorArg(), TEST_DOCS_COUNT); + PARSER.declareLong(optionalConstructorArg(), SKIPPED_DOCS_COUNT); + } + + private final long trainingDocsCount; + private final long testDocsCount; + private final long skippedDocsCount; + + public DataCounts(long trainingDocsCount, long testDocsCount, long skippedDocsCount) { + this.trainingDocsCount = trainingDocsCount; + this.testDocsCount = testDocsCount; + this.skippedDocsCount = skippedDocsCount; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(TRAINING_DOCS_COUNT.getPreferredName(), trainingDocsCount); + builder.field(TEST_DOCS_COUNT.getPreferredName(), testDocsCount); + builder.field(SKIPPED_DOCS_COUNT.getPreferredName(), skippedDocsCount); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + DataCounts that = (DataCounts) o; + return trainingDocsCount == that.trainingDocsCount + && testDocsCount == that.testDocsCount + && skippedDocsCount == that.skippedDocsCount; + } + + @Override + public int hashCode() { + return Objects.hash(trainingDocsCount, testDocsCount, skippedDocsCount); + } + + @Override + public String toString() { + return new ToStringBuilder(getClass()) + .add(TRAINING_DOCS_COUNT.getPreferredName(), trainingDocsCount) + .add(TEST_DOCS_COUNT.getPreferredName(), testDocsCount) + .add(SKIPPED_DOCS_COUNT.getPreferredName(), skippedDocsCount) + .toString(); + } + + public long getTrainingDocsCount() { + return trainingDocsCount; + } + + public long getTestDocsCount() { + return testDocsCount; + } + + public long getSkippedDocsCount() { + return skippedDocsCount; + } + + private static T getOrDefault(@Nullable T value, T defaultValue) { + return value != null ? value : defaultValue; + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStatsTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStatsTests.java index 25345181982..d251f568dfa 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStatsTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStatsTests.java @@ -23,6 +23,7 @@ import org.elasticsearch.client.ml.NodeAttributesTests; import org.elasticsearch.client.ml.dataframe.stats.AnalysisStats; import org.elasticsearch.client.ml.dataframe.stats.AnalysisStatsNamedXContentProvider; import org.elasticsearch.client.ml.dataframe.stats.classification.ClassificationStatsTests; +import org.elasticsearch.client.ml.dataframe.stats.common.DataCountsTests; import org.elasticsearch.client.ml.dataframe.stats.common.MemoryUsageTests; import org.elasticsearch.client.ml.dataframe.stats.outlierdetection.OutlierDetectionStatsTests; import org.elasticsearch.client.ml.dataframe.stats.regression.RegressionStatsTests; @@ -68,6 +69,7 @@ public class DataFrameAnalyticsStatsTests extends ESTestCase { randomFrom(DataFrameAnalyticsState.values()), randomBoolean() ? null : randomAlphaOfLength(10), randomBoolean() ? null : createRandomProgress(), + randomBoolean() ? null : DataCountsTests.createRandom(), randomBoolean() ? null : MemoryUsageTests.createRandom(), analysisStats, randomBoolean() ? null : NodeAttributesTests.createRandom(), @@ -93,6 +95,9 @@ public class DataFrameAnalyticsStatsTests extends ESTestCase { if (stats.getProgress() != null) { builder.field(DataFrameAnalyticsStats.PROGRESS.getPreferredName(), stats.getProgress()); } + if (stats.getDataCounts() != null) { + builder.field(DataFrameAnalyticsStats.DATA_COUNTS.getPreferredName(), stats.getDataCounts()); + } if (stats.getMemoryUsage() != null) { builder.field(DataFrameAnalyticsStats.MEMORY_USAGE.getPreferredName(), stats.getMemoryUsage()); } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/stats/common/DataCountsTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/stats/common/DataCountsTests.java new file mode 100644 index 00000000000..5e877e2d40f --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/stats/common/DataCountsTests.java @@ -0,0 +1,51 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml.dataframe.stats.common; + +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; + +public class DataCountsTests extends AbstractXContentTestCase { + + @Override + protected DataCounts createTestInstance() { + return createRandom(); + } + + public static DataCounts createRandom() { + return new DataCounts( + randomNonNegativeLong(), + randomNonNegativeLong(), + randomNonNegativeLong() + ); + } + + @Override + protected DataCounts doParseInstance(XContentParser parser) throws IOException { + return DataCounts.PARSER.apply(parser, null); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsAction.java index 209058e0046..e37ccbbefc5 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsAction.java @@ -29,6 +29,7 @@ import org.elasticsearch.xpack.core.action.util.QueryPage; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStats; +import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts; import org.elasticsearch.xpack.core.ml.dataframe.stats.MemoryUsage; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.PhaseProgress; @@ -165,6 +166,9 @@ public class GetDataFrameAnalyticsStatsAction extends ActionType progress; + @Nullable + private final DataCounts dataCounts; + @Nullable private final MemoryUsage memoryUsage; @@ -177,12 +181,13 @@ public class GetDataFrameAnalyticsStatsAction extends ActionType progress, - @Nullable MemoryUsage memoryUsage, @Nullable AnalysisStats analysisStats, @Nullable DiscoveryNode node, - @Nullable String assignmentExplanation) { + @Nullable DataCounts dataCounts, @Nullable MemoryUsage memoryUsage, @Nullable AnalysisStats analysisStats, + @Nullable DiscoveryNode node, @Nullable String assignmentExplanation) { this.id = Objects.requireNonNull(id); this.state = Objects.requireNonNull(state); this.failureReason = failureReason; this.progress = Objects.requireNonNull(progress); + this.dataCounts = dataCounts; this.memoryUsage = memoryUsage; this.analysisStats = analysisStats; this.node = node; @@ -198,6 +203,11 @@ public class GetDataFrameAnalyticsStatsAction extends ActionType STRICT_PARSER = createParser(false); + public static final ConstructingObjectParser LENIENT_PARSER = createParser(true); + + private static ConstructingObjectParser createParser(boolean ignoreUnknownFields) { + ConstructingObjectParser parser = new ConstructingObjectParser<>(TYPE_VALUE, ignoreUnknownFields, + a -> new DataCounts((String) a[0], (long) a[1], (long) a[2], (long) a[3])); + + parser.declareString((bucket, s) -> {}, Fields.TYPE); + parser.declareString(ConstructingObjectParser.constructorArg(), Fields.JOB_ID); + parser.declareLong(ConstructingObjectParser.constructorArg(), TRAINING_DOCS_COUNT); + parser.declareLong(ConstructingObjectParser.constructorArg(), TEST_DOCS_COUNT); + parser.declareLong(ConstructingObjectParser.constructorArg(), SKIPPED_DOCS_COUNT); + return parser; + } + + private final String jobId; + private final long trainingDocsCount; + private final long testDocsCount; + private final long skippedDocsCount; + + public DataCounts(String jobId, long trainingDocsCount, long testDocsCount, long skippedDocsCount) { + this.jobId = Objects.requireNonNull(jobId); + this.trainingDocsCount = trainingDocsCount; + this.testDocsCount = testDocsCount; + this.skippedDocsCount = skippedDocsCount; + } + + public DataCounts(StreamInput in) throws IOException { + this.jobId = in.readString(); + this.trainingDocsCount = in.readVLong(); + this.testDocsCount = in.readVLong(); + this.skippedDocsCount = in.readVLong(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(jobId); + out.writeVLong(trainingDocsCount); + out.writeVLong(testDocsCount); + out.writeVLong(skippedDocsCount); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false)) { + builder.field(Fields.TYPE.getPreferredName(), TYPE_VALUE); + builder.field(Fields.JOB_ID.getPreferredName(), jobId); + } + builder.field(TRAINING_DOCS_COUNT.getPreferredName(), trainingDocsCount); + builder.field(TEST_DOCS_COUNT.getPreferredName(), testDocsCount); + builder.field(SKIPPED_DOCS_COUNT.getPreferredName(), skippedDocsCount); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + DataCounts that = (DataCounts) o; + return Objects.equals(jobId, that.jobId) + && trainingDocsCount == that.trainingDocsCount + && testDocsCount == that.testDocsCount + && skippedDocsCount == that.skippedDocsCount; + } + + @Override + public int hashCode() { + return Objects.hash(jobId, trainingDocsCount, testDocsCount, skippedDocsCount); + } + + public static String documentId(String jobId) { + return TYPE_VALUE + "_" + jobId; + } + + public String getJobId() { + return jobId; + } + + public long getTrainingDocsCount() { + return trainingDocsCount; + } + + public long getTestDocsCount() { + return testDocsCount; + } + + public long getSkippedDocsCount() { + return skippedDocsCount; + } +} diff --git a/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/stats_index_mappings.json b/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/stats_index_mappings.json index f7f5e1e4d20..d30586bebb1 100644 --- a/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/stats_index_mappings.json +++ b/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/stats_index_mappings.json @@ -85,6 +85,9 @@ "peak_usage_bytes" : { "type" : "long" }, + "skipped_docs_count": { + "type": "long" + }, "timestamp" : { "type" : "date" }, @@ -98,6 +101,12 @@ } } }, + "test_docs_count": { + "type": "long" + }, + "training_docs_count": { + "type": "long" + }, "type" : { "type" : "keyword" }, diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsActionResponseTests.java index 5cb2b3fef54..3f957f95a90 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsActionResponseTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsActionResponseTests.java @@ -14,6 +14,8 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigTests; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStats; import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStatsNamedWriteablesProvider; +import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts; +import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCountsTests; import org.elasticsearch.xpack.core.ml.dataframe.stats.MemoryUsage; import org.elasticsearch.xpack.core.ml.dataframe.stats.MemoryUsageTests; import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStatsTests; @@ -42,6 +44,7 @@ public class GetDataFrameAnalyticsStatsActionResponseTests extends AbstractWireS List progress = new ArrayList<>(progressSize); IntStream.of(progressSize).forEach(progressIndex -> progress.add( new PhaseProgress(randomAlphaOfLength(10), randomIntBetween(0, 100)))); + DataCounts dataCounts = randomBoolean() ? null : DataCountsTests.createRandom(); MemoryUsage memoryUsage = randomBoolean() ? null : MemoryUsageTests.createRandom(); AnalysisStats analysisStats = randomBoolean() ? null : randomFrom( @@ -50,7 +53,7 @@ public class GetDataFrameAnalyticsStatsActionResponseTests extends AbstractWireS RegressionStatsTests.createRandom() ); Response.Stats stats = new Response.Stats(DataFrameAnalyticsConfigTests.randomValidId(), - randomFrom(DataFrameAnalyticsState.values()), failureReason, progress, memoryUsage, analysisStats, null, + randomFrom(DataFrameAnalyticsState.values()), failureReason, progress, dataCounts, memoryUsage, analysisStats, null, randomAlphaOfLength(20)); analytics.add(stats); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/stats/common/DataCountsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/stats/common/DataCountsTests.java new file mode 100644 index 00000000000..84033d49de6 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/stats/common/DataCountsTests.java @@ -0,0 +1,66 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.stats.common; + +import org.elasticsearch.Version; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase; +import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; +import org.junit.Before; + +import java.io.IOException; +import java.util.Collections; + +public class DataCountsTests extends AbstractBWCSerializationTestCase { + + private boolean lenient; + + @Before + public void chooseLenient() { + lenient = randomBoolean(); + } + + @Override + protected boolean supportsUnknownFields() { + return lenient; + } + + @Override + protected DataCounts mutateInstanceForVersion(DataCounts instance, Version version) { + return instance; + } + + @Override + protected DataCounts doParseInstance(XContentParser parser) throws IOException { + return lenient ? DataCounts.LENIENT_PARSER.apply(parser, null) : DataCounts.STRICT_PARSER.apply(parser, null); + } + + @Override + protected ToXContent.Params getToXContentParams() { + return new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")); + } + + @Override + protected Writeable.Reader instanceReader() { + return DataCounts::new; + } + + @Override + protected DataCounts createTestInstance() { + return createRandom(); + } + + public static DataCounts createRandom() { + return new DataCounts( + randomAlphaOfLength(10), + randomNonNegativeLong(), + randomNonNegativeLong(), + randomNonNegativeLong() + ); + } +} diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java index 51d33ea62c6..f0b39889f9d 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java @@ -22,6 +22,7 @@ import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.SearchHit; import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction; +import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams; @@ -49,6 +50,7 @@ import static org.hamcrest.Matchers.hasKey; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.in; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.lessThan; import static org.hamcrest.Matchers.lessThanOrEqualTo; import static org.hamcrest.Matchers.startsWith; @@ -158,6 +160,12 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { assertTopClasses(resultsObject, 2, KEYWORD_FIELD, KEYWORD_FIELD_VALUES); } + GetDataFrameAnalyticsStatsAction.Response.Stats stats = getAnalyticsStats(jobId); + assertThat(stats.getDataCounts().getJobId(), equalTo(jobId)); + assertThat(stats.getDataCounts().getTrainingDocsCount(), equalTo(300L)); + assertThat(stats.getDataCounts().getTestDocsCount(), equalTo(0L)); + assertThat(stats.getDataCounts().getSkippedDocsCount(), equalTo(0L)); + assertProgress(jobId, 100, 100, 100, 100); assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); assertModelStatePersisted(stateDocId()); @@ -225,6 +233,14 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { assertThat(trainingRowsCount, greaterThan(0)); assertThat(nonTrainingRowsCount, greaterThan(0)); + GetDataFrameAnalyticsStatsAction.Response.Stats stats = getAnalyticsStats(jobId); + assertThat(stats.getDataCounts().getJobId(), equalTo(jobId)); + assertThat(stats.getDataCounts().getTrainingDocsCount(), greaterThan(0L)); + assertThat(stats.getDataCounts().getTrainingDocsCount(), lessThan(300L)); + assertThat(stats.getDataCounts().getTestDocsCount(), greaterThan(0L)); + assertThat(stats.getDataCounts().getTestDocsCount(), lessThan(300L)); + assertThat(stats.getDataCounts().getSkippedDocsCount(), equalTo(0L)); + assertProgress(jobId, 100, 100, 100, 100); assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); assertModelStatePersisted(stateDocId()); diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/OutlierDetectionWithMissingFieldsIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/OutlierDetectionWithMissingFieldsIT.java index 26ae36be99f..c30f1c1a983 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/OutlierDetectionWithMissingFieldsIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/OutlierDetectionWithMissingFieldsIT.java @@ -12,6 +12,7 @@ import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.search.SearchHit; +import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection; import org.junit.After; @@ -79,6 +80,12 @@ public class OutlierDetectionWithMissingFieldsIT extends MlNativeDataFrameAnalyt startAnalytics(id); waitUntilAnalyticsIsStopped(id); + GetDataFrameAnalyticsStatsAction.Response.Stats stats = getAnalyticsStats(id); + assertThat(stats.getDataCounts().getJobId(), equalTo(id)); + assertThat(stats.getDataCounts().getTrainingDocsCount(), equalTo(5L)); + assertThat(stats.getDataCounts().getTestDocsCount(), equalTo(0L)); + assertThat(stats.getDataCounts().getSkippedDocsCount(), equalTo(2L)); + SearchResponse sourceData = client().prepareSearch(sourceIndex).get(); for (SearchHit hit : sourceData.getHits()) { GetResponse destDocGetResponse = client().prepareGet().setIndex(config.getDest().getIndex()).setId(hit.getId()).get(); diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java index 313cbac6de1..3536971bcf2 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java @@ -15,6 +15,7 @@ import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.SearchHit; +import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams; @@ -33,6 +34,7 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.lessThan; public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase { @@ -143,6 +145,13 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase { assertProgress(jobId, 100, 100, 100, 100); assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); + + GetDataFrameAnalyticsStatsAction.Response.Stats stats = getAnalyticsStats(jobId); + assertThat(stats.getDataCounts().getJobId(), equalTo(jobId)); + assertThat(stats.getDataCounts().getTrainingDocsCount(), equalTo(350L)); + assertThat(stats.getDataCounts().getTestDocsCount(), equalTo(0L)); + assertThat(stats.getDataCounts().getSkippedDocsCount(), equalTo(0L)); + assertModelStatePersisted(stateDocId()); assertInferenceModelPersisted(jobId); assertMlResultsFieldMappings(destIndex, predictedClassField, "double"); @@ -199,6 +208,14 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase { assertThat(trainingRowsCount, greaterThan(0)); assertThat(nonTrainingRowsCount, greaterThan(0)); + GetDataFrameAnalyticsStatsAction.Response.Stats stats = getAnalyticsStats(jobId); + assertThat(stats.getDataCounts().getJobId(), equalTo(jobId)); + assertThat(stats.getDataCounts().getTrainingDocsCount(), greaterThan(0L)); + assertThat(stats.getDataCounts().getTrainingDocsCount(), lessThan(350L)); + assertThat(stats.getDataCounts().getTestDocsCount(), greaterThan(0L)); + assertThat(stats.getDataCounts().getTestDocsCount(), lessThan(350L)); + assertThat(stats.getDataCounts().getSkippedDocsCount(), equalTo(0L)); + assertProgress(jobId, 100, 100, 100, 100); assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); assertModelStatePersisted(stateDocId()); diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java index f1b2c5eb2ee..70c5150eda4 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java @@ -85,6 +85,11 @@ public class RunDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsIntegTest startAnalytics(id); waitUntilAnalyticsIsStopped(id); + GetDataFrameAnalyticsStatsAction.Response.Stats stats = getAnalyticsStats(id); + assertThat(stats.getDataCounts().getJobId(), equalTo(id)); + assertThat(stats.getDataCounts().getTrainingDocsCount(), equalTo(5L)); + assertThat(stats.getDataCounts().getTestDocsCount(), equalTo(0L)); + assertThat(stats.getDataCounts().getSkippedDocsCount(), equalTo(0L)); SearchResponse sourceData = client().prepareSearch(sourceIndex).get(); double scoreOfOutlier = 0.0; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsStatsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsStatsAction.java index 6efca760917..1d8dd1ebfaa 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsStatsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsStatsAction.java @@ -42,6 +42,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState; import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStats; +import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts; import org.elasticsearch.xpack.core.ml.dataframe.stats.Fields; import org.elasticsearch.xpack.core.ml.dataframe.stats.MemoryUsage; import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStats; @@ -108,6 +109,7 @@ public class TransportGetDataFrameAnalyticsStatsAction Stats stats = buildStats( task.getParams().getId(), statsHolder.getProgressTracker().report(), + statsHolder.getDataCountsTracker().report(task.getParams().getId()), statsHolder.getMemoryUsage(), statsHolder.getAnalysisStats() ); @@ -198,6 +200,7 @@ public class TransportGetDataFrameAnalyticsStatsAction MultiSearchRequest multiSearchRequest = new MultiSearchRequest(); multiSearchRequest.add(buildStoredProgressSearch(configId)); + multiSearchRequest.add(buildStatsDocSearch(configId, DataCounts.TYPE_VALUE)); multiSearchRequest.add(buildStatsDocSearch(configId, MemoryUsage.TYPE_VALUE)); multiSearchRequest.add(buildStatsDocSearch(configId, OutlierDetectionStats.TYPE_VALUE)); multiSearchRequest.add(buildStatsDocSearch(configId, ClassificationStats.TYPE_VALUE)); @@ -222,6 +225,7 @@ public class TransportGetDataFrameAnalyticsStatsAction } listener.onResponse(buildStats(configId, retrievedStatsHolder.progress.get(), + retrievedStatsHolder.dataCounts, retrievedStatsHolder.memoryUsage, retrievedStatsHolder.analysisStats )); @@ -256,6 +260,8 @@ public class TransportGetDataFrameAnalyticsStatsAction String hitId = hit.getId(); if (StoredProgress.documentId(configId).equals(hitId)) { retrievedStatsHolder.progress = MlParserUtils.parse(hit, StoredProgress.PARSER); + } else if (DataCounts.documentId(configId).equals(hitId)) { + retrievedStatsHolder.dataCounts = MlParserUtils.parse(hit, DataCounts.LENIENT_PARSER); } else if (hitId.startsWith(MemoryUsage.documentIdPrefix(configId))) { retrievedStatsHolder.memoryUsage = MlParserUtils.parse(hit, MemoryUsage.LENIENT_PARSER); } else if (hitId.startsWith(OutlierDetectionStats.documentIdPrefix(configId))) { @@ -271,6 +277,7 @@ public class TransportGetDataFrameAnalyticsStatsAction private GetDataFrameAnalyticsStatsAction.Response.Stats buildStats(String concreteAnalyticsId, List progress, + DataCounts dataCounts, MemoryUsage memoryUsage, AnalysisStats analysisStats) { ClusterState clusterState = clusterService.state(); @@ -293,6 +300,7 @@ public class TransportGetDataFrameAnalyticsStatsAction analyticsState, failureReason, progress, + dataCounts, memoryUsage, analysisStats, node, @@ -303,6 +311,7 @@ public class TransportGetDataFrameAnalyticsStatsAction private static class RetrievedStatsHolder { private volatile StoredProgress progress = new StoredProgress(new ProgressTracker().report()); + private volatile DataCounts dataCounts; private volatile MemoryUsage memoryUsage; private volatile AnalysisStats analysisStats; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java index ba1bdf9a6a5..aad06b71c07 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java @@ -19,6 +19,9 @@ import org.elasticsearch.action.search.SearchScrollRequestBuilder; import org.elasticsearch.client.Client; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.index.query.BoolQueryBuilder; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.fetch.StoredFieldsContext; import org.elasticsearch.search.sort.SortOrder; @@ -187,7 +190,7 @@ public class DataFrameDataExtractor { if (values.length == 1 && (values[0] instanceof Number || values[0] instanceof String)) { extractedValues[i] = Objects.toString(values[0]); } else { - if (values.length == 0 && context.includeRowsWithMissingValues) { + if (values.length == 0 && context.supportsRowsWithMissingValues) { // if values is empty then it means it's a missing value extractedValues[i] = NULL_VALUE; } else { @@ -263,13 +266,29 @@ public class DataFrameDataExtractor { } private SearchRequestBuilder buildDataSummarySearchRequestBuilder() { + + QueryBuilder summaryQuery = context.query; + if (context.supportsRowsWithMissingValues == false) { + summaryQuery = QueryBuilders.boolQuery() + .filter(summaryQuery) + .filter(allExtractedFieldsExistQuery()); + } + return new SearchRequestBuilder(client, SearchAction.INSTANCE) .setIndices(context.indices) .setSize(0) - .setQuery(context.query) + .setQuery(summaryQuery) .setTrackTotalHits(true); } + private QueryBuilder allExtractedFieldsExistQuery() { + BoolQueryBuilder query = QueryBuilders.boolQuery(); + for (ExtractedField field : context.extractedFields.getAllFields()) { + query.filter(QueryBuilders.existsQuery(field.getName())); + } + return query; + } + public Set getCategoricalFields(DataFrameAnalysis analysis) { return ExtractedFieldsDetector.getCategoricalFields(context.extractedFields, analysis); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorContext.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorContext.java index 0cf391bc33b..64ad4bed452 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorContext.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorContext.java @@ -21,10 +21,10 @@ public class DataFrameDataExtractorContext { final int scrollSize; final Map headers; final boolean includeSource; - final boolean includeRowsWithMissingValues; + final boolean supportsRowsWithMissingValues; DataFrameDataExtractorContext(String jobId, ExtractedFields extractedFields, List indices, QueryBuilder query, int scrollSize, - Map headers, boolean includeSource, boolean includeRowsWithMissingValues) { + Map headers, boolean includeSource, boolean supportsRowsWithMissingValues) { this.jobId = Objects.requireNonNull(jobId); this.extractedFields = Objects.requireNonNull(extractedFields); this.indices = indices.toArray(new String[indices.size()]); @@ -32,6 +32,6 @@ public class DataFrameDataExtractorContext { this.scrollSize = scrollSize; this.headers = headers; this.includeSource = includeSource; - this.includeRowsWithMissingValues = includeRowsWithMissingValues; + this.supportsRowsWithMissingValues = supportsRowsWithMissingValues; } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java index 3243d92bf77..a699e16a7d6 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java @@ -7,11 +7,9 @@ package org.elasticsearch.xpack.ml.dataframe.extractor; import org.elasticsearch.action.ActionListener; import org.elasticsearch.client.Client; -import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; -import org.elasticsearch.xpack.ml.extractor.ExtractedField; import org.elasticsearch.xpack.ml.extractor.ExtractedFields; import java.util.Arrays; @@ -28,18 +26,18 @@ public class DataFrameDataExtractorFactory { private final QueryBuilder sourceQuery; private final ExtractedFields extractedFields; private final Map headers; - private final boolean includeRowsWithMissingValues; + private final boolean supportsRowsWithMissingValues; private DataFrameDataExtractorFactory(Client client, String analyticsId, List indices, QueryBuilder sourceQuery, ExtractedFields extractedFields, Map headers, - boolean includeRowsWithMissingValues) { + boolean supportsRowsWithMissingValues) { this.client = Objects.requireNonNull(client); this.analyticsId = Objects.requireNonNull(analyticsId); this.indices = Objects.requireNonNull(indices); this.sourceQuery = Objects.requireNonNull(sourceQuery); this.extractedFields = Objects.requireNonNull(extractedFields); this.headers = headers; - this.includeRowsWithMissingValues = includeRowsWithMissingValues; + this.supportsRowsWithMissingValues = supportsRowsWithMissingValues; } public DataFrameDataExtractor newExtractor(boolean includeSource) { @@ -47,11 +45,11 @@ public class DataFrameDataExtractorFactory { analyticsId, extractedFields, indices, - createQuery(), + QueryBuilders.boolQuery().filter(sourceQuery), 1000, headers, includeSource, - includeRowsWithMissingValues + supportsRowsWithMissingValues ); return new DataFrameDataExtractor(client, context); } @@ -60,23 +58,6 @@ public class DataFrameDataExtractorFactory { return extractedFields; } - private QueryBuilder createQuery() { - BoolQueryBuilder query = QueryBuilders.boolQuery(); - query.filter(sourceQuery); - if (includeRowsWithMissingValues == false) { - query.filter(allExtractedFieldsExistQuery()); - } - return query; - } - - private QueryBuilder allExtractedFieldsExistQuery() { - BoolQueryBuilder query = QueryBuilders.boolQuery(); - for (ExtractedField field : extractedFields.getAllFields()) { - query.filter(QueryBuilders.existsQuery(field.getName())); - } - return query; - } - /** * Create a new extractor factory * @@ -109,6 +90,7 @@ public class DataFrameDataExtractorFactory { extractedFieldsDetectorFactory.createFromDest(config, ActionListener.wrap( extractedFieldsDetector -> { ExtractedFields extractedFields = extractedFieldsDetector.detect().v1(); + DataFrameDataExtractorFactory extractorFactory = new DataFrameDataExtractorFactory(client, config.getId(), Collections.singletonList(config.getDest().getIndex()), config.getSource().getParsedQuery(), extractedFields, config.getHeaders(), config.getAnalysis().supportsMissingValues()); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java index c7baad202d2..c1f09ff3fb3 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java @@ -11,6 +11,7 @@ import org.apache.logging.log4j.message.ParameterizedMessage; import org.apache.lucene.util.SetOnce; import org.elasticsearch.action.admin.indices.refresh.RefreshAction; import org.elasticsearch.action.admin.indices.refresh.RefreshRequest; +import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.client.Client; @@ -22,8 +23,10 @@ import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.search.SearchHit; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ClientHelper; +import org.elasticsearch.xpack.core.ml.MlStatsIndex; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis; +import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; @@ -34,7 +37,9 @@ import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFact import org.elasticsearch.xpack.ml.dataframe.process.crossvalidation.CrossValidationSplitter; import org.elasticsearch.xpack.ml.dataframe.process.crossvalidation.CrossValidationSplitterFactory; import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; +import org.elasticsearch.xpack.ml.dataframe.stats.DataCountsTracker; import org.elasticsearch.xpack.ml.dataframe.stats.ProgressTracker; +import org.elasticsearch.xpack.ml.dataframe.stats.StatsPersister; import org.elasticsearch.xpack.ml.extractor.ExtractedFields; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; @@ -156,7 +161,10 @@ public class AnalyticsProcessManager { AnalyticsResultProcessor resultProcessor = processContext.resultProcessor.get(); try { writeHeaderRecord(dataExtractor, process); - writeDataRows(dataExtractor, process, config.getAnalysis(), task.getStatsHolder().getProgressTracker()); + writeDataRows(dataExtractor, process, config.getAnalysis(), task.getStatsHolder().getProgressTracker(), + task.getStatsHolder().getDataCountsTracker()); + processContext.statsPersister.persistWithRetry(task.getStatsHolder().getDataCountsTracker().report(config.getId()), + DataCounts::documentId); process.writeEndOfDataMessage(); process.flushStream(); @@ -205,8 +213,8 @@ public class AnalyticsProcessManager { } } - private void writeDataRows(DataFrameDataExtractor dataExtractor, AnalyticsProcess process, - DataFrameAnalysis analysis, ProgressTracker progressTracker) throws IOException { + private void writeDataRows(DataFrameDataExtractor dataExtractor, AnalyticsProcess process, DataFrameAnalysis analysis, + ProgressTracker progressTracker, DataCountsTracker dataCountsTracker) throws IOException { CrossValidationSplitter crossValidationSplitter = new CrossValidationSplitterFactory(dataExtractor.getFieldNames()) .create(analysis); @@ -223,11 +231,14 @@ public class AnalyticsProcessManager { Optional> rows = dataExtractor.next(); if (rows.isPresent()) { for (DataFrameDataExtractor.Row row : rows.get()) { - if (row.shouldSkip() == false) { + if (row.shouldSkip()) { + dataCountsTracker.incrementSkippedDocsCount(); + } else { String[] rowValues = row.getValues(); System.arraycopy(rowValues, 0, record, 0, rowValues.length); record[record.length - 2] = String.valueOf(row.getChecksum()); - crossValidationSplitter.process(record); + crossValidationSplitter.process(record, dataCountsTracker::incrementTrainingDocsCount, + dataCountsTracker::incrementTestDocsCount); process.writeRecord(record); } } @@ -253,6 +264,10 @@ public class AnalyticsProcessManager { process.writeRecord(headerRecord); } + private void indexDataCounts(DataCounts dataCounts) { + IndexRequest indexRequest = new IndexRequest(MlStatsIndex.writeAlias()); + } + private void restoreState(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config, @Nullable BytesReference state, AnalyticsProcess process) { if (config.getAnalysis().persistsState() == false) { @@ -353,9 +368,11 @@ public class AnalyticsProcessManager { private final SetOnce dataExtractor = new SetOnce<>(); private final SetOnce resultProcessor = new SetOnce<>(); private final SetOnce failureReason = new SetOnce<>(); + private final StatsPersister statsPersister; ProcessContext(DataFrameAnalyticsConfig config) { this.config = Objects.requireNonNull(config); + this.statsPersister = new StatsPersister(config.getId(), resultsPersisterService, auditor); } String getFailureReason() { @@ -378,6 +395,7 @@ public class AnalyticsProcessManager { if (resultProcessor.get() != null) { resultProcessor.get().cancel(); } + statsPersister.cancel(); if (process.get() != null) { try { process.get().kill(); @@ -434,7 +452,7 @@ public class AnalyticsProcessManager { DataFrameRowsJoiner dataFrameRowsJoiner = new DataFrameRowsJoiner(config.getId(), dataExtractorFactory.newExtractor(true), resultsPersisterService); return new AnalyticsResultProcessor( - config, dataFrameRowsJoiner, task.getStatsHolder(), trainedModelProvider, auditor, resultsPersisterService, + config, dataFrameRowsJoiner, task.getStatsHolder(), trainedModelProvider, auditor, statsPersister, dataExtractor.get().getAllExtractedFields()); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java index 2afb2b65d5b..5a5e28683d8 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java @@ -11,14 +11,10 @@ import org.apache.logging.log4j.message.ParameterizedMessage; import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.LatchedActionListener; -import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.common.Nullable; -import org.elasticsearch.common.xcontent.ToXContent; -import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.common.xcontent.json.JsonXContent; import org.elasticsearch.license.License; -import org.elasticsearch.xpack.core.ml.MlStatsIndex; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; @@ -31,18 +27,16 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; -import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; import org.elasticsearch.xpack.core.security.user.XPackUser; import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults; import org.elasticsearch.xpack.ml.dataframe.stats.StatsHolder; +import org.elasticsearch.xpack.ml.dataframe.stats.StatsPersister; import org.elasticsearch.xpack.ml.extractor.ExtractedField; import org.elasticsearch.xpack.ml.extractor.MultiField; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; -import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService; -import java.io.IOException; import java.time.Instant; import java.util.Collections; import java.util.Iterator; @@ -51,7 +45,6 @@ import java.util.Map; import java.util.Objects; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; -import java.util.function.Function; import java.util.stream.Collectors; import static java.util.stream.Collectors.toList; @@ -77,7 +70,7 @@ public class AnalyticsResultProcessor { private final StatsHolder statsHolder; private final TrainedModelProvider trainedModelProvider; private final DataFrameAnalyticsAuditor auditor; - private final ResultsPersisterService resultsPersisterService; + private final StatsPersister statsPersister; private final List fieldNames; private final CountDownLatch completionLatch = new CountDownLatch(1); private volatile String failure; @@ -85,14 +78,13 @@ public class AnalyticsResultProcessor { public AnalyticsResultProcessor(DataFrameAnalyticsConfig analytics, DataFrameRowsJoiner dataFrameRowsJoiner, StatsHolder statsHolder, TrainedModelProvider trainedModelProvider, - DataFrameAnalyticsAuditor auditor, ResultsPersisterService resultsPersisterService, - List fieldNames) { + DataFrameAnalyticsAuditor auditor, StatsPersister statsPersister, List fieldNames) { this.analytics = Objects.requireNonNull(analytics); this.dataFrameRowsJoiner = Objects.requireNonNull(dataFrameRowsJoiner); this.statsHolder = Objects.requireNonNull(statsHolder); this.trainedModelProvider = Objects.requireNonNull(trainedModelProvider); this.auditor = Objects.requireNonNull(auditor); - this.resultsPersisterService = Objects.requireNonNull(resultsPersisterService); + this.statsPersister = Objects.requireNonNull(statsPersister); this.fieldNames = Collections.unmodifiableList(Objects.requireNonNull(fieldNames)); } @@ -112,6 +104,7 @@ public class AnalyticsResultProcessor { public void cancel() { dataFrameRowsJoiner.cancel(); + statsPersister.cancel(); isCancelled = true; } @@ -176,22 +169,22 @@ public class AnalyticsResultProcessor { MemoryUsage memoryUsage = result.getMemoryUsage(); if (memoryUsage != null) { statsHolder.setMemoryUsage(memoryUsage); - indexStatsResult(memoryUsage, memoryUsage::documentId); + statsPersister.persistWithRetry(memoryUsage, memoryUsage::documentId); } OutlierDetectionStats outlierDetectionStats = result.getOutlierDetectionStats(); if (outlierDetectionStats != null) { statsHolder.setAnalysisStats(outlierDetectionStats); - indexStatsResult(outlierDetectionStats, outlierDetectionStats::documentId); + statsPersister.persistWithRetry(outlierDetectionStats, outlierDetectionStats::documentId); } ClassificationStats classificationStats = result.getClassificationStats(); if (classificationStats != null) { statsHolder.setAnalysisStats(classificationStats); - indexStatsResult(classificationStats, classificationStats::documentId); + statsPersister.persistWithRetry(classificationStats, classificationStats::documentId); } RegressionStats regressionStats = result.getRegressionStats(); if (regressionStats != null) { statsHolder.setAnalysisStats(regressionStats); - indexStatsResult(regressionStats, regressionStats::documentId); + statsPersister.persistWithRetry(regressionStats, regressionStats::documentId); } } @@ -274,23 +267,4 @@ public class AnalyticsResultProcessor { failure = "error processing results; " + e.getMessage(); auditor.error(analytics.getId(), "Error processing results; " + e.getMessage()); } - - private void indexStatsResult(ToXContentObject result, Function docIdSupplier) { - try { - resultsPersisterService.indexWithRetry(analytics.getId(), - MlStatsIndex.writeAlias(), - result, - new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")), - WriteRequest.RefreshPolicy.IMMEDIATE, - docIdSupplier.apply(analytics.getId()), - () -> isCancelled == false, - errorMsg -> auditor.error(analytics.getId(), - "failed to persist result with id [" + docIdSupplier.apply(analytics.getId()) + "]; " + errorMsg) - ); - } catch (IOException ioe) { - LOGGER.error(() -> new ParameterizedMessage("[{}] Failed serializing stats result", analytics.getId()), ioe); - } catch (Exception e) { - LOGGER.error(() -> new ParameterizedMessage("[{}] Failed indexing stats result", analytics.getId()), e); - } - } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/CrossValidationSplitter.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/CrossValidationSplitter.java index 5d12a2a81a6..fce602b28e2 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/CrossValidationSplitter.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/CrossValidationSplitter.java @@ -10,5 +10,5 @@ package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation; */ public interface CrossValidationSplitter { - void process(String[] row); + void process(String[] row, Runnable incrementTrainingDocs, Runnable incrementTestDocs); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/CrossValidationSplitterFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/CrossValidationSplitterFactory.java index 47c052dd0bf..986633aaa37 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/CrossValidationSplitterFactory.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/CrossValidationSplitterFactory.java @@ -31,6 +31,6 @@ public class CrossValidationSplitterFactory { return new RandomCrossValidationSplitter( fieldNames, classification.getDependentVariable(), classification.getTrainingPercent(), classification.getRandomizeSeed()); } - return row -> {}; + return (row, incrementTrainingDocs, incrementTestDocs) -> incrementTrainingDocs.run(); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/RandomCrossValidationSplitter.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/RandomCrossValidationSplitter.java index 0afc59628e7..e4e343083ee 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/RandomCrossValidationSplitter.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/RandomCrossValidationSplitter.java @@ -40,22 +40,25 @@ class RandomCrossValidationSplitter implements CrossValidationSplitter { } @Override - public void process(String[] row) { - if (canBeUsedForTraining(row)) { - if (isFirstRow) { - // Let's make sure we have at least one training row - isFirstRow = false; - } else if (isRandomlyExcludedFromTraining()) { - row[dependentVariableIndex] = DataFrameDataExtractor.NULL_VALUE; - } + public void process(String[] row, Runnable incrementTrainingDocs, Runnable incrementTestDocs) { + if (canBeUsedForTraining(row) && isPickedForTraining()) { + incrementTrainingDocs.run(); + } else { + row[dependentVariableIndex] = DataFrameDataExtractor.NULL_VALUE; + incrementTestDocs.run(); } } private boolean canBeUsedForTraining(String[] row) { - return row[dependentVariableIndex].length() > 0; + return row[dependentVariableIndex] != DataFrameDataExtractor.NULL_VALUE; } - private boolean isRandomlyExcludedFromTraining() { - return random.nextDouble() * 100 > trainingPercent; + private boolean isPickedForTraining() { + if (isFirstRow) { + // Let's make sure we have at least one training row + isFirstRow = false; + return true; + } + return random.nextDouble() * 100 <= trainingPercent; } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/DataCountsTracker.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/DataCountsTracker.java new file mode 100644 index 00000000000..bed9f52b448 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/DataCountsTracker.java @@ -0,0 +1,37 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.dataframe.stats; + +import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts; + +public class DataCountsTracker { + + private volatile long trainingDocsCount; + private volatile long testDocsCount; + private volatile long skippedDocsCount; + + public void incrementTrainingDocsCount() { + trainingDocsCount++; + } + + public void incrementTestDocsCount() { + testDocsCount++; + } + + public void incrementSkippedDocsCount() { + skippedDocsCount++; + } + + public DataCounts report(String jobId) { + return new DataCounts( + jobId, + trainingDocsCount, + testDocsCount, + skippedDocsCount + ); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolder.java index ff6b9ec7bcf..d01eee6a3a3 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolder.java @@ -19,11 +19,13 @@ public class StatsHolder { private final ProgressTracker progressTracker; private final AtomicReference memoryUsageHolder; private final AtomicReference analysisStatsHolder; + private final DataCountsTracker dataCountsTracker; public StatsHolder() { progressTracker = new ProgressTracker(); memoryUsageHolder = new AtomicReference<>(); analysisStatsHolder = new AtomicReference<>(); + dataCountsTracker = new DataCountsTracker(); } public ProgressTracker getProgressTracker() { @@ -45,4 +47,8 @@ public class StatsHolder { public AnalysisStats getAnalysisStats() { return analysisStatsHolder.get(); } + + public DataCountsTracker getDataCountsTracker() { + return dataCountsTracker; + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsPersister.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsPersister.java new file mode 100644 index 00000000000..eeb8924928c --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsPersister.java @@ -0,0 +1,66 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.dataframe.stats; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.xpack.core.ml.MlStatsIndex; +import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; +import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; +import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService; + +import java.io.IOException; +import java.util.Collections; +import java.util.Objects; +import java.util.function.Function; + +public class StatsPersister { + + private static final Logger LOGGER = LogManager.getLogger(StatsPersister.class); + + private final String jobId; + private final ResultsPersisterService resultsPersisterService; + private final DataFrameAnalyticsAuditor auditor; + private volatile boolean isCancelled; + + public StatsPersister(String jobId, ResultsPersisterService resultsPersisterService, DataFrameAnalyticsAuditor auditor) { + this.jobId = Objects.requireNonNull(jobId); + this.resultsPersisterService = Objects.requireNonNull(resultsPersisterService); + this.auditor = Objects.requireNonNull(auditor); + } + + public void persistWithRetry(ToXContentObject result, Function docIdSupplier) { + if (isCancelled) { + return; + } + + try { + resultsPersisterService.indexWithRetry(jobId, + MlStatsIndex.writeAlias(), + result, + new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")), + WriteRequest.RefreshPolicy.IMMEDIATE, + docIdSupplier.apply(jobId), + () -> isCancelled == false, + errorMsg -> auditor.error(jobId, + "failed to persist result with id [" + docIdSupplier.apply(jobId) + "]; " + errorMsg) + ); + } catch (IOException ioe) { + LOGGER.error(() -> new ParameterizedMessage("[{}] Failed serializing stats result", jobId), ioe); + } catch (Exception e) { + LOGGER.error(() -> new ParameterizedMessage("[{}] Failed indexing stats result", jobId), e); + } + } + + public void cancel() { + isCancelled = true; + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java index c482105de89..3ea8277f8ee 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java @@ -324,41 +324,43 @@ public class DataFrameDataExtractorTests extends ESTestCase { assertThat(searchRequest, containsString("\"_source\":{\"includes\":[\"field_2\"],\"excludes\":[]}")); } - public void testMissingValues_GivenShouldNotInclude() throws IOException { + public void testCollectDataSummary_GivenAnalysisSupportsMissingFields() { + TestExtractor dataExtractor = createExtractor(true, true); + + // First and only batch + SearchResponse response = createSearchResponse(Arrays.asList(1_1, 1_2), Arrays.asList(2_1, 2_2)); + dataExtractor.setNextResponse(response); + + DataFrameDataExtractor.DataSummary dataSummary = dataExtractor.collectDataSummary(); + + assertThat(dataSummary.rows, equalTo(2L)); + assertThat(dataSummary.cols, equalTo(2)); + + assertThat(dataExtractor.capturedSearchRequests.size(), equalTo(1)); + String searchRequest = dataExtractor.capturedSearchRequests.get(0).request().toString().replaceAll("\\s", ""); + assertThat(searchRequest, containsString("\"query\":{\"match_all\":{\"boost\":1.0}}")); + } + + public void testCollectDataSummary_GivenAnalysisDoesNotSupportMissingFields() { TestExtractor dataExtractor = createExtractor(true, false); // First and only batch - SearchResponse response1 = createSearchResponse(Arrays.asList(1_1, null, 1_3), Arrays.asList(2_1, 2_2, 2_3)); - dataExtractor.setNextResponse(response1); + SearchResponse response = createSearchResponse(Arrays.asList(1_1, 1_2), Arrays.asList(2_1, 2_2)); + dataExtractor.setNextResponse(response); - // Empty - SearchResponse lastAndEmptyResponse = createEmptySearchResponse(); - dataExtractor.setNextResponse(lastAndEmptyResponse); + DataFrameDataExtractor.DataSummary dataSummary = dataExtractor.collectDataSummary(); - assertThat(dataExtractor.hasNext(), is(true)); + assertThat(dataSummary.rows, equalTo(2L)); + assertThat(dataSummary.cols, equalTo(2)); - // First batch - Optional> rows = dataExtractor.next(); - assertThat(rows.isPresent(), is(true)); - assertThat(rows.get().size(), equalTo(3)); - - assertThat(rows.get().get(0).getValues(), equalTo(new String[] {"11", "21"})); - assertThat(rows.get().get(1).getValues(), is(nullValue())); - assertThat(rows.get().get(2).getValues(), equalTo(new String[] {"13", "23"})); - - assertThat(rows.get().get(0).shouldSkip(), is(false)); - assertThat(rows.get().get(1).shouldSkip(), is(true)); - assertThat(rows.get().get(2).shouldSkip(), is(false)); - - assertThat(dataExtractor.hasNext(), is(true)); - - // Third batch should return empty - rows = dataExtractor.next(); - assertThat(rows.isPresent(), is(false)); - assertThat(dataExtractor.hasNext(), is(false)); + assertThat(dataExtractor.capturedSearchRequests.size(), equalTo(1)); + String searchRequest = dataExtractor.capturedSearchRequests.get(0).request().toString().replaceAll("\\s", ""); + assertThat(searchRequest, containsString( + "\"query\":{\"bool\":{\"filter\":[{\"match_all\":{\"boost\":1.0}},{\"bool\":{\"filter\":" + + "[{\"exists\":{\"field\":\"field_1\",\"boost\":1.0}},{\"exists\":{\"field\":\"field_2\",\"boost\":1.0}}]")); } - public void testMissingValues_GivenShouldInclude() throws IOException { + public void testMissingValues_GivenSupported() throws IOException { TestExtractor dataExtractor = createExtractor(true, true); // First and only batch @@ -393,6 +395,40 @@ public class DataFrameDataExtractorTests extends ESTestCase { assertThat(dataExtractor.hasNext(), is(false)); } + public void testMissingValues_GivenNotSupported() throws IOException { + TestExtractor dataExtractor = createExtractor(true, false); + + // First and only batch + SearchResponse response1 = createSearchResponse(Arrays.asList(1_1, null, 1_3), Arrays.asList(2_1, 2_2, 2_3)); + dataExtractor.setNextResponse(response1); + + // Empty + SearchResponse lastAndEmptyResponse = createEmptySearchResponse(); + dataExtractor.setNextResponse(lastAndEmptyResponse); + + assertThat(dataExtractor.hasNext(), is(true)); + + // First batch + Optional> rows = dataExtractor.next(); + assertThat(rows.isPresent(), is(true)); + assertThat(rows.get().size(), equalTo(3)); + + assertThat(rows.get().get(0).getValues(), equalTo(new String[] {"11", "21"})); + assertThat(rows.get().get(1).getValues(), is(nullValue())); + assertThat(rows.get().get(2).getValues(), equalTo(new String[] {"13", "23"})); + + assertThat(rows.get().get(0).shouldSkip(), is(false)); + assertThat(rows.get().get(1).shouldSkip(), is(true)); + assertThat(rows.get().get(2).shouldSkip(), is(false)); + + assertThat(dataExtractor.hasNext(), is(true)); + + // Third batch should return empty + rows = dataExtractor.next(); + assertThat(rows.isPresent(), is(false)); + assertThat(dataExtractor.hasNext(), is(false)); + } + public void testGetCategoricalFields() { // Explicit cast of ExtractedField args necessary for Eclipse due to https://bugs.eclipse.org/bugs/show_bug.cgi?id=530915 extractedFields = new ExtractedFields(Arrays.asList( @@ -424,9 +460,9 @@ public class DataFrameDataExtractorTests extends ESTestCase { containsInAnyOrder("field_keyword", "field_text", "field_boolean")); } - private TestExtractor createExtractor(boolean includeSource, boolean includeRowsWithMissingValues) { + private TestExtractor createExtractor(boolean includeSource, boolean supportsRowsWithMissingValues) { DataFrameDataExtractorContext context = new DataFrameDataExtractorContext( - JOB_ID, extractedFields, indices, query, scrollSize, headers, includeSource, includeRowsWithMissingValues); + JOB_ID, extractedFields, indices, query, scrollSize, headers, includeSource, supportsRowsWithMissingValues); return new TestExtractor(client, context); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java index da7bbff71d5..ee2b399ef2d 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java @@ -24,13 +24,13 @@ import org.elasticsearch.xpack.core.security.user.XPackUser; import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults; import org.elasticsearch.xpack.ml.dataframe.stats.StatsHolder; +import org.elasticsearch.xpack.ml.dataframe.stats.StatsPersister; import org.elasticsearch.xpack.ml.extractor.DocValueField; import org.elasticsearch.xpack.ml.extractor.ExtractedField; import org.elasticsearch.xpack.ml.extractor.ExtractedFields; import org.elasticsearch.xpack.ml.extractor.MultiField; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; -import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService; import org.junit.Before; import org.mockito.ArgumentCaptor; import org.mockito.InOrder; @@ -66,7 +66,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase { private StatsHolder statsHolder = new StatsHolder(); private TrainedModelProvider trainedModelProvider; private DataFrameAnalyticsAuditor auditor; - private ResultsPersisterService resultsPersisterService; + private StatsPersister statsPersister; private DataFrameAnalyticsConfig analyticsConfig; @Before @@ -76,7 +76,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase { dataFrameRowsJoiner = mock(DataFrameRowsJoiner.class); trainedModelProvider = mock(TrainedModelProvider.class); auditor = mock(DataFrameAnalyticsAuditor.class); - resultsPersisterService = mock(ResultsPersisterService.class); + statsPersister = mock(StatsPersister.class); analyticsConfig = new DataFrameAnalyticsConfig.Builder() .setId(JOB_ID) .setDescription(JOB_DESCRIPTION) @@ -251,7 +251,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase { statsHolder, trainedModelProvider, auditor, - resultsPersisterService, + statsPersister, fieldNames); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/RandomCrossValidationSplitterTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/RandomCrossValidationSplitterTests.java index eea102e6738..0bbc9d75d8b 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/RandomCrossValidationSplitterTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/RandomCrossValidationSplitterTests.java @@ -26,6 +26,8 @@ public class RandomCrossValidationSplitterTests extends ESTestCase { private int dependentVariableIndex; private String dependentVariable; private long randomizeSeed; + private long trainingDocsCount; + private long testDocsCount; @Before public void setUpTests() { @@ -40,47 +42,48 @@ public class RandomCrossValidationSplitterTests extends ESTestCase { } public void testProcess_GivenRowsWithoutDependentVariableValue() { - CrossValidationSplitter crossValidationSplitter = new RandomCrossValidationSplitter(fields, dependentVariable, 50.0, randomizeSeed); + CrossValidationSplitter crossValidationSplitter = createSplitter(50.0); for (int i = 0; i < 100; i++) { String[] row = new String[fields.size()]; for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) { - String value = fieldIndex == dependentVariableIndex ? "" : randomAlphaOfLength(10); + String value = fieldIndex == dependentVariableIndex ? DataFrameDataExtractor.NULL_VALUE : randomAlphaOfLength(10); row[fieldIndex] = value; } String[] processedRow = Arrays.copyOf(row, row.length); - crossValidationSplitter.process(processedRow); + crossValidationSplitter.process(processedRow, this::incrementTrainingDocsCount, this::incrementTestDocsCount); // As all these rows have no dependent variable value, they're not for training and should be unaffected assertThat(Arrays.equals(processedRow, row), is(true)); } + assertThat(trainingDocsCount, equalTo(0L)); + assertThat(testDocsCount, equalTo(100L)); } public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIsHundred() { - CrossValidationSplitter crossValidationSplitter = new RandomCrossValidationSplitter( - fields, dependentVariable, 100.0, randomizeSeed); + CrossValidationSplitter crossValidationSplitter = createSplitter(100.0); for (int i = 0; i < 100; i++) { String[] row = new String[fields.size()]; for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) { - String value = fieldIndex == dependentVariableIndex ? "" : randomAlphaOfLength(10); - row[fieldIndex] = value; + row[fieldIndex] = randomAlphaOfLength(10); } String[] processedRow = Arrays.copyOf(row, row.length); - crossValidationSplitter.process(processedRow); + crossValidationSplitter.process(processedRow, this::incrementTrainingDocsCount, this::incrementTestDocsCount); // We should pick them all as training percent is 100 assertThat(Arrays.equals(processedRow, row), is(true)); } + assertThat(trainingDocsCount, equalTo(100L)); + assertThat(testDocsCount, equalTo(0L)); } public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIsRandom() { double trainingPercent = randomDoubleBetween(1.0, 100.0, true); double trainingFraction = trainingPercent / 100; - CrossValidationSplitter crossValidationSplitter = new RandomCrossValidationSplitter( - fields, dependentVariable, trainingPercent, randomizeSeed); + CrossValidationSplitter crossValidationSplitter = createSplitter(trainingPercent); int runCount = 20; int rowsCount = 1000; @@ -94,7 +97,7 @@ public class RandomCrossValidationSplitterTests extends ESTestCase { } String[] processedRow = Arrays.copyOf(row, row.length); - crossValidationSplitter.process(processedRow); + crossValidationSplitter.process(processedRow, this::incrementTrainingDocsCount, this::incrementTestDocsCount); for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) { if (fieldIndex != dependentVariableIndex) { @@ -126,8 +129,7 @@ public class RandomCrossValidationSplitterTests extends ESTestCase { } public void testProcess_ShouldHaveAtLeastOneTrainingRow() { - CrossValidationSplitter crossValidationSplitter = new RandomCrossValidationSplitter( - fields, dependentVariable, 1.0, randomizeSeed); + CrossValidationSplitter crossValidationSplitter = createSplitter(1.0); // We have some non-training rows and then a training row to check // we maintain the first training row and not just the first row @@ -135,16 +137,30 @@ public class RandomCrossValidationSplitterTests extends ESTestCase { String[] row = new String[fields.size()]; for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) { if (i < 9 && fieldIndex == dependentVariableIndex) { - row[fieldIndex] = ""; + row[fieldIndex] = DataFrameDataExtractor.NULL_VALUE; } else { row[fieldIndex] = randomAlphaOfLength(10); } } String[] processedRow = Arrays.copyOf(row, row.length); - crossValidationSplitter.process(processedRow); + crossValidationSplitter.process(processedRow, this::incrementTrainingDocsCount, this::incrementTestDocsCount); assertThat(Arrays.equals(processedRow, row), is(true)); } + assertThat(trainingDocsCount, equalTo(1L)); + assertThat(testDocsCount, equalTo(9L)); + } + + private RandomCrossValidationSplitter createSplitter(double trainingPercent) { + return new RandomCrossValidationSplitter(fields, dependentVariable, trainingPercent, randomizeSeed); + } + + private void incrementTrainingDocsCount() { + trainingDocsCount++; + } + + private void incrementTestDocsCount() { + testDocsCount++; } }