[ML] Changes default destination index field mapping and adds scripted_metric agg (#40750) (#40846)

* [ML] Allowing destination index mappings to have dynamic types, adds script_metric agg

* Making dynamic|source mapping explicit
This commit is contained in:
Benjamin Trent 2019-04-05 11:34:20 -05:00 committed by GitHub
parent 4452e8e10f
commit a8dbb07546
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 200 additions and 15 deletions

View File

@ -314,6 +314,59 @@ public class DataFramePivotRestIT extends DataFrameRestTestCase {
assertThat(actual, containsString("2017-01-15T"));
}
public void testPivotWithScriptedMetricAgg() throws Exception {
String transformId = "scriptedMetricPivot";
String dataFrameIndex = "scripted_metric_pivot_reviews";
setupDataAccessRole(DATA_ACCESS_ROLE, REVIEWS_INDEX_NAME, dataFrameIndex);
final Request createDataframeTransformRequest = createRequestWithAuth("PUT", DATAFRAME_ENDPOINT + transformId,
BASIC_AUTH_VALUE_DATA_FRAME_ADMIN_WITH_SOME_DATA_ACCESS);
String config = "{"
+ " \"source\": {\"index\":\"" + REVIEWS_INDEX_NAME + "\"},"
+ " \"dest\": {\"index\":\"" + dataFrameIndex + "\"},";
config += " \"pivot\": {"
+ " \"group_by\": {"
+ " \"reviewer\": {"
+ " \"terms\": {"
+ " \"field\": \"user_id\""
+ " } } },"
+ " \"aggregations\": {"
+ " \"avg_rating\": {"
+ " \"avg\": {"
+ " \"field\": \"stars\""
+ " } },"
+ " \"squared_sum\": {"
+ " \"scripted_metric\": {"
+ " \"init_script\": \"state.reviews_sqrd = []\","
+ " \"map_script\": \"state.reviews_sqrd.add(doc.stars.value * doc.stars.value)\","
+ " \"combine_script\": \"state.reviews_sqrd\","
+ " \"reduce_script\": \"def sum = 0.0; for(l in states){ for(a in l) { sum += a}} return sum\""
+ " } }"
+ " } }"
+ "}";
createDataframeTransformRequest.setJsonEntity(config);
Map<String, Object> createDataframeTransformResponse = entityAsMap(client().performRequest(createDataframeTransformRequest));
assertThat(createDataframeTransformResponse.get("acknowledged"), equalTo(Boolean.TRUE));
assertTrue(indexExists(dataFrameIndex));
startAndWaitForTransform(transformId, dataFrameIndex, BASIC_AUTH_VALUE_DATA_FRAME_ADMIN_WITH_SOME_DATA_ACCESS);
// we expect 27 documents as there shall be 27 user_id's
Map<String, Object> indexStats = getAsMap(dataFrameIndex + "/_stats");
assertEquals(27, XContentMapValues.extractValue("_all.total.docs.count", indexStats));
// get and check some users
Map<String, Object> searchResult = getAsMap(dataFrameIndex + "/_search?q=reviewer:user_4");
assertEquals(1, XContentMapValues.extractValue("hits.total.value", searchResult));
Number actual = (Number) ((List<?>) XContentMapValues.extractValue("hits.hits._source.avg_rating", searchResult)).get(0);
assertEquals(3.878048780, actual.doubleValue(), 0.000001);
actual = (Number) ((List<?>) XContentMapValues.extractValue("hits.hits._source.squared_sum", searchResult)).get(0);
assertEquals(711.0, actual.doubleValue(), 0.000001);
}
private void assertOnePivotValue(String query, double expected) throws IOException {
Map<String, Object> searchResult = getAsMap(query);

View File

@ -95,6 +95,5 @@ public class TransportPreviewDataFrameTransformAction extends
},
listener::onFailure
));
}
}

View File

