[ML][Data Frame] allow null values for aggs with sparse data (#42966) (#42998)

* [ML][Data Frame] allow null values for aggs with sparse data

* Making classes static, memory allocation optimization
This commit is contained in:
Benjamin Trent 2019-06-07 15:43:06 -05:00 committed by GitHub
parent d6fe4b648d
commit 553c73b22d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 89 additions and 44 deletions

View File

@ -6,15 +6,13 @@
package org.elasticsearch.xpack.dataframe.transforms.pivot;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.common.Numbers;
import org.elasticsearch.search.aggregations.Aggregation;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.search.aggregations.bucket.composite.CompositeAggregation;
import org.elasticsearch.search.aggregations.metrics.GeoCentroid;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation.SingleValue;
import org.elasticsearch.search.aggregations.metrics.ScriptedMetric;
import org.elasticsearch.xpack.core.dataframe.DataFrameField;
@ -23,6 +21,7 @@ import org.elasticsearch.xpack.core.dataframe.transforms.pivot.GroupConfig;
import org.elasticsearch.xpack.dataframe.transforms.IDGenerator;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@ -32,7 +31,15 @@ import java.util.stream.Stream;
import static org.elasticsearch.xpack.dataframe.transforms.pivot.SchemaUtil.isNumericType;
public final class AggregationResultUtils {
private static final Logger logger = LogManager.getLogger(AggregationResultUtils.class);
private static final Map<String, AggValueExtractor> TYPE_VALUE_EXTRACTOR_MAP;
static {
Map<String, AggValueExtractor> tempMap = new HashMap<>();
tempMap.put(SingleValue.class.getName(), new SingleValueAggExtractor());
tempMap.put(ScriptedMetric.class.getName(), new ScriptedMetricAggExtractor());
tempMap.put(GeoCentroid.class.getName(), new GeoCentroidAggExtractor());
TYPE_VALUE_EXTRACTOR_MAP = Collections.unmodifiableMap(tempMap);
}
/**
* Extracts aggregation results from a composite aggregation and puts it into a map.
@ -73,27 +80,8 @@ public final class AggregationResultUtils {
// TODO: support other aggregation types
Aggregation aggResult = bucket.getAggregations().get(aggName);
if (aggResult instanceof NumericMetricsAggregation.SingleValue) {
NumericMetricsAggregation.SingleValue aggResultSingleValue = (SingleValue) aggResult;
// If the type is numeric or if the formatted string is the same as simply making the value a string,
// gather the `value` type, otherwise utilize `getValueAsString` so we don't lose formatted outputs.
if (isNumericType(fieldType) ||
(aggResultSingleValue.getValueAsString().equals(String.valueOf(aggResultSingleValue.value())))) {
updateDocument(document, aggName, aggResultSingleValue.value());
} else {
updateDocument(document, aggName, aggResultSingleValue.getValueAsString());
}
} else if (aggResult instanceof ScriptedMetric) {
updateDocument(document, aggName, ((ScriptedMetric) aggResult).aggregation());
} else if (aggResult instanceof GeoCentroid) {
updateDocument(document, aggName, ((GeoCentroid) aggResult).centroid().toString());
} else {
// Execution should never reach this point!
// Creating transforms with unsupported aggregations shall not be possible
throw new AggregationExtractionException("unsupported aggregation [{}] with name [{}]",
aggResult.getType(),
aggResult.getName());
}
AggValueExtractor extractor = getExtractor(aggResult);
updateDocument(document, aggName, extractor.value(aggResult, fieldType));
}
document.put(DataFrameField.DOCUMENT_ID_FIELD, idGen.getID());
@ -102,6 +90,23 @@ public final class AggregationResultUtils {
});
}
static AggValueExtractor getExtractor(Aggregation aggregation) {
if (aggregation instanceof SingleValue) {
return TYPE_VALUE_EXTRACTOR_MAP.get(SingleValue.class.getName());
} else if (aggregation instanceof ScriptedMetric) {
return TYPE_VALUE_EXTRACTOR_MAP.get(ScriptedMetric.class.getName());
} else if (aggregation instanceof GeoCentroid) {
return TYPE_VALUE_EXTRACTOR_MAP.get(GeoCentroid.class.getName());
} else {
// Execution should never reach this point!
// Creating transforms with unsupported aggregations shall not be possible
throw new AggregationExtractionException("unsupported aggregation [{}] with name [{}]",
aggregation.getType(),
aggregation.getName());
}
}
@SuppressWarnings("unchecked")
static void updateDocument(Map<String, Object> document, String fieldName, Object value) {
String[] fieldTokens = fieldName.split("\\.");
@ -147,4 +152,44 @@ public final class AggregationResultUtils {
super(msg, args);
}
}
private interface AggValueExtractor {
Object value(Aggregation aggregation, String fieldType);
}
private static class SingleValueAggExtractor implements AggValueExtractor {
@Override
public Object value(Aggregation agg, String fieldType) {
SingleValue aggregation = (SingleValue)agg;
// If the double is invalid, this indicates sparse data
if (Numbers.isValidDouble(aggregation.value()) == false) {
return null;
}
// If the type is numeric or if the formatted string is the same as simply making the value a string,
// gather the `value` type, otherwise utilize `getValueAsString` so we don't lose formatted outputs.
if (isNumericType(fieldType) ||
aggregation.getValueAsString().equals(String.valueOf(aggregation.value()))){
return aggregation.value();
} else {
return aggregation.getValueAsString();
}
}
}
private static class ScriptedMetricAggExtractor implements AggValueExtractor {
@Override
public Object value(Aggregation agg, String fieldType) {
ScriptedMetric aggregation = (ScriptedMetric)agg;
return aggregation.aggregation();
}
}
private static class GeoCentroidAggExtractor implements AggValueExtractor {
@Override
public Object value(Aggregation agg, String fieldType) {
GeoCentroid aggregation = (GeoCentroid)agg;
// if the account is `0` iff there is no contained centroid
return aggregation.count() > 0 ? aggregation.centroid().toString() : null;
}
}
}

View File

@ -135,8 +135,8 @@ public class AggregationResultUtilsTests extends ESTestCase {
KEY, asMap(
targetField, "ID3"),
aggTypedName, asMap(
"value", 12.55),
DOC_COUNT, 9)
"value", Double.NaN),
DOC_COUNT, 0)
));
List<Map<String, Object>> expected = asList(
@ -150,14 +150,14 @@ public class AggregationResultUtilsTests extends ESTestCase {
),
asMap(
targetField, "ID3",
aggName, 12.55
aggName, null
)
);
Map<String, String> fieldTypeMap = asStringMap(
targetField, "keyword",
aggName, "double"
);
executeTest(groupBy, aggregationBuilders, Collections.emptyList(), input, fieldTypeMap, expected, 20);
executeTest(groupBy, aggregationBuilders, Collections.emptyList(), input, fieldTypeMap, expected, 11);
}
public void testExtractCompositeAggregationResultsMultipleGroups() throws IOException {
@ -212,8 +212,8 @@ public class AggregationResultUtilsTests extends ESTestCase {
targetField2, "ID2_2"
),
aggTypedName, asMap(
"value", 12.55),
DOC_COUNT, 4)
"value", Double.NaN),
DOC_COUNT, 0)
));
List<Map<String, Object>> expected = asList(
@ -235,7 +235,7 @@ public class AggregationResultUtilsTests extends ESTestCase {
asMap(
targetField, "ID3",
targetField2, "ID2_2",
aggName, 12.55
aggName, null
)
);
Map<String, String> fieldTypeMap = asStringMap(
@ -243,7 +243,7 @@ public class AggregationResultUtilsTests extends ESTestCase {
targetField, "keyword",
targetField2, "keyword"
);
executeTest(groupBy, aggregationBuilders, Collections.emptyList(), input, fieldTypeMap, expected, 10);
executeTest(groupBy, aggregationBuilders, Collections.emptyList(), input, fieldTypeMap, expected, 6);
}
public void testExtractCompositeAggregationResultsMultiAggregations() throws IOException {
@ -287,7 +287,7 @@ public class AggregationResultUtilsTests extends ESTestCase {
aggTypedName, asMap(
"value", 12.55),
aggTypedName2, asMap(
"value", -2.44),
"value", Double.NaN),
DOC_COUNT, 1)
));
@ -305,7 +305,7 @@ public class AggregationResultUtilsTests extends ESTestCase {
asMap(
targetField, "ID3",
aggName, 12.55,
aggName2, -2.44
aggName2, null
)
);
Map<String, String> fieldTypeMap = asStringMap(
@ -383,8 +383,8 @@ public class AggregationResultUtilsTests extends ESTestCase {
aggTypedName, asMap(
"value", 12.55),
aggTypedName2, asMap(
"value", -100.44,
"value_as_string", "-100.44F"),
"value", Double.NaN,
"value_as_string", "NaN"),
DOC_COUNT, 4)
));
@ -411,7 +411,7 @@ public class AggregationResultUtilsTests extends ESTestCase {
targetField, "ID3",
targetField2, "ID2_2",
aggName, 12.55,
aggName2, "-100.44F"
aggName2, null
)
);
Map<String, String> fieldTypeMap = asStringMap(
@ -476,8 +476,8 @@ public class AggregationResultUtilsTests extends ESTestCase {
targetField2, "ID2_2"
),
aggTypedName, asMap(
"value", asMap("field", 12.0)),
DOC_COUNT, 4)
"value", null),
DOC_COUNT, 0)
));
List<Map<String, Object>> expected = asList(
@ -499,14 +499,14 @@ public class AggregationResultUtilsTests extends ESTestCase {
asMap(
targetField, "ID3",
targetField2, "ID2_2",
aggName, asMap("field", 12.0)
aggName, null
)
);
Map<String, String> fieldTypeMap = asStringMap(
targetField, "keyword",
targetField2, "keyword"
);
executeTest(groupBy, aggregationBuilders, Collections.emptyList(), input, fieldTypeMap, expected, 10);
executeTest(groupBy, aggregationBuilders, Collections.emptyList(), input, fieldTypeMap, expected, 6);
}
public void testExtractCompositeAggregationResultsWithPipelineAggregation() throws IOException {
@ -576,7 +576,7 @@ public class AggregationResultUtilsTests extends ESTestCase {
aggTypedName, asMap(
"value", 12.0),
pipelineAggTypedName, asMap(
"value", 12.0),
"value", Double.NaN),
DOC_COUNT, 4)
));
@ -603,7 +603,7 @@ public class AggregationResultUtilsTests extends ESTestCase {
targetField, "ID3",
targetField2, "ID2_2",
aggName, 12.0,
pipelineAggName, 12.0
pipelineAggName, null
)
);
Map<String, String> fieldTypeMap = asStringMap(