From a8dbb07546034313d155618055cd8c8728bd5a6a Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Fri, 5 Apr 2019 11:34:20 -0500 Subject: [PATCH] [ML] Changes default destination index field mapping and adds scripted_metric agg (#40750) (#40846) * [ML] Allowing destination index mappings to have dynamic types, adds script_metric agg * Making dynamic|source mapping explicit --- .../integration/DataFramePivotRestIT.java | 53 +++++++++++ ...nsportPreviewDataFrameTransformAction.java | 1 - .../pivot/AggregationResultUtils.java | 3 + .../transforms/pivot/Aggregations.java | 18 +++- .../transforms/pivot/SchemaUtil.java | 13 ++- .../pivot/AggregationResultUtilsTests.java | 89 +++++++++++++++++++ .../transforms/pivot/AggregationsTests.java | 22 +++++ .../transforms/pivot/PivotTests.java | 16 ++-- 8 files changed, 200 insertions(+), 15 deletions(-) diff --git a/x-pack/plugin/data-frame/qa/single-node-tests/src/test/java/org/elasticsearch/xpack/dataframe/integration/DataFramePivotRestIT.java b/x-pack/plugin/data-frame/qa/single-node-tests/src/test/java/org/elasticsearch/xpack/dataframe/integration/DataFramePivotRestIT.java index 95daf11f674..0d14851aa7c 100644 --- a/x-pack/plugin/data-frame/qa/single-node-tests/src/test/java/org/elasticsearch/xpack/dataframe/integration/DataFramePivotRestIT.java +++ b/x-pack/plugin/data-frame/qa/single-node-tests/src/test/java/org/elasticsearch/xpack/dataframe/integration/DataFramePivotRestIT.java @@ -314,6 +314,59 @@ public class DataFramePivotRestIT extends DataFrameRestTestCase { assertThat(actual, containsString("2017-01-15T")); } + public void testPivotWithScriptedMetricAgg() throws Exception { + String transformId = "scriptedMetricPivot"; + String dataFrameIndex = "scripted_metric_pivot_reviews"; + setupDataAccessRole(DATA_ACCESS_ROLE, REVIEWS_INDEX_NAME, dataFrameIndex); + + final Request createDataframeTransformRequest = createRequestWithAuth("PUT", DATAFRAME_ENDPOINT + transformId, + BASIC_AUTH_VALUE_DATA_FRAME_ADMIN_WITH_SOME_DATA_ACCESS); + + String config = "{" + + " \"source\": {\"index\":\"" + REVIEWS_INDEX_NAME + "\"}," + + " \"dest\": {\"index\":\"" + dataFrameIndex + "\"},"; + + config += " \"pivot\": {" + + " \"group_by\": {" + + " \"reviewer\": {" + + " \"terms\": {" + + " \"field\": \"user_id\"" + + " } } }," + + " \"aggregations\": {" + + " \"avg_rating\": {" + + " \"avg\": {" + + " \"field\": \"stars\"" + + " } }," + + " \"squared_sum\": {" + + " \"scripted_metric\": {" + + " \"init_script\": \"state.reviews_sqrd = []\"," + + " \"map_script\": \"state.reviews_sqrd.add(doc.stars.value * doc.stars.value)\"," + + " \"combine_script\": \"state.reviews_sqrd\"," + + " \"reduce_script\": \"def sum = 0.0; for(l in states){ for(a in l) { sum += a}} return sum\"" + + " } }" + + " } }" + + "}"; + + createDataframeTransformRequest.setJsonEntity(config); + Map createDataframeTransformResponse = entityAsMap(client().performRequest(createDataframeTransformRequest)); + assertThat(createDataframeTransformResponse.get("acknowledged"), equalTo(Boolean.TRUE)); + assertTrue(indexExists(dataFrameIndex)); + + startAndWaitForTransform(transformId, dataFrameIndex, BASIC_AUTH_VALUE_DATA_FRAME_ADMIN_WITH_SOME_DATA_ACCESS); + + // we expect 27 documents as there shall be 27 user_id's + Map indexStats = getAsMap(dataFrameIndex + "/_stats"); + assertEquals(27, XContentMapValues.extractValue("_all.total.docs.count", indexStats)); + + // get and check some users + Map searchResult = getAsMap(dataFrameIndex + "/_search?q=reviewer:user_4"); + assertEquals(1, XContentMapValues.extractValue("hits.total.value", searchResult)); + Number actual = (Number) ((List) XContentMapValues.extractValue("hits.hits._source.avg_rating", searchResult)).get(0); + assertEquals(3.878048780, actual.doubleValue(), 0.000001); + actual = (Number) ((List) XContentMapValues.extractValue("hits.hits._source.squared_sum", searchResult)).get(0); + assertEquals(711.0, actual.doubleValue(), 0.000001); + } + private void assertOnePivotValue(String query, double expected) throws IOException { Map searchResult = getAsMap(query); diff --git a/x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/action/TransportPreviewDataFrameTransformAction.java b/x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/action/TransportPreviewDataFrameTransformAction.java index 63b2ed720c0..2a4ba47f507 100644 --- a/x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/action/TransportPreviewDataFrameTransformAction.java +++ b/x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/action/TransportPreviewDataFrameTransformAction.java @@ -95,6 +95,5 @@ public class TransportPreviewDataFrameTransformAction extends }, listener::onFailure )); - } } diff --git a/x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/pivot/AggregationResultUtils.java b/x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/pivot/AggregationResultUtils.java index 5d77f82e610..574afd4f2fd 100644 --- a/x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/pivot/AggregationResultUtils.java +++ b/x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/pivot/AggregationResultUtils.java @@ -13,6 +13,7 @@ import org.elasticsearch.search.aggregations.AggregationBuilder; import org.elasticsearch.search.aggregations.bucket.composite.CompositeAggregation; import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation; import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation.SingleValue; +import org.elasticsearch.search.aggregations.metrics.ScriptedMetric; import org.elasticsearch.xpack.core.dataframe.DataFrameField; import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameIndexerTransformStats; import org.elasticsearch.xpack.core.dataframe.transforms.pivot.GroupConfig; @@ -73,6 +74,8 @@ final class AggregationResultUtils { } else { document.put(aggName, aggResultSingleValue.getValueAsString()); } + } else if (aggResult instanceof ScriptedMetric) { + document.put(aggName, ((ScriptedMetric) aggResult).aggregation()); } else { // Execution should never reach this point! // Creating transforms with unsupported aggregations shall not be possible diff --git a/x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/pivot/Aggregations.java b/x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/pivot/Aggregations.java index 555deae3674..39b139314d4 100644 --- a/x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/pivot/Aggregations.java +++ b/x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/pivot/Aggregations.java @@ -12,6 +12,11 @@ import java.util.stream.Collectors; import java.util.stream.Stream; public final class Aggregations { + + // the field mapping should not explicitly be set and allow ES to dynamically determine mapping via the data. + private static final String DYNAMIC = "_dynamic"; + // the field mapping should be determined explicitly from the source field mapping if possible. + private static final String SOURCE = "_source"; private Aggregations() {} /** @@ -27,9 +32,10 @@ public final class Aggregations { AVG("avg", "double"), CARDINALITY("cardinality", "long"), VALUE_COUNT("value_count", "long"), - MAX("max", null), - MIN("min", null), - SUM("sum", null); + MAX("max", SOURCE), + MIN("min", SOURCE), + SUM("sum", SOURCE), + SCRIPTED_METRIC("scripted_metric", DYNAMIC); private final String aggregationType; private final String targetMapping; @@ -55,8 +61,12 @@ public final class Aggregations { return aggregationSupported.contains(aggregationType.toUpperCase(Locale.ROOT)); } + public static boolean isDynamicMapping(String targetMapping) { + return DYNAMIC.equals(targetMapping); + } + public static String resolveTargetMapping(String aggregationType, String sourceType) { AggregationType agg = AggregationType.valueOf(aggregationType.toUpperCase(Locale.ROOT)); - return agg.getTargetMapping() == null ? sourceType : agg.getTargetMapping(); + return agg.getTargetMapping().equals(SOURCE) ? sourceType : agg.getTargetMapping(); } } diff --git a/x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/pivot/SchemaUtil.java b/x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/pivot/SchemaUtil.java index ff967213e81..deb4afdb73d 100644 --- a/x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/pivot/SchemaUtil.java +++ b/x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/pivot/SchemaUtil.java @@ -15,6 +15,7 @@ import org.elasticsearch.action.admin.indices.mapping.get.GetFieldMappingsRespon import org.elasticsearch.client.Client; import org.elasticsearch.index.mapper.NumberFieldMapper; import org.elasticsearch.search.aggregations.AggregationBuilder; +import org.elasticsearch.search.aggregations.metrics.ScriptedMetricAggregationBuilder; import org.elasticsearch.search.aggregations.support.ValuesSourceAggregationBuilder; import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.dataframe.transforms.pivot.PivotConfig; @@ -75,6 +76,8 @@ public final class SchemaUtil { ValuesSourceAggregationBuilder valueSourceAggregation = (ValuesSourceAggregationBuilder) agg; aggregationSourceFieldNames.put(valueSourceAggregation.getName(), valueSourceAggregation.field()); aggregationTypes.put(valueSourceAggregation.getName(), valueSourceAggregation.getType()); + } else if(agg instanceof ScriptedMetricAggregationBuilder) { + aggregationTypes.put(agg.getName(), agg.getType()); } else { // execution should not reach this point listener.onFailure(new RuntimeException("Unsupported aggregation type [" + agg.getType() + "]")); @@ -127,15 +130,17 @@ public final class SchemaUtil { aggregationTypes.forEach((targetFieldName, aggregationName) -> { String sourceFieldName = aggregationSourceFieldNames.get(targetFieldName); - String destinationMapping = Aggregations.resolveTargetMapping(aggregationName, sourceMappings.get(sourceFieldName)); + String sourceMapping = sourceFieldName == null ? null : sourceMappings.get(sourceFieldName); + String destinationMapping = Aggregations.resolveTargetMapping(aggregationName, sourceMapping); logger.debug( "Deduced mapping for: [" + targetFieldName + "], agg type [" + aggregationName + "] to [" + destinationMapping + "]"); - if (destinationMapping != null) { + if (Aggregations.isDynamicMapping(destinationMapping)) { + logger.info("Dynamic target mapping set for field ["+ targetFieldName +"] and aggregation [" + aggregationName +"]"); + } else if (destinationMapping != null) { targetMapping.put(targetFieldName, destinationMapping); } else { - logger.warn("Failed to deduce mapping for [" + targetFieldName + "], fall back to double."); - targetMapping.put(targetFieldName, "double"); + logger.warn("Failed to deduce mapping for [" + targetFieldName + "], fall back to dynamic mapping."); } }); diff --git a/x-pack/plugin/data-frame/src/test/java/org/elasticsearch/xpack/dataframe/transforms/pivot/AggregationResultUtilsTests.java b/x-pack/plugin/data-frame/src/test/java/org/elasticsearch/xpack/dataframe/transforms/pivot/AggregationResultUtilsTests.java index c2c22dc6ffa..62a4de353bc 100644 --- a/x-pack/plugin/data-frame/src/test/java/org/elasticsearch/xpack/dataframe/transforms/pivot/AggregationResultUtilsTests.java +++ b/x-pack/plugin/data-frame/src/test/java/org/elasticsearch/xpack/dataframe/transforms/pivot/AggregationResultUtilsTests.java @@ -35,9 +35,11 @@ import org.elasticsearch.search.aggregations.metrics.ParsedCardinality; import org.elasticsearch.search.aggregations.metrics.ParsedExtendedStats; import org.elasticsearch.search.aggregations.metrics.ParsedMax; import org.elasticsearch.search.aggregations.metrics.ParsedMin; +import org.elasticsearch.search.aggregations.metrics.ParsedScriptedMetric; import org.elasticsearch.search.aggregations.metrics.ParsedStats; import org.elasticsearch.search.aggregations.metrics.ParsedSum; import org.elasticsearch.search.aggregations.metrics.ParsedValueCount; +import org.elasticsearch.search.aggregations.metrics.ScriptedMetricAggregationBuilder; import org.elasticsearch.search.aggregations.metrics.StatsAggregationBuilder; import org.elasticsearch.search.aggregations.metrics.SumAggregationBuilder; import org.elasticsearch.search.aggregations.metrics.ValueCountAggregationBuilder; @@ -76,6 +78,7 @@ public class AggregationResultUtilsTests extends ESTestCase { map.put(MaxAggregationBuilder.NAME, (p, c) -> ParsedMax.fromXContent(p, (String) c)); map.put(SumAggregationBuilder.NAME, (p, c) -> ParsedSum.fromXContent(p, (String) c)); map.put(AvgAggregationBuilder.NAME, (p, c) -> ParsedAvg.fromXContent(p, (String) c)); + map.put(ScriptedMetricAggregationBuilder.NAME, (p, c) -> ParsedScriptedMetric.fromXContent(p, (String) c)); map.put(ValueCountAggregationBuilder.NAME, (p, c) -> ParsedValueCount.fromXContent(p, (String) c)); map.put(StatsAggregationBuilder.NAME, (p, c) -> ParsedStats.fromXContent(p, (String) c)); map.put(StatsBucketPipelineAggregationBuilder.NAME, (p, c) -> ParsedStatsBucket.fromXContent(p, (String) c)); @@ -409,6 +412,92 @@ public class AggregationResultUtilsTests extends ESTestCase { executeTest(groupBy, aggregationBuilders, input, fieldTypeMap, expected, 10); } + public void testExtractCompositeAggregationResultsWithDynamicType() throws IOException { + String targetField = randomAlphaOfLengthBetween(5, 10); + String targetField2 = randomAlphaOfLengthBetween(5, 10) + "_2"; + + GroupConfig groupBy = parseGroupConfig("{" + + "\"" + targetField + "\" : {" + + " \"terms\" : {" + + " \"field\" : \"doesn't_matter_for_this_test\"" + + " } }," + + "\"" + targetField2 + "\" : {" + + " \"terms\" : {" + + " \"field\" : \"doesn't_matter_for_this_test\"" + + " } }" + + "}"); + + String aggName = randomAlphaOfLengthBetween(5, 10); + String aggTypedName = "scripted_metric#" + aggName; + + Collection aggregationBuilders = asList(AggregationBuilders.scriptedMetric(aggName)); + + Map input = asMap( + "buckets", + asList( + asMap( + KEY, asMap( + targetField, "ID1", + targetField2, "ID1_2" + ), + aggTypedName, asMap( + "value", asMap("field", 123.0)), + DOC_COUNT, 1), + asMap( + KEY, asMap( + targetField, "ID1", + targetField2, "ID2_2" + ), + aggTypedName, asMap( + "value", asMap("field", 1.0)), + DOC_COUNT, 2), + asMap( + KEY, asMap( + targetField, "ID2", + targetField2, "ID1_2" + ), + aggTypedName, asMap( + "value", asMap("field", 2.13)), + DOC_COUNT, 3), + asMap( + KEY, asMap( + targetField, "ID3", + targetField2, "ID2_2" + ), + aggTypedName, asMap( + "value", asMap("field", 12.0)), + DOC_COUNT, 4) + )); + + List> expected = asList( + asMap( + targetField, "ID1", + targetField2, "ID1_2", + aggName, asMap("field", 123.0) + ), + asMap( + targetField, "ID1", + targetField2, "ID2_2", + aggName, asMap("field", 1.0) + ), + asMap( + targetField, "ID2", + targetField2, "ID1_2", + aggName, asMap("field", 2.13) + ), + asMap( + targetField, "ID3", + targetField2, "ID2_2", + aggName, asMap("field", 12.0) + ) + ); + Map fieldTypeMap = asStringMap( + targetField, "keyword", + targetField2, "keyword" + ); + executeTest(groupBy, aggregationBuilders, input, fieldTypeMap, expected, 10); + } + public void testExtractCompositeAggregationResultsDocIDs() throws IOException { String targetField = randomAlphaOfLengthBetween(5, 10); String targetField2 = randomAlphaOfLengthBetween(5, 10) + "_2"; diff --git a/x-pack/plugin/data-frame/src/test/java/org/elasticsearch/xpack/dataframe/transforms/pivot/AggregationsTests.java b/x-pack/plugin/data-frame/src/test/java/org/elasticsearch/xpack/dataframe/transforms/pivot/AggregationsTests.java index 23720ab6af3..47476baebdd 100644 --- a/x-pack/plugin/data-frame/src/test/java/org/elasticsearch/xpack/dataframe/transforms/pivot/AggregationsTests.java +++ b/x-pack/plugin/data-frame/src/test/java/org/elasticsearch/xpack/dataframe/transforms/pivot/AggregationsTests.java @@ -15,9 +15,31 @@ public class AggregationsTests extends ESTestCase { assertEquals("double", Aggregations.resolveTargetMapping("avg", "int")); assertEquals("double", Aggregations.resolveTargetMapping("avg", "double")); + // cardinality + assertEquals("long", Aggregations.resolveTargetMapping("cardinality", "int")); + assertEquals("long", Aggregations.resolveTargetMapping("cardinality", "double")); + + // value_count + assertEquals("long", Aggregations.resolveTargetMapping("value_count", "int")); + assertEquals("long", Aggregations.resolveTargetMapping("value_count", "double")); + // max assertEquals("int", Aggregations.resolveTargetMapping("max", "int")); assertEquals("double", Aggregations.resolveTargetMapping("max", "double")); assertEquals("half_float", Aggregations.resolveTargetMapping("max", "half_float")); + + // min + assertEquals("int", Aggregations.resolveTargetMapping("min", "int")); + assertEquals("double", Aggregations.resolveTargetMapping("min", "double")); + assertEquals("half_float", Aggregations.resolveTargetMapping("min", "half_float")); + + // sum + assertEquals("int", Aggregations.resolveTargetMapping("sum", "int")); + assertEquals("double", Aggregations.resolveTargetMapping("sum", "double")); + assertEquals("half_float", Aggregations.resolveTargetMapping("sum", "half_float")); + + // scripted_metric + assertEquals("_dynamic", Aggregations.resolveTargetMapping("scripted_metric", null)); + assertEquals("_dynamic", Aggregations.resolveTargetMapping("scripted_metric", "int")); } } diff --git a/x-pack/plugin/data-frame/src/test/java/org/elasticsearch/xpack/dataframe/transforms/pivot/PivotTests.java b/x-pack/plugin/data-frame/src/test/java/org/elasticsearch/xpack/dataframe/transforms/pivot/PivotTests.java index c39e9a2589f..be23f515ac8 100644 --- a/x-pack/plugin/data-frame/src/test/java/org/elasticsearch/xpack/dataframe/transforms/pivot/PivotTests.java +++ b/x-pack/plugin/data-frame/src/test/java/org/elasticsearch/xpack/dataframe/transforms/pivot/PivotTests.java @@ -37,9 +37,7 @@ import org.junit.Before; import java.io.IOException; import java.util.ArrayList; -import java.util.Collections; import java.util.List; -import java.util.Map; import java.util.Set; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; @@ -176,14 +174,20 @@ public class PivotTests extends ESTestCase { } private AggregationConfig getAggregationConfig(String agg) throws IOException { + if (agg.equals(AggregationType.SCRIPTED_METRIC.getName())) { + return parseAggregations("{\"pivot_scripted_metric\": {\n" + + "\"scripted_metric\": {\n" + + " \"init_script\" : \"state.transactions = []\",\n" + + " \"map_script\" : \"state.transactions.add(doc.type.value == 'sale' ? doc.amount.value : -1 * doc.amount.value)\", \n" + + " \"combine_script\" : \"double profit = 0; for (t in state.transactions) { profit += t } return profit\",\n" + + " \"reduce_script\" : \"double profit = 0; for (a in states) { profit += a } return profit\"\n" + + " }\n" + + "}}"); + } return parseAggregations("{\n" + " \"pivot_" + agg + "\": {\n" + " \"" + agg + "\": {\n" + " \"field\": \"values\"\n" + " }\n" + " }" + "}"); } - private Map getFieldMappings() { - return Collections.singletonMap("values", "double"); - } - private AggregationConfig parseAggregations(String json) throws IOException { final XContentParser parser = XContentType.JSON.xContent().createParser(xContentRegistry(), DeprecationHandler.THROW_UNSUPPORTED_OPERATION, json);