[ML][Data Frame] add support for bucket_selector (#44718) (#45008)

This commit is contained in:
Benjamin Trent 2019-07-30 11:32:58 -05:00 committed by GitHub
parent 548c767b6b
commit 22feedf289
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 73 additions and 6 deletions

View File

@ -130,6 +130,49 @@ public class DataFramePivotRestIT extends DataFrameRestTestCase {
assertThat(actual, equalTo(pipelineValue));
}
public void testBucketSelectorPivot() throws Exception {
String transformId = "simple_bucket_selector_pivot";
String dataFrameIndex = "bucket_selector_idx";
setupDataAccessRole(DATA_ACCESS_ROLE, REVIEWS_INDEX_NAME, dataFrameIndex);
final Request createDataframeTransformRequest = createRequestWithAuth("PUT", DATAFRAME_ENDPOINT + transformId,
BASIC_AUTH_VALUE_DATA_FRAME_ADMIN_WITH_SOME_DATA_ACCESS);
String config = "{"
+ " \"source\": {\"index\":\"" + REVIEWS_INDEX_NAME + "\"},"
+ " \"dest\": {\"index\":\"" + dataFrameIndex + "\"},"
+ " \"frequency\": \"1s\","
+ " \"pivot\": {"
+ " \"group_by\": {"
+ " \"reviewer\": {"
+ " \"terms\": {"
+ " \"field\": \"user_id\""
+ " } } },"
+ " \"aggregations\": {"
+ " \"avg_rating\": {"
+ " \"avg\": {"
+ " \"field\": \"stars\""
+ " } },"
+ " \"over_38\": {"
+ " \"bucket_selector\" : {"
+ " \"buckets_path\": {\"rating\":\"avg_rating\"}, "
+ " \"script\": \"params.rating > 3.8\""
+ " }"
+ " } } }"
+ "}";
createDataframeTransformRequest.setJsonEntity(config);
Map<String, Object> createDataframeTransformResponse = entityAsMap(client().performRequest(createDataframeTransformRequest));
assertThat(createDataframeTransformResponse.get("acknowledged"), equalTo(Boolean.TRUE));
startAndWaitForTransform(transformId, dataFrameIndex);
assertTrue(indexExists(dataFrameIndex));
// get and check some users
assertOnePivotValue(dataFrameIndex + "/_search?q=reviewer:user_11", 3.846153846);
assertOnePivotValue(dataFrameIndex + "/_search?q=reviewer:user_26", 3.918918918);
Map<String, Object> indexStats = getAsMap(dataFrameIndex + "/_stats");
// Should be less than the total number of users since we filtered every user who had an average review less than or equal to 3.8
assertEquals(21, XContentMapValues.extractValue("_all.total.docs.count", indexStats));
}
public void testContinuousPivot() throws Exception {
String indexName = "continuous_reviews";
createReviewsIndex(indexName);

View File

@ -201,6 +201,7 @@ public abstract class DataFrameIndexer extends AsyncTwoPhaseIndexer<DataFrameInd
newPosition,
agg.getBuckets().isEmpty());
// NOTE: progress is also mutated in ClientDataFrameIndexer#onFinished
if (progress != null) {
progress.docsProcessed(getStats().getNumDocuments() - docsBeforeProcess);
}

View File

@ -746,6 +746,16 @@ public class DataFrameTransformTask extends AllocatedPersistentTask implements S
nextCheckpoint = null;
// Reset our failure count as we have finished and may start again with a new checkpoint
failureCount.set(0);
// TODO: progress hack to get around bucket_selector filtering out buckets
// With bucket_selector we could have read all the buckets and completed the transform
// but not "see" all the buckets since they were filtered out. Consequently, progress would
// show less than 100% even though we are done.
// NOTE: this method is called in the same thread as the processing thread.
// Theoretically, there should not be a race condition with updating progress here.
if (progress != null && progress.getRemainingDocs() > 0) {
progress.docsProcessed(progress.getRemainingDocs());
}
if (shouldAuditOnFinish(checkpoint)) {
auditor.info(transformTask.getTransformId(),
"Finished indexing for data frame transform checkpoint [" + checkpoint + "].");

View File

@ -75,13 +75,15 @@ public final class AggregationResultUtils {
aggNames.addAll(pipelineAggs.stream().map(PipelineAggregationBuilder::getName).collect(Collectors.toList()));
for (String aggName: aggNames) {
final String fieldType = fieldTypeMap.get(aggName);
// TODO: support other aggregation types
Aggregation aggResult = bucket.getAggregations().get(aggName);
AggValueExtractor extractor = getExtractor(aggResult);
updateDocument(document, aggName, extractor.value(aggResult, fieldType));
// This indicates not that the value contained in the `aggResult` is null, but that the `aggResult` is not
// present at all in the `bucket.getAggregations`. This could occur in the case of a `bucket_selector` agg, which
// does not calculate a value, but instead manipulates other results.
if (aggResult != null) {
final String fieldType = fieldTypeMap.get(aggName);
AggValueExtractor extractor = getExtractor(aggResult);
updateDocument(document, aggName, extractor.value(aggResult, fieldType));
}
}
document.put(DataFrameField.DOCUMENT_ID_FIELD, idGen.getID());

View File

@ -38,6 +38,7 @@ public final class Aggregations {
GEO_CENTROID("geo_centroid", "geo_point"),
SCRIPTED_METRIC("scripted_metric", DYNAMIC),
WEIGHTED_AVG("weighted_avg", DYNAMIC),
BUCKET_SELECTOR("bucket_selector", DYNAMIC),
BUCKET_SCRIPT("bucket_script", DYNAMIC);
private final String aggregationType;

View File

@ -50,6 +50,10 @@ public class AggregationsTests extends ESTestCase {
assertEquals("_dynamic", Aggregations.resolveTargetMapping("bucket_script", null));
assertEquals("_dynamic", Aggregations.resolveTargetMapping("bucket_script", "int"));
// bucket_selector
assertEquals("_dynamic", Aggregations.resolveTargetMapping("bucket_selector", null));
assertEquals("_dynamic", Aggregations.resolveTargetMapping("bucket_selector", "int"));
// weighted_avg
assertEquals("_dynamic", Aggregations.resolveTargetMapping("weighted_avg", null));
assertEquals("_dynamic", Aggregations.resolveTargetMapping("weighted_avg", "double"));

View File

@ -217,6 +217,12 @@ public class PivotTests extends ESTestCase {
"\"buckets_path\":{\"param_1\":\"other_bucket\"}," +
"\"script\":\"return params.param_1\"}}}");
}
if (agg.equals(AggregationType.BUCKET_SELECTOR.getName())) {
return parseAggregations("{\"pivot_bucket_selector\":{" +
"\"bucket_selector\":{" +
"\"buckets_path\":{\"param_1\":\"other_bucket\"}," +
"\"script\":\"params.param_1 > 42.0\"}}}");
}
if (agg.equals(AggregationType.WEIGHTED_AVG.getName())) {
return parseAggregations("{\n" +
"\"pivot_weighted_avg\": {\n" +