* [ML] Allowing destination index mappings to have dynamic types, adds script_metric agg * Making dynamic|source mapping explicit
This commit is contained in:
parent
4452e8e10f
commit
a8dbb07546
|
@ -314,6 +314,59 @@ public class DataFramePivotRestIT extends DataFrameRestTestCase {
|
|||
assertThat(actual, containsString("2017-01-15T"));
|
||||
}
|
||||
|
||||
public void testPivotWithScriptedMetricAgg() throws Exception {
|
||||
String transformId = "scriptedMetricPivot";
|
||||
String dataFrameIndex = "scripted_metric_pivot_reviews";
|
||||
setupDataAccessRole(DATA_ACCESS_ROLE, REVIEWS_INDEX_NAME, dataFrameIndex);
|
||||
|
||||
final Request createDataframeTransformRequest = createRequestWithAuth("PUT", DATAFRAME_ENDPOINT + transformId,
|
||||
BASIC_AUTH_VALUE_DATA_FRAME_ADMIN_WITH_SOME_DATA_ACCESS);
|
||||
|
||||
String config = "{"
|
||||
+ " \"source\": {\"index\":\"" + REVIEWS_INDEX_NAME + "\"},"
|
||||
+ " \"dest\": {\"index\":\"" + dataFrameIndex + "\"},";
|
||||
|
||||
config += " \"pivot\": {"
|
||||
+ " \"group_by\": {"
|
||||
+ " \"reviewer\": {"
|
||||
+ " \"terms\": {"
|
||||
+ " \"field\": \"user_id\""
|
||||
+ " } } },"
|
||||
+ " \"aggregations\": {"
|
||||
+ " \"avg_rating\": {"
|
||||
+ " \"avg\": {"
|
||||
+ " \"field\": \"stars\""
|
||||
+ " } },"
|
||||
+ " \"squared_sum\": {"
|
||||
+ " \"scripted_metric\": {"
|
||||
+ " \"init_script\": \"state.reviews_sqrd = []\","
|
||||
+ " \"map_script\": \"state.reviews_sqrd.add(doc.stars.value * doc.stars.value)\","
|
||||
+ " \"combine_script\": \"state.reviews_sqrd\","
|
||||
+ " \"reduce_script\": \"def sum = 0.0; for(l in states){ for(a in l) { sum += a}} return sum\""
|
||||
+ " } }"
|
||||
+ " } }"
|
||||
+ "}";
|
||||
|
||||
createDataframeTransformRequest.setJsonEntity(config);
|
||||
Map<String, Object> createDataframeTransformResponse = entityAsMap(client().performRequest(createDataframeTransformRequest));
|
||||
assertThat(createDataframeTransformResponse.get("acknowledged"), equalTo(Boolean.TRUE));
|
||||
assertTrue(indexExists(dataFrameIndex));
|
||||
|
||||
startAndWaitForTransform(transformId, dataFrameIndex, BASIC_AUTH_VALUE_DATA_FRAME_ADMIN_WITH_SOME_DATA_ACCESS);
|
||||
|
||||
// we expect 27 documents as there shall be 27 user_id's
|
||||
Map<String, Object> indexStats = getAsMap(dataFrameIndex + "/_stats");
|
||||
assertEquals(27, XContentMapValues.extractValue("_all.total.docs.count", indexStats));
|
||||
|
||||
// get and check some users
|
||||
Map<String, Object> searchResult = getAsMap(dataFrameIndex + "/_search?q=reviewer:user_4");
|
||||
assertEquals(1, XContentMapValues.extractValue("hits.total.value", searchResult));
|
||||
Number actual = (Number) ((List<?>) XContentMapValues.extractValue("hits.hits._source.avg_rating", searchResult)).get(0);
|
||||
assertEquals(3.878048780, actual.doubleValue(), 0.000001);
|
||||
actual = (Number) ((List<?>) XContentMapValues.extractValue("hits.hits._source.squared_sum", searchResult)).get(0);
|
||||
assertEquals(711.0, actual.doubleValue(), 0.000001);
|
||||
}
|
||||
|
||||
private void assertOnePivotValue(String query, double expected) throws IOException {
|
||||
Map<String, Object> searchResult = getAsMap(query);
|
||||
|
||||
|
|
|
@ -95,6 +95,5 @@ public class TransportPreviewDataFrameTransformAction extends
|
|||
},
|
||||
listener::onFailure
|
||||
));
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -13,6 +13,7 @@ import org.elasticsearch.search.aggregations.AggregationBuilder;
|
|||
import org.elasticsearch.search.aggregations.bucket.composite.CompositeAggregation;
|
||||
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
|
||||
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation.SingleValue;
|
||||
import org.elasticsearch.search.aggregations.metrics.ScriptedMetric;
|
||||
import org.elasticsearch.xpack.core.dataframe.DataFrameField;
|
||||
import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameIndexerTransformStats;
|
||||
import org.elasticsearch.xpack.core.dataframe.transforms.pivot.GroupConfig;
|
||||
|
@ -73,6 +74,8 @@ final class AggregationResultUtils {
|
|||
} else {
|
||||
document.put(aggName, aggResultSingleValue.getValueAsString());
|
||||
}
|
||||
} else if (aggResult instanceof ScriptedMetric) {
|
||||
document.put(aggName, ((ScriptedMetric) aggResult).aggregation());
|
||||
} else {
|
||||
// Execution should never reach this point!
|
||||
// Creating transforms with unsupported aggregations shall not be possible
|
||||
|
|
|
@ -12,6 +12,11 @@ import java.util.stream.Collectors;
|
|||
import java.util.stream.Stream;
|
||||
|
||||
public final class Aggregations {
|
||||
|
||||
// the field mapping should not explicitly be set and allow ES to dynamically determine mapping via the data.
|
||||
private static final String DYNAMIC = "_dynamic";
|
||||
// the field mapping should be determined explicitly from the source field mapping if possible.
|
||||
private static final String SOURCE = "_source";
|
||||
private Aggregations() {}
|
||||
|
||||
/**
|
||||
|
@ -27,9 +32,10 @@ public final class Aggregations {
|
|||
AVG("avg", "double"),
|
||||
CARDINALITY("cardinality", "long"),
|
||||
VALUE_COUNT("value_count", "long"),
|
||||
MAX("max", null),
|
||||
MIN("min", null),
|
||||
SUM("sum", null);
|
||||
MAX("max", SOURCE),
|
||||
MIN("min", SOURCE),
|
||||
SUM("sum", SOURCE),
|
||||
SCRIPTED_METRIC("scripted_metric", DYNAMIC);
|
||||
|
||||
private final String aggregationType;
|
||||
private final String targetMapping;
|
||||
|
@ -55,8 +61,12 @@ public final class Aggregations {
|
|||
return aggregationSupported.contains(aggregationType.toUpperCase(Locale.ROOT));
|
||||
}
|
||||
|
||||
public static boolean isDynamicMapping(String targetMapping) {
|
||||
return DYNAMIC.equals(targetMapping);
|
||||
}
|
||||
|
||||
public static String resolveTargetMapping(String aggregationType, String sourceType) {
|
||||
AggregationType agg = AggregationType.valueOf(aggregationType.toUpperCase(Locale.ROOT));
|
||||
return agg.getTargetMapping() == null ? sourceType : agg.getTargetMapping();
|
||||
return agg.getTargetMapping().equals(SOURCE) ? sourceType : agg.getTargetMapping();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,6 +15,7 @@ import org.elasticsearch.action.admin.indices.mapping.get.GetFieldMappingsRespon
|
|||
import org.elasticsearch.client.Client;
|
||||
import org.elasticsearch.index.mapper.NumberFieldMapper;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.metrics.ScriptedMetricAggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.support.ValuesSourceAggregationBuilder;
|
||||
import org.elasticsearch.xpack.core.ClientHelper;
|
||||
import org.elasticsearch.xpack.core.dataframe.transforms.pivot.PivotConfig;
|
||||
|
@ -75,6 +76,8 @@ public final class SchemaUtil {
|
|||
ValuesSourceAggregationBuilder<?, ?> valueSourceAggregation = (ValuesSourceAggregationBuilder<?, ?>) agg;
|
||||
aggregationSourceFieldNames.put(valueSourceAggregation.getName(), valueSourceAggregation.field());
|
||||
aggregationTypes.put(valueSourceAggregation.getName(), valueSourceAggregation.getType());
|
||||
} else if(agg instanceof ScriptedMetricAggregationBuilder) {
|
||||
aggregationTypes.put(agg.getName(), agg.getType());
|
||||
} else {
|
||||
// execution should not reach this point
|
||||
listener.onFailure(new RuntimeException("Unsupported aggregation type [" + agg.getType() + "]"));
|
||||
|
@ -127,15 +130,17 @@ public final class SchemaUtil {
|
|||
|
||||
aggregationTypes.forEach((targetFieldName, aggregationName) -> {
|
||||
String sourceFieldName = aggregationSourceFieldNames.get(targetFieldName);
|
||||
String destinationMapping = Aggregations.resolveTargetMapping(aggregationName, sourceMappings.get(sourceFieldName));
|
||||
String sourceMapping = sourceFieldName == null ? null : sourceMappings.get(sourceFieldName);
|
||||
String destinationMapping = Aggregations.resolveTargetMapping(aggregationName, sourceMapping);
|
||||
|
||||
logger.debug(
|
||||
"Deduced mapping for: [" + targetFieldName + "], agg type [" + aggregationName + "] to [" + destinationMapping + "]");
|
||||
if (destinationMapping != null) {
|
||||
if (Aggregations.isDynamicMapping(destinationMapping)) {
|
||||
logger.info("Dynamic target mapping set for field ["+ targetFieldName +"] and aggregation [" + aggregationName +"]");
|
||||
} else if (destinationMapping != null) {
|
||||
targetMapping.put(targetFieldName, destinationMapping);
|
||||
} else {
|
||||
logger.warn("Failed to deduce mapping for [" + targetFieldName + "], fall back to double.");
|
||||
targetMapping.put(targetFieldName, "double");
|
||||
logger.warn("Failed to deduce mapping for [" + targetFieldName + "], fall back to dynamic mapping.");
|
||||
}
|
||||
});
|
||||
|
||||
|
|
|
@ -35,9 +35,11 @@ import org.elasticsearch.search.aggregations.metrics.ParsedCardinality;
|
|||
import org.elasticsearch.search.aggregations.metrics.ParsedExtendedStats;
|
||||
import org.elasticsearch.search.aggregations.metrics.ParsedMax;
|
||||
import org.elasticsearch.search.aggregations.metrics.ParsedMin;
|
||||
import org.elasticsearch.search.aggregations.metrics.ParsedScriptedMetric;
|
||||
import org.elasticsearch.search.aggregations.metrics.ParsedStats;
|
||||
import org.elasticsearch.search.aggregations.metrics.ParsedSum;
|
||||
import org.elasticsearch.search.aggregations.metrics.ParsedValueCount;
|
||||
import org.elasticsearch.search.aggregations.metrics.ScriptedMetricAggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.metrics.StatsAggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.metrics.SumAggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.metrics.ValueCountAggregationBuilder;
|
||||
|
@ -76,6 +78,7 @@ public class AggregationResultUtilsTests extends ESTestCase {
|
|||
map.put(MaxAggregationBuilder.NAME, (p, c) -> ParsedMax.fromXContent(p, (String) c));
|
||||
map.put(SumAggregationBuilder.NAME, (p, c) -> ParsedSum.fromXContent(p, (String) c));
|
||||
map.put(AvgAggregationBuilder.NAME, (p, c) -> ParsedAvg.fromXContent(p, (String) c));
|
||||
map.put(ScriptedMetricAggregationBuilder.NAME, (p, c) -> ParsedScriptedMetric.fromXContent(p, (String) c));
|
||||
map.put(ValueCountAggregationBuilder.NAME, (p, c) -> ParsedValueCount.fromXContent(p, (String) c));
|
||||
map.put(StatsAggregationBuilder.NAME, (p, c) -> ParsedStats.fromXContent(p, (String) c));
|
||||
map.put(StatsBucketPipelineAggregationBuilder.NAME, (p, c) -> ParsedStatsBucket.fromXContent(p, (String) c));
|
||||
|
@ -409,6 +412,92 @@ public class AggregationResultUtilsTests extends ESTestCase {
|
|||
executeTest(groupBy, aggregationBuilders, input, fieldTypeMap, expected, 10);
|
||||
}
|
||||
|
||||
public void testExtractCompositeAggregationResultsWithDynamicType() throws IOException {
|
||||
String targetField = randomAlphaOfLengthBetween(5, 10);
|
||||
String targetField2 = randomAlphaOfLengthBetween(5, 10) + "_2";
|
||||
|
||||
GroupConfig groupBy = parseGroupConfig("{"
|
||||
+ "\"" + targetField + "\" : {"
|
||||
+ " \"terms\" : {"
|
||||
+ " \"field\" : \"doesn't_matter_for_this_test\""
|
||||
+ " } },"
|
||||
+ "\"" + targetField2 + "\" : {"
|
||||
+ " \"terms\" : {"
|
||||
+ " \"field\" : \"doesn't_matter_for_this_test\""
|
||||
+ " } }"
|
||||
+ "}");
|
||||
|
||||
String aggName = randomAlphaOfLengthBetween(5, 10);
|
||||
String aggTypedName = "scripted_metric#" + aggName;
|
||||
|
||||
Collection<AggregationBuilder> aggregationBuilders = asList(AggregationBuilders.scriptedMetric(aggName));
|
||||
|
||||
Map<String, Object> input = asMap(
|
||||
"buckets",
|
||||
asList(
|
||||
asMap(
|
||||
KEY, asMap(
|
||||
targetField, "ID1",
|
||||
targetField2, "ID1_2"
|
||||
),
|
||||
aggTypedName, asMap(
|
||||
"value", asMap("field", 123.0)),
|
||||
DOC_COUNT, 1),
|
||||
asMap(
|
||||
KEY, asMap(
|
||||
targetField, "ID1",
|
||||
targetField2, "ID2_2"
|
||||
),
|
||||
aggTypedName, asMap(
|
||||
"value", asMap("field", 1.0)),
|
||||
DOC_COUNT, 2),
|
||||
asMap(
|
||||
KEY, asMap(
|
||||
targetField, "ID2",
|
||||
targetField2, "ID1_2"
|
||||
),
|
||||
aggTypedName, asMap(
|
||||
"value", asMap("field", 2.13)),
|
||||
DOC_COUNT, 3),
|
||||
asMap(
|
||||
KEY, asMap(
|
||||
targetField, "ID3",
|
||||
targetField2, "ID2_2"
|
||||
),
|
||||
aggTypedName, asMap(
|
||||
"value", asMap("field", 12.0)),
|
||||
DOC_COUNT, 4)
|
||||
));
|
||||
|
||||
List<Map<String, Object>> expected = asList(
|
||||
asMap(
|
||||
targetField, "ID1",
|
||||
targetField2, "ID1_2",
|
||||
aggName, asMap("field", 123.0)
|
||||
),
|
||||
asMap(
|
||||
targetField, "ID1",
|
||||
targetField2, "ID2_2",
|
||||
aggName, asMap("field", 1.0)
|
||||
),
|
||||
asMap(
|
||||
targetField, "ID2",
|
||||
targetField2, "ID1_2",
|
||||
aggName, asMap("field", 2.13)
|
||||
),
|
||||
asMap(
|
||||
targetField, "ID3",
|
||||
targetField2, "ID2_2",
|
||||
aggName, asMap("field", 12.0)
|
||||
)
|
||||
);
|
||||
Map<String, String> fieldTypeMap = asStringMap(
|
||||
targetField, "keyword",
|
||||
targetField2, "keyword"
|
||||
);
|
||||
executeTest(groupBy, aggregationBuilders, input, fieldTypeMap, expected, 10);
|
||||
}
|
||||
|
||||
public void testExtractCompositeAggregationResultsDocIDs() throws IOException {
|
||||
String targetField = randomAlphaOfLengthBetween(5, 10);
|
||||
String targetField2 = randomAlphaOfLengthBetween(5, 10) + "_2";
|
||||
|
|
|
@ -15,9 +15,31 @@ public class AggregationsTests extends ESTestCase {
|
|||
assertEquals("double", Aggregations.resolveTargetMapping("avg", "int"));
|
||||
assertEquals("double", Aggregations.resolveTargetMapping("avg", "double"));
|
||||
|
||||
// cardinality
|
||||
assertEquals("long", Aggregations.resolveTargetMapping("cardinality", "int"));
|
||||
assertEquals("long", Aggregations.resolveTargetMapping("cardinality", "double"));
|
||||
|
||||
// value_count
|
||||
assertEquals("long", Aggregations.resolveTargetMapping("value_count", "int"));
|
||||
assertEquals("long", Aggregations.resolveTargetMapping("value_count", "double"));
|
||||
|
||||
// max
|
||||
assertEquals("int", Aggregations.resolveTargetMapping("max", "int"));
|
||||
assertEquals("double", Aggregations.resolveTargetMapping("max", "double"));
|
||||
assertEquals("half_float", Aggregations.resolveTargetMapping("max", "half_float"));
|
||||
|
||||
// min
|
||||
assertEquals("int", Aggregations.resolveTargetMapping("min", "int"));
|
||||
assertEquals("double", Aggregations.resolveTargetMapping("min", "double"));
|
||||
assertEquals("half_float", Aggregations.resolveTargetMapping("min", "half_float"));
|
||||
|
||||
// sum
|
||||
assertEquals("int", Aggregations.resolveTargetMapping("sum", "int"));
|
||||
assertEquals("double", Aggregations.resolveTargetMapping("sum", "double"));
|
||||
assertEquals("half_float", Aggregations.resolveTargetMapping("sum", "half_float"));
|
||||
|
||||
// scripted_metric
|
||||
assertEquals("_dynamic", Aggregations.resolveTargetMapping("scripted_metric", null));
|
||||
assertEquals("_dynamic", Aggregations.resolveTargetMapping("scripted_metric", "int"));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -37,9 +37,7 @@ import org.junit.Before;
|
|||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
@ -176,14 +174,20 @@ public class PivotTests extends ESTestCase {
|
|||
}
|
||||
|
||||
private AggregationConfig getAggregationConfig(String agg) throws IOException {
|
||||
if (agg.equals(AggregationType.SCRIPTED_METRIC.getName())) {
|
||||
return parseAggregations("{\"pivot_scripted_metric\": {\n" +
|
||||
"\"scripted_metric\": {\n" +
|
||||
" \"init_script\" : \"state.transactions = []\",\n" +
|
||||
" \"map_script\" : \"state.transactions.add(doc.type.value == 'sale' ? doc.amount.value : -1 * doc.amount.value)\", \n" +
|
||||
" \"combine_script\" : \"double profit = 0; for (t in state.transactions) { profit += t } return profit\",\n" +
|
||||
" \"reduce_script\" : \"double profit = 0; for (a in states) { profit += a } return profit\"\n" +
|
||||
" }\n" +
|
||||
"}}");
|
||||
}
|
||||
return parseAggregations("{\n" + " \"pivot_" + agg + "\": {\n" + " \"" + agg + "\": {\n" + " \"field\": \"values\"\n"
|
||||
+ " }\n" + " }" + "}");
|
||||
}
|
||||
|
||||
private Map<String, String> getFieldMappings() {
|
||||
return Collections.singletonMap("values", "double");
|
||||
}
|
||||
|
||||
private AggregationConfig parseAggregations(String json) throws IOException {
|
||||
final XContentParser parser = XContentType.JSON.xContent().createParser(xContentRegistry(),
|
||||
DeprecationHandler.THROW_UNSUPPORTED_OPERATION, json);
|
||||
|
|
Loading…
Reference in New Issue