[ML] [Data Frame] nesting group_by fields like other aggs (#42718) (#42760)

This commit is contained in:
Benjamin Trent 2019-05-31 10:55:35 -05:00 committed by GitHub
parent 0a37dd7a86
commit f22dcfb9da
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 9 additions and 5 deletions

View File

@ -251,10 +251,10 @@ public class DataFramePivotRestIT extends DataFrameRestTestCase {
config += " \"pivot\": {"
+ " \"group_by\": {"
+ " \"reviewer\": {\"terms\": { \"field\": \"user_id\" }},"
+ " \"user.id\": {\"terms\": { \"field\": \"user_id\" }},"
+ " \"by_day\": {\"date_histogram\": {\"fixed_interval\": \"1d\",\"field\":\"timestamp\",\"format\":\"yyyy-MM-dd\"}}},"
+ " \"aggregations\": {"
+ " \"avg_rating\": {"
+ " \"user.avg_rating\": {"
+ " \"avg\": {"
+ " \"field\": \"stars\""
+ " } } } }"
@ -265,10 +265,14 @@ public class DataFramePivotRestIT extends DataFrameRestTestCase {
List<Map<String, Object>> preview = (List<Map<String, Object>>)previewDataframeResponse.get("preview");
// preview is limited to 100
assertThat(preview.size(), equalTo(100));
Set<String> expectedFields = new HashSet<>(Arrays.asList("reviewer", "by_day", "avg_rating"));
Set<String> expectedTopLevelFields = new HashSet<>(Arrays.asList("user", "by_day"));
Set<String> expectedNestedFields = new HashSet<>(Arrays.asList("id", "avg_rating"));
preview.forEach(p -> {
Set<String> keys = p.keySet();
assertThat(keys, equalTo(expectedFields));
assertThat(keys, equalTo(expectedTopLevelFields));
Map<String, Object> nestedObj = (Map<String, Object>)p.get("user");
keys = nestedObj.keySet();
assertThat(keys, equalTo(expectedNestedFields));
});
}

View File

@ -61,7 +61,7 @@ public final class AggregationResultUtils {
groups.getGroups().keySet().forEach(destinationFieldName -> {
Object value = bucket.getKey().get(destinationFieldName);
idGen.add(destinationFieldName, value);
document.put(destinationFieldName, value);
updateDocument(document, destinationFieldName, value);
});
List<String> aggNames = aggregationBuilders.stream().map(AggregationBuilder::getName).collect(Collectors.toList());