[ML-DataFrame] refactor pivot to only take the pivot config (#41763)

refactor pivot class to only take the config at construction, other parameters are passed in as part of
method that require them
This commit is contained in:
Hendrik Muhs 2019-05-03 13:36:59 +02:00
parent c942277822
commit befe2a45b9
7 changed files with 63 additions and 67 deletions

View File

@ -39,7 +39,7 @@ public class QueryConfig extends AbstractDiffable<QueryConfig> implements Writea
private final Map<String, Object> source; private final Map<String, Object> source;
private final QueryBuilder query; private final QueryBuilder query;
static QueryConfig matchAll() { public static QueryConfig matchAll() {
return new QueryConfig(Collections.singletonMap(MatchAllQueryBuilder.NAME, Collections.emptyMap()), return new QueryConfig(Collections.singletonMap(MatchAllQueryBuilder.NAME, Collections.emptyMap()),
new MatchAllQueryBuilder()); new MatchAllQueryBuilder());
} }

View File

@ -24,6 +24,7 @@ import org.elasticsearch.xpack.core.XPackField;
import org.elasticsearch.xpack.core.dataframe.action.PreviewDataFrameTransformAction; import org.elasticsearch.xpack.core.dataframe.action.PreviewDataFrameTransformAction;
import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameIndexerTransformStats; import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameIndexerTransformStats;
import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformConfig; import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformConfig;
import org.elasticsearch.xpack.core.dataframe.transforms.SourceConfig;
import org.elasticsearch.xpack.dataframe.transforms.pivot.Pivot; import org.elasticsearch.xpack.dataframe.transforms.pivot.Pivot;
import java.util.List; import java.util.List;
@ -61,24 +62,22 @@ public class TransportPreviewDataFrameTransformAction extends
final DataFrameTransformConfig config = request.getConfig(); final DataFrameTransformConfig config = request.getConfig();
Pivot pivot = new Pivot(config.getSource().getIndex(), Pivot pivot = new Pivot(config.getPivotConfig());
config.getSource().getQueryConfig().getQuery(),
config.getPivotConfig());
getPreview(pivot, ActionListener.wrap( getPreview(pivot, config.getSource(), ActionListener.wrap(
previewResponse -> listener.onResponse(new PreviewDataFrameTransformAction.Response(previewResponse)), previewResponse -> listener.onResponse(new PreviewDataFrameTransformAction.Response(previewResponse)),
listener::onFailure listener::onFailure
)); ));
} }
private void getPreview(Pivot pivot, ActionListener<List<Map<String, Object>>> listener) { private void getPreview(Pivot pivot, SourceConfig source, ActionListener<List<Map<String, Object>>> listener) {
pivot.deduceMappings(client, ActionListener.wrap( pivot.deduceMappings(client, source, ActionListener.wrap(
deducedMappings -> { deducedMappings -> {
ClientHelper.executeWithHeadersAsync(threadPool.getThreadContext().getHeaders(), ClientHelper.executeWithHeadersAsync(threadPool.getThreadContext().getHeaders(),
ClientHelper.DATA_FRAME_ORIGIN, ClientHelper.DATA_FRAME_ORIGIN,
client, client,
SearchAction.INSTANCE, SearchAction.INSTANCE,
pivot.buildSearchRequest(null, NUMBER_OF_PREVIEW_BUCKETS), pivot.buildSearchRequest(source, null, NUMBER_OF_PREVIEW_BUCKETS),
ActionListener.wrap( ActionListener.wrap(
r -> { r -> {
final CompositeAggregation agg = r.getAggregations().get(COMPOSITE_AGGREGATION_NAME); final CompositeAggregation agg = r.getAggregations().get(COMPOSITE_AGGREGATION_NAME);

View File

@ -190,9 +190,7 @@ public class TransportPutDataFrameTransformAction
private void putDataFrame(DataFrameTransformConfig config, ActionListener<AcknowledgedResponse> listener) { private void putDataFrame(DataFrameTransformConfig config, ActionListener<AcknowledgedResponse> listener) {
final Pivot pivot = new Pivot(config.getSource().getIndex(), final Pivot pivot = new Pivot(config.getPivotConfig());
config.getSource().getQueryConfig().getQuery(),
config.getPivotConfig());
// <5> Return the listener, or clean up destination index on failure. // <5> Return the listener, or clean up destination index on failure.
@ -210,6 +208,6 @@ public class TransportPutDataFrameTransformAction
); );
// <1> Validate our pivot // <1> Validate our pivot
pivot.validate(client, pivotValidationListener); pivot.validate(client, config.getSource(), pivotValidationListener);
} }
} }

View File

@ -224,9 +224,7 @@ public class TransportStartDataFrameTransformAction extends
private void createDestinationIndex(final DataFrameTransformConfig config, final ActionListener<Void> listener) { private void createDestinationIndex(final DataFrameTransformConfig config, final ActionListener<Void> listener) {
final Pivot pivot = new Pivot(config.getSource().getIndex(), final Pivot pivot = new Pivot(config.getPivotConfig());
config.getSource().getQueryConfig().getQuery(),
config.getPivotConfig());
ActionListener<Map<String, String>> deduceMappingsListener = ActionListener.wrap( ActionListener<Map<String, String>> deduceMappingsListener = ActionListener.wrap(
mappings -> DataframeIndex.createDestinationIndex(client, mappings -> DataframeIndex.createDestinationIndex(client,
@ -238,7 +236,7 @@ public class TransportStartDataFrameTransformAction extends
deduceTargetMappingsException)) deduceTargetMappingsException))
); );
pivot.deduceMappings(client, deduceMappingsListener); pivot.deduceMappings(client, config.getSource(), deduceMappingsListener);
} }
@Override @Override

View File

@ -16,7 +16,6 @@ import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.search.ShardSearchFailure; import org.elasticsearch.action.search.ShardSearchFailure;
import org.elasticsearch.common.breaker.CircuitBreakingException; import org.elasticsearch.common.breaker.CircuitBreakingException;
import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.search.aggregations.bucket.composite.CompositeAggregation; import org.elasticsearch.search.aggregations.bucket.composite.CompositeAggregation;
import org.elasticsearch.xpack.core.dataframe.DataFrameField; import org.elasticsearch.xpack.core.dataframe.DataFrameField;
import org.elasticsearch.xpack.core.dataframe.DataFrameMessages; import org.elasticsearch.xpack.core.dataframe.DataFrameMessages;
@ -97,8 +96,7 @@ public abstract class DataFrameIndexer extends AsyncTwoPhaseIndexer<Map<String,
@Override @Override
protected void onStart(long now, ActionListener<Void> listener) { protected void onStart(long now, ActionListener<Void> listener) {
try { try {
QueryBuilder queryBuilder = getConfig().getSource().getQueryConfig().getQuery(); pivot = new Pivot(getConfig().getPivotConfig());
pivot = new Pivot(getConfig().getSource().getIndex(), queryBuilder, getConfig().getPivotConfig());
// if we haven't set the page size yet, if it is set we might have reduced it after running into an out of memory // if we haven't set the page size yet, if it is set we might have reduced it after running into an out of memory
if (pageSize == 0) { if (pageSize == 0) {
@ -180,7 +178,7 @@ public abstract class DataFrameIndexer extends AsyncTwoPhaseIndexer<Map<String,
@Override @Override
protected SearchRequest buildSearchRequest() { protected SearchRequest buildSearchRequest() {
return pivot.buildSearchRequest(getPosition(), pageSize); return pivot.buildSearchRequest(getConfig().getSource(), getPosition(), pageSize);
} }
/** /**

View File

@ -25,6 +25,7 @@ import org.elasticsearch.search.aggregations.bucket.composite.CompositeAggregati
import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.xpack.core.dataframe.DataFrameMessages; import org.elasticsearch.xpack.core.dataframe.DataFrameMessages;
import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameIndexerTransformStats; import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameIndexerTransformStats;
import org.elasticsearch.xpack.core.dataframe.transforms.SourceConfig;
import org.elasticsearch.xpack.core.dataframe.transforms.pivot.GroupConfig; import org.elasticsearch.xpack.core.dataframe.transforms.pivot.GroupConfig;
import org.elasticsearch.xpack.core.dataframe.transforms.pivot.PivotConfig; import org.elasticsearch.xpack.core.dataframe.transforms.pivot.PivotConfig;
@ -37,24 +38,21 @@ import static org.elasticsearch.common.xcontent.XContentFactory.jsonBuilder;
public class Pivot { public class Pivot {
public static final int DEFAULT_INITIAL_PAGE_SIZE = 500; public static final int DEFAULT_INITIAL_PAGE_SIZE = 500;
public static final int TEST_QUERY_PAGE_SIZE = 50;
private static final String COMPOSITE_AGGREGATION_NAME = "_data_frame"; private static final String COMPOSITE_AGGREGATION_NAME = "_data_frame";
private final PivotConfig config; private final PivotConfig config;
private final String[] source;
// objects for re-using // objects for re-using
private final CompositeAggregationBuilder cachedCompositeAggregation; private final CompositeAggregationBuilder cachedCompositeAggregation;
private final SearchRequest cachedSearchRequest;
public Pivot(String[] source, QueryBuilder query, PivotConfig config) { public Pivot(PivotConfig config) {
this.source = source;
this.config = config; this.config = config;
this.cachedCompositeAggregation = createCompositeAggregation(config); this.cachedCompositeAggregation = createCompositeAggregation(config);
this.cachedSearchRequest = createSearchRequest(source, query, cachedCompositeAggregation);
} }
public void validate(Client client, final ActionListener<Boolean> listener) { public void validate(Client client, SourceConfig sourceConfig, final ActionListener<Boolean> listener) {
// step 1: check if used aggregations are supported // step 1: check if used aggregations are supported
for (AggregationBuilder agg : config.getAggregationConfig().getAggregatorFactories()) { for (AggregationBuilder agg : config.getAggregationConfig().getAggregatorFactories()) {
if (Aggregations.isSupportedByDataframe(agg.getType()) == false) { if (Aggregations.isSupportedByDataframe(agg.getType()) == false) {
@ -64,11 +62,11 @@ public class Pivot {
} }
// step 2: run a query to validate that config is valid // step 2: run a query to validate that config is valid
runTestQuery(client, listener); runTestQuery(client, sourceConfig, listener);
} }
public void deduceMappings(Client client, final ActionListener<Map<String, String>> listener) { public void deduceMappings(Client client, SourceConfig sourceConfig, final ActionListener<Map<String, String>> listener) {
SchemaUtil.deduceMappings(client, config, source, listener); SchemaUtil.deduceMappings(client, config, sourceConfig.getIndex(), listener);
} }
/** /**
@ -87,14 +85,24 @@ public class Pivot {
return DEFAULT_INITIAL_PAGE_SIZE; return DEFAULT_INITIAL_PAGE_SIZE;
} }
public SearchRequest buildSearchRequest(Map<String, Object> position, int pageSize) { public SearchRequest buildSearchRequest(SourceConfig sourceConfig, Map<String, Object> position, int pageSize) {
if (position != null) { QueryBuilder queryBuilder = sourceConfig.getQueryConfig().getQuery();
cachedCompositeAggregation.aggregateAfter(position);
}
SearchRequest searchRequest = new SearchRequest(sourceConfig.getIndex());
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
sourceBuilder.aggregation(buildAggregation(position, pageSize));
sourceBuilder.size(0);
sourceBuilder.query(queryBuilder);
searchRequest.source(sourceBuilder);
return searchRequest;
}
public AggregationBuilder buildAggregation(Map<String, Object> position, int pageSize) {
cachedCompositeAggregation.aggregateAfter(position);
cachedCompositeAggregation.size(pageSize); cachedCompositeAggregation.size(pageSize);
return cachedSearchRequest; return cachedCompositeAggregation;
} }
public Stream<Map<String, Object>> extractResults(CompositeAggregation agg, public Stream<Map<String, Object>> extractResults(CompositeAggregation agg,
@ -113,10 +121,10 @@ public class Pivot {
dataFrameIndexerTransformStats); dataFrameIndexerTransformStats);
} }
private void runTestQuery(Client client, final ActionListener<Boolean> listener) { private void runTestQuery(Client client, SourceConfig sourceConfig, final ActionListener<Boolean> listener) {
// no after key SearchRequest searchRequest = buildSearchRequest(sourceConfig, null, TEST_QUERY_PAGE_SIZE);
cachedCompositeAggregation.aggregateAfter(null);
client.execute(SearchAction.INSTANCE, cachedSearchRequest, ActionListener.wrap(response -> { client.execute(SearchAction.INSTANCE, searchRequest, ActionListener.wrap(response -> {
if (response == null) { if (response == null) {
listener.onFailure(new RuntimeException("Unexpected null response from test query")); listener.onFailure(new RuntimeException("Unexpected null response from test query"));
return; return;
@ -131,16 +139,6 @@ public class Pivot {
})); }));
} }
private static SearchRequest createSearchRequest(String[] index, QueryBuilder query, CompositeAggregationBuilder compositeAggregation) {
SearchRequest searchRequest = new SearchRequest(index);
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
sourceBuilder.aggregation(compositeAggregation);
sourceBuilder.size(0);
sourceBuilder.query(query);
searchRequest.source(sourceBuilder);
return searchRequest;
}
private static CompositeAggregationBuilder createCompositeAggregation(PivotConfig config) { private static CompositeAggregationBuilder createCompositeAggregation(PivotConfig config) {
CompositeAggregationBuilder compositeAggregation; CompositeAggregationBuilder compositeAggregation;

View File

@ -22,12 +22,13 @@ import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.index.IndexNotFoundException; import org.elasticsearch.index.IndexNotFoundException;
import org.elasticsearch.index.query.MatchAllQueryBuilder;
import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits; import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.SearchModule; import org.elasticsearch.search.SearchModule;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.client.NoOpClient; import org.elasticsearch.test.client.NoOpClient;
import org.elasticsearch.xpack.core.dataframe.transforms.QueryConfig;
import org.elasticsearch.xpack.core.dataframe.transforms.SourceConfig;
import org.elasticsearch.xpack.core.dataframe.transforms.pivot.AggregationConfig; import org.elasticsearch.xpack.core.dataframe.transforms.pivot.AggregationConfig;
import org.elasticsearch.xpack.core.dataframe.transforms.pivot.GroupConfigTests; import org.elasticsearch.xpack.core.dataframe.transforms.pivot.GroupConfigTests;
import org.elasticsearch.xpack.core.dataframe.transforms.pivot.PivotConfig; import org.elasticsearch.xpack.core.dataframe.transforms.pivot.PivotConfig;
@ -83,42 +84,46 @@ public class PivotTests extends ESTestCase {
} }
public void testValidateExistingIndex() throws Exception { public void testValidateExistingIndex() throws Exception {
Pivot pivot = new Pivot(new String[]{"existing_source_index"}, new MatchAllQueryBuilder(), getValidPivotConfig()); SourceConfig source = new SourceConfig(new String[]{"existing_source_index"}, QueryConfig.matchAll());
Pivot pivot = new Pivot(getValidPivotConfig());
assertValidTransform(client, pivot); assertValidTransform(client, source, pivot);
} }
public void testValidateNonExistingIndex() throws Exception { public void testValidateNonExistingIndex() throws Exception {
Pivot pivot = new Pivot(new String[]{"non_existing_source_index"}, new MatchAllQueryBuilder(), getValidPivotConfig()); SourceConfig source = new SourceConfig(new String[]{"non_existing_source_index"}, QueryConfig.matchAll());
Pivot pivot = new Pivot(getValidPivotConfig());
assertInvalidTransform(client, pivot); assertInvalidTransform(client, source, pivot);
} }
public void testSearchFailure() throws Exception { public void testSearchFailure() throws Exception {
// test a failure during the search operation, transform creation fails if // test a failure during the search operation, transform creation fails if
// search has failures although they might just be temporary // search has failures although they might just be temporary
Pivot pivot = new Pivot(new String[]{"existing_source_index_with_failing_shards"}, SourceConfig source = new SourceConfig(new String[] { "existing_source_index_with_failing_shards" }, QueryConfig.matchAll());
new MatchAllQueryBuilder(),
getValidPivotConfig());
assertInvalidTransform(client, pivot); Pivot pivot = new Pivot(getValidPivotConfig());
assertInvalidTransform(client, source, pivot);
} }
public void testValidateAllSupportedAggregations() throws Exception { public void testValidateAllSupportedAggregations() throws Exception {
for (String agg : supportedAggregations) { for (String agg : supportedAggregations) {
AggregationConfig aggregationConfig = getAggregationConfig(agg); AggregationConfig aggregationConfig = getAggregationConfig(agg);
SourceConfig source = new SourceConfig(new String[]{"existing_source"}, QueryConfig.matchAll());
Pivot pivot = new Pivot(new String[]{"existing_source"}, new MatchAllQueryBuilder(), getValidPivotConfig(aggregationConfig)); Pivot pivot = new Pivot(getValidPivotConfig(aggregationConfig));
assertValidTransform(client, pivot); assertValidTransform(client, source, pivot);
} }
} }
public void testValidateAllUnsupportedAggregations() throws Exception { public void testValidateAllUnsupportedAggregations() throws Exception {
for (String agg : unsupportedAggregations) { for (String agg : unsupportedAggregations) {
AggregationConfig aggregationConfig = getAggregationConfig(agg); AggregationConfig aggregationConfig = getAggregationConfig(agg);
SourceConfig source = new SourceConfig(new String[]{"existing_source"}, QueryConfig.matchAll());
Pivot pivot = new Pivot(new String[]{"existing_source"}, new MatchAllQueryBuilder(), getValidPivotConfig(aggregationConfig)); Pivot pivot = new Pivot(getValidPivotConfig(aggregationConfig));
assertInvalidTransform(client, pivot); assertInvalidTransform(client, source, pivot);
} }
} }
@ -202,18 +207,18 @@ public class PivotTests extends ESTestCase {
return AggregationConfig.fromXContent(parser, false); return AggregationConfig.fromXContent(parser, false);
} }
private static void assertValidTransform(Client client, Pivot pivot) throws Exception { private static void assertValidTransform(Client client, SourceConfig source, Pivot pivot) throws Exception {
validate(client, pivot, true); validate(client, source, pivot, true);
} }
private static void assertInvalidTransform(Client client, Pivot pivot) throws Exception { private static void assertInvalidTransform(Client client, SourceConfig source, Pivot pivot) throws Exception {
validate(client, pivot, false); validate(client, source, pivot, false);
} }
private static void validate(Client client, Pivot pivot, boolean expectValid) throws Exception { private static void validate(Client client, SourceConfig source, Pivot pivot, boolean expectValid) throws Exception {
CountDownLatch latch = new CountDownLatch(1); CountDownLatch latch = new CountDownLatch(1);
final AtomicReference<Exception> exceptionHolder = new AtomicReference<>(); final AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
pivot.validate(client, ActionListener.wrap(validity -> { pivot.validate(client, source, ActionListener.wrap(validity -> {
assertEquals(expectValid, validity); assertEquals(expectValid, validity);
latch.countDown(); latch.countDown();
}, e -> { }, e -> {