@ -13,6 +13,7 @@ import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.bucket.composite.CompositeAggregation;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation.SingleValue;
import org.elasticsearch.search.aggregations.metrics.ScriptedMetric;
import org.elasticsearch.xpack.core.dataframe.DataFrameField;
import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameIndexerTransformStats;
import org.elasticsearch.xpack.core.dataframe.transforms.pivot.GroupConfig;
@ -73,6 +74,8 @@ final class AggregationResultUtils {
} else {
document.put(aggName, aggResultSingleValue.getValueAsString());
}
} else if (aggResult instanceof ScriptedMetric) {
document.put(aggName, ((ScriptedMetric) aggResult).aggregation());
} else {
// Execution should never reach this point!
// Creating transforms with unsupported aggregations shall not be possible

View File

@ -12,6 +12,11 @@ import java.util.stream.Collectors;
import java.util.stream.Stream;
public final class Aggregations {
// the field mapping should not explicitly be set and allow ES to dynamically determine mapping via the data.
private static final String DYNAMIC = "_dynamic";
// the field mapping should be determined explicitly from the source field mapping if possible.
private static final String SOURCE = "_source";
private Aggregations() {}
/**
@ -27,9 +32,10 @@ public final class Aggregations {
AVG("avg", "double"),
CARDINALITY("cardinality", "long"),
VALUE_COUNT("value_count", "long"),
MAX("max", null),
MIN("min", null),
SUM("sum", null);
MAX("max", SOURCE),
MIN("min", SOURCE),
SUM("sum", SOURCE),
SCRIPTED_METRIC("scripted_metric", DYNAMIC);
private final String aggregationType;
private final String targetMapping;
@ -55,8 +61,12 @@ public final class Aggregations {
return aggregationSupported.contains(aggregationType.toUpperCase(Locale.ROOT));
}
public static boolean isDynamicMapping(String targetMapping) {
return DYNAMIC.equals(targetMapping);
}
public static String resolveTargetMapping(String aggregationType, String sourceType) {
AggregationType agg = AggregationType.valueOf(aggregationType.toUpperCase(Locale.ROOT));
return agg.getTargetMapping() == null ? sourceType : agg.getTargetMapping();
return agg.getTargetMapping().equals(SOURCE) ? sourceType : agg.getTargetMapping();
}
}

View File

@ -15,6 +15,7 @@ import org.elasticsearch.action.admin.indices.mapping.get.GetFieldMappingsRespon
import org.elasticsearch.client.Client;
import org.elasticsearch.index.mapper.NumberFieldMapper;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.ScriptedMetricAggregationBuilder;
import org.elasticsearch.search.aggregations.support.ValuesSourceAggregationBuilder;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.dataframe.transforms.pivot.PivotConfig;
@ -75,6 +76,8 @@ public final class SchemaUtil {
ValuesSourceAggregationBuilder<?, ?> valueSourceAggregation = (ValuesSourceAggregationBuilder<?, ?>) agg;
aggregationSourceFieldNames.put(valueSourceAggregation.getName(), valueSourceAggregation.field());
aggregationTypes.put(valueSourceAggregation.getName(), valueSourceAggregation.getType());
} else if(agg instanceof ScriptedMetricAggregationBuilder) {
aggregationTypes.put(agg.getName(), agg.getType());
} else {
// execution should not reach this point
listener.onFailure(new RuntimeException("Unsupported aggregation type [" + agg.getType() + "]"));
@ -127,15 +130,17 @@ public final class SchemaUtil {
aggregationTypes.forEach((targetFieldName, aggregationName) -> {
String sourceFieldName = aggregationSourceFieldNames.get(targetFieldName);
String destinationMapping = Aggregations.resolveTargetMapping(aggregationName, sourceMappings.get(sourceFieldName));
String sourceMapping = sourceFieldName == null ? null : sourceMappings.get(sourceFieldName);
String destinationMapping = Aggregations.resolveTargetMapping(aggregationName, sourceMapping);
logger.debug(
"Deduced mapping for: [" + targetFieldName + "], agg type [" + aggregationName + "] to [" + destinationMapping + "]");
if (destinationMapping != null) {
if (Aggregations.isDynamicMapping(destinationMapping)) {
logger.info("Dynamic target mapping set for field ["+ targetFieldName +"] and aggregation [" + aggregationName +"]");
} else if (destinationMapping != null) {
targetMapping.put(targetFieldName, destinationMapping);
} else {
logger.warn("Failed to deduce mapping for [" + targetFieldName + "], fall back to double.");
targetMapping.put(targetFieldName, "double");
logger.warn("Failed to deduce mapping for [" + targetFieldName + "], fall back to dynamic mapping.");
}
});

View File

@ -35,9 +35,11 @@ import org.elasticsearch.search.aggregations.metrics.ParsedCardinality;
import org.elasticsearch.search.aggregations.metrics.ParsedExtendedStats;
import org.elasticsearch.search.aggregations.metrics.ParsedMax;
import org.elasticsearch.search.aggregations.metrics.ParsedMin;
import org.elasticsearch.search.aggregations.metrics.ParsedScriptedMetric;
import org.elasticsearch.search.aggregations.metrics.ParsedStats;
import org.elasticsearch.search.aggregations.metrics.ParsedSum;
import org.elasticsearch.search.aggregations.metrics.ParsedValueCount;
import org.elasticsearch.search.aggregations.metrics.ScriptedMetricAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.StatsAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.SumAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.ValueCountAggregationBuilder;
@ -76,6 +78,7 @@ public class AggregationResultUtilsTests extends ESTestCase {
map.put(MaxAggregationBuilder.NAME, (p, c) -> ParsedMax.fromXContent(p, (String) c));
map.put(SumAggregationBuilder.NAME, (p, c) -> ParsedSum.fromXContent(p, (String) c));
map.put(AvgAggregationBuilder.NAME, (p, c) -> ParsedAvg.fromXContent(p, (String) c));
map.put(ScriptedMetricAggregationBuilder.NAME, (p, c) -> ParsedScriptedMetric.fromXContent(p, (String) c));
map.put(ValueCountAggregationBuilder.NAME, (p, c) -> ParsedValueCount.fromXContent(p, (String) c));
map.put(StatsAggregationBuilder.NAME, (p, c) -> ParsedStats.fromXContent(p, (String) c));
map.put(StatsBucketPipelineAggregationBuilder.NAME, (p, c) -> ParsedStatsBucket.fromXContent(p, (String) c));
@ -409,6 +412,92 @@ public class AggregationResultUtilsTests extends ESTestCase {
executeTest(groupBy, aggregationBuilders, input, fieldTypeMap, expected, 10);
}
public void testExtractCompositeAggregationResultsWithDynamicType() throws IOException {
String targetField = randomAlphaOfLengthBetween(5, 10);
String targetField2 = randomAlphaOfLengthBetween(5, 10) + "_2";
GroupConfig groupBy = parseGroupConfig("{"
+ "\"" + targetField + "\" : {"
+ " \"terms\" : {"
+ " \"field\" : \"doesn't_matter_for_this_test\""
+ " } },"
+ "\"" + targetField2 + "\" : {"
+ " \"terms\" : {"
+ " \"field\" : \"doesn't_matter_for_this_test\""
+ " } }"
+ "}");
String aggName = randomAlphaOfLengthBetween(5, 10);
String aggTypedName = "scripted_metric#" + aggName;
Collection<AggregationBuilder> aggregationBuilders = asList(AggregationBuilders.scriptedMetric(aggName));
Map<String, Object> input = asMap(
"buckets",
asList(
asMap(
KEY, asMap(
targetField, "ID1",
targetField2, "ID1_2"
),
aggTypedName, asMap(
"value", asMap("field", 123.0)),
DOC_COUNT, 1),
asMap(
KEY, asMap(
targetField, "ID1",
targetField2, "ID2_2"
),
aggTypedName, asMap(
"value", asMap("field", 1.0)),
DOC_COUNT, 2),
asMap(
KEY, asMap(
targetField, "ID2",
targetField2, "ID1_2"
),
aggTypedName, asMap(
"value", asMap("field", 2.13)),
DOC_COUNT, 3),
asMap(
KEY, asMap(
targetField, "ID3",
targetField2, "ID2_2"
),
aggTypedName, asMap(
"value", asMap("field", 12.0)),
DOC_COUNT, 4)
));
List<Map<String, Object>> expected = asList(
asMap(
targetField, "ID1",
targetField2, "ID1_2",
aggName, asMap("field", 123.0)
),
asMap(
targetField, "ID1",
targetField2, "ID2_2",
aggName, asMap("field", 1.0)
),
asMap(
targetField, "ID2",
targetField2, "ID1_2",
aggName, asMap("field", 2.13)
),
asMap(
targetField, "ID3",
targetField2, "ID2_2",
aggName, asMap("field", 12.0)
)
);
Map<String, String> fieldTypeMap = asStringMap(
targetField, "keyword",
targetField2, "keyword"
);
executeTest(groupBy, aggregationBuilders, input, fieldTypeMap, expected, 10);
}
public void testExtractCompositeAggregationResultsDocIDs() throws IOException {
String targetField = randomAlphaOfLengthBetween(5, 10);
String targetField2 = randomAlphaOfLengthBetween(5, 10) + "_2";

View File

@ -15,9 +15,31 @@ public class AggregationsTests extends ESTestCase {
assertEquals("double", Aggregations.resolveTargetMapping("avg", "int"));
assertEquals("double", Aggregations.resolveTargetMapping("avg", "double"));
// cardinality
assertEquals("long", Aggregations.resolveTargetMapping("cardinality", "int"));
assertEquals("long", Aggregations.resolveTargetMapping("cardinality", "double"));
// value_count
assertEquals("long", Aggregations.resolveTargetMapping("value_count", "int"));
assertEquals("long", Aggregations.resolveTargetMapping("value_count", "double"));
// max
assertEquals("int", Aggregations.resolveTargetMapping("max", "int"));
assertEquals("double", Aggregations.resolveTargetMapping("max", "double"));
assertEquals("half_float", Aggregations.resolveTargetMapping("max", "half_float"));
// min
assertEquals("int", Aggregations.resolveTargetMapping("min", "int"));
assertEquals("double", Aggregations.resolveTargetMapping("min", "double"));
assertEquals("half_float", Aggregations.resolveTargetMapping("min", "half_float"));
// sum
assertEquals("int", Aggregations.resolveTargetMapping("sum", "int"));
assertEquals("double", Aggregations.resolveTargetMapping("sum", "double"));
assertEquals("half_float", Aggregations.resolveTargetMapping("sum", "half_float"));
// scripted_metric
assertEquals("_dynamic", Aggregations.resolveTargetMapping("scripted_metric", null));
assertEquals("_dynamic", Aggregations.resolveTargetMapping("scripted_metric", "int"));
}
}

View File

@ -37,9 +37,7 @@ import org.junit.Before;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
@ -176,14 +174,20 @@ public class PivotTests extends ESTestCase {
}
private AggregationConfig getAggregationConfig(String agg) throws IOException {
if (agg.equals(AggregationType.SCRIPTED_METRIC.getName())) {
return parseAggregations("{\"pivot_scripted_metric\": {\n" +
"\"scripted_metric\": {\n" +
" \"init_script\" : \"state.transactions = []\",\n" +
" \"map_script\" : \"state.transactions.add(doc.type.value == 'sale' ? doc.amount.value : -1 * doc.amount.value)\", \n" +
" \"combine_script\" : \"double profit = 0; for (t in state.transactions) { profit += t } return profit\",\n" +
" \"reduce_script\" : \"double profit = 0; for (a in states) { profit += a } return profit\"\n" +
" }\n" +
"}}");
}
return parseAggregations("{\n" + " \"pivot_" + agg + "\": {\n" + " \"" + agg + "\": {\n" + " \"field\": \"values\"\n"
+ " }\n" + " }" + "}");
}
private Map<String, String> getFieldMappings() {
return Collections.singletonMap("values", "double");
}
private AggregationConfig parseAggregations(String json) throws IOException {
final XContentParser parser = XContentType.JSON.xContent().createParser(xContentRegistry(),
DeprecationHandler.THROW_UNSUPPORTED_OPERATION, json);