[ML][Data Frame] add support for geo_bounds aggregation (#44441) (#45281)

This adds support for `geo_bounds` aggregation inside the `pivot.aggregations` configuration. 

The two points returned from the `geo_bounds` aggregation are transformed into `geo_shape` whose types are dynamic given the point's similarity.

* `point` if the two points are identical
* `linestring` if the two points share either a latitude or longitude 
* `polygon` if the two points are completely different

The automatically deduced mapping for the resulting field is a `geo_shape`.
This commit is contained in:
Benjamin Trent 2019-08-07 10:37:09 -05:00 committed by GitHub
parent 95d3a8e8ad
commit 3a71b91dca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 267 additions and 4 deletions

View File

@ -677,6 +677,60 @@ public class DataFramePivotRestIT extends DataFrameRestTestCase {
assertEquals(3.878048780, actual.doubleValue(), 0.000001);
}
@SuppressWarnings("unchecked")
public void testPivotWithGeoBoundsAgg() throws Exception {
String transformId = "geo_bounds_pivot";
String dataFrameIndex = "geo_bounds_pivot_reviews";
setupDataAccessRole(DATA_ACCESS_ROLE, REVIEWS_INDEX_NAME, dataFrameIndex);
final Request createDataframeTransformRequest = createRequestWithAuth("PUT", DATAFRAME_ENDPOINT + transformId,
BASIC_AUTH_VALUE_DATA_FRAME_ADMIN_WITH_SOME_DATA_ACCESS);
String config = "{"
+ " \"source\": {\"index\":\"" + REVIEWS_INDEX_NAME + "\"},"
+ " \"dest\": {\"index\":\"" + dataFrameIndex + "\"},";
config += " \"pivot\": {"
+ " \"group_by\": {"
+ " \"reviewer\": {"
+ " \"terms\": {"
+ " \"field\": \"user_id\""
+ " } } },"
+ " \"aggregations\": {"
+ " \"avg_rating\": {"
+ " \"avg\": {"
+ " \"field\": \"stars\""
+ " } },"
+ " \"boundary\": {"
+ " \"geo_bounds\": {\"field\": \"location\"}"
+ " } } }"
+ "}";
createDataframeTransformRequest.setJsonEntity(config);
Map<String, Object> createDataframeTransformResponse = entityAsMap(client().performRequest(createDataframeTransformRequest));
assertThat(createDataframeTransformResponse.get("acknowledged"), equalTo(Boolean.TRUE));
startAndWaitForTransform(transformId, dataFrameIndex, BASIC_AUTH_VALUE_DATA_FRAME_ADMIN_WITH_SOME_DATA_ACCESS);
assertTrue(indexExists(dataFrameIndex));
// we expect 27 documents as there shall be 27 user_id's
Map<String, Object> indexStats = getAsMap(dataFrameIndex + "/_stats");
assertEquals(27, XContentMapValues.extractValue("_all.total.docs.count", indexStats));
// get and check some users
Map<String, Object> searchResult = getAsMap(dataFrameIndex + "/_search?q=reviewer:user_4");
assertEquals(1, XContentMapValues.extractValue("hits.total.value", searchResult));
Number actual = (Number) ((List<?>) XContentMapValues.extractValue("hits.hits._source.avg_rating", searchResult)).get(0);
assertEquals(3.878048780, actual.doubleValue(), 0.000001);
Map<String, Object> actualObj = (Map<String, Object>) ((List<?>) XContentMapValues.extractValue("hits.hits._source.boundary",
searchResult))
.get(0);
assertThat(actualObj.get("type"), equalTo("point"));
List<Double> coordinates = (List<Double>)actualObj.get("coordinates");
assertEquals((4 + 10), coordinates.get(1), 0.000001);
assertEquals((4 + 15), coordinates.get(0), 0.000001);
}
public void testPivotWithGeoCentroidAgg() throws Exception {
String transformId = "geo_centroid_pivot";
String dataFrameIndex = "geo_centroid_pivot_reviews";

View File

@ -8,10 +8,16 @@ package org.elasticsearch.xpack.dataframe.transforms.pivot;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.common.Numbers;
import org.elasticsearch.common.geo.GeoPoint;
import org.elasticsearch.common.geo.builders.LineStringBuilder;
import org.elasticsearch.common.geo.builders.PointBuilder;
import org.elasticsearch.common.geo.builders.PolygonBuilder;
import org.elasticsearch.common.geo.parsers.ShapeParser;
import org.elasticsearch.search.aggregations.Aggregation;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.search.aggregations.bucket.composite.CompositeAggregation;
import org.elasticsearch.search.aggregations.metrics.GeoBounds;
import org.elasticsearch.search.aggregations.metrics.GeoCentroid;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation.SingleValue;
import org.elasticsearch.search.aggregations.metrics.ScriptedMetric;
@ -20,6 +26,7 @@ import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameIndexerTransfo
import org.elasticsearch.xpack.core.dataframe.transforms.pivot.GroupConfig;
import org.elasticsearch.xpack.dataframe.transforms.IDGenerator;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
@ -38,6 +45,7 @@ public final class AggregationResultUtils {
tempMap.put(SingleValue.class.getName(), new SingleValueAggExtractor());
tempMap.put(ScriptedMetric.class.getName(), new ScriptedMetricAggExtractor());
tempMap.put(GeoCentroid.class.getName(), new GeoCentroidAggExtractor());
tempMap.put(GeoBounds.class.getName(), new GeoBoundsAggExtractor());
TYPE_VALUE_EXTRACTOR_MAP = Collections.unmodifiableMap(tempMap);
}
@ -99,6 +107,8 @@ public final class AggregationResultUtils {
return TYPE_VALUE_EXTRACTOR_MAP.get(ScriptedMetric.class.getName());
} else if (aggregation instanceof GeoCentroid) {
return TYPE_VALUE_EXTRACTOR_MAP.get(GeoCentroid.class.getName());
} else if (aggregation instanceof GeoBounds) {
return TYPE_VALUE_EXTRACTOR_MAP.get(GeoBounds.class.getName());
} else {
// Execution should never reach this point!
// Creating transforms with unsupported aggregations shall not be possible
@ -155,11 +165,11 @@ public final class AggregationResultUtils {
}
}
private interface AggValueExtractor {
interface AggValueExtractor {
Object value(Aggregation aggregation, String fieldType);
}
private static class SingleValueAggExtractor implements AggValueExtractor {
static class SingleValueAggExtractor implements AggValueExtractor {
@Override
public Object value(Aggregation agg, String fieldType) {
SingleValue aggregation = (SingleValue)agg;
@ -178,7 +188,7 @@ public final class AggregationResultUtils {
}
}
private static class ScriptedMetricAggExtractor implements AggValueExtractor {
static class ScriptedMetricAggExtractor implements AggValueExtractor {
@Override
public Object value(Aggregation agg, String fieldType) {
ScriptedMetric aggregation = (ScriptedMetric)agg;
@ -186,7 +196,7 @@ public final class AggregationResultUtils {
}
}
private static class GeoCentroidAggExtractor implements AggValueExtractor {
static class GeoCentroidAggExtractor implements AggValueExtractor {
@Override
public Object value(Aggregation agg, String fieldType) {
GeoCentroid aggregation = (GeoCentroid)agg;
@ -194,4 +204,42 @@ public final class AggregationResultUtils {
return aggregation.count() > 0 ? aggregation.centroid().toString() : null;
}
}
static class GeoBoundsAggExtractor implements AggValueExtractor {
@Override
public Object value(Aggregation agg, String fieldType) {
GeoBounds aggregation = (GeoBounds)agg;
if (aggregation.bottomRight() == null || aggregation.topLeft() == null) {
return null;
}
final Map<String, Object> geoShape = new HashMap<>();
// If the two geo_points are equal, it is a point
if (aggregation.topLeft().equals(aggregation.bottomRight())) {
geoShape.put(ShapeParser.FIELD_TYPE.getPreferredName(), PointBuilder.TYPE.shapeName());
geoShape.put(ShapeParser.FIELD_COORDINATES.getPreferredName(),
Arrays.asList(aggregation.topLeft().getLon(), aggregation.bottomRight().getLat()));
// If only the lat or the lon of the two geo_points are equal, than we know it should be a line
} else if (Double.compare(aggregation.topLeft().getLat(), aggregation.bottomRight().getLat()) == 0
|| Double.compare(aggregation.topLeft().getLon(), aggregation.bottomRight().getLon()) == 0) {
geoShape.put(ShapeParser.FIELD_TYPE.getPreferredName(), LineStringBuilder.TYPE.shapeName());
geoShape.put(ShapeParser.FIELD_COORDINATES.getPreferredName(),
Arrays.asList(
new Double[]{aggregation.topLeft().getLon(), aggregation.topLeft().getLat()},
new Double[]{aggregation.bottomRight().getLon(), aggregation.bottomRight().getLat()}));
} else {
// neither points are equal, we have a polygon that is a square
geoShape.put(ShapeParser.FIELD_TYPE.getPreferredName(), PolygonBuilder.TYPE.shapeName());
final GeoPoint tl = aggregation.topLeft();
final GeoPoint br = aggregation.bottomRight();
geoShape.put(ShapeParser.FIELD_COORDINATES.getPreferredName(),
Collections.singletonList(Arrays.asList(
new Double[]{tl.getLon(), tl.getLat()},
new Double[]{br.getLon(), tl.getLat()},
new Double[]{br.getLon(), br.getLat()},
new Double[]{tl.getLon(), br.getLat()},
new Double[]{tl.getLon(), tl.getLat()})));
}
return geoShape;
}
}
}

View File

@ -36,6 +36,7 @@ public final class Aggregations {
MIN("min", SOURCE),
SUM("sum", "double"),
GEO_CENTROID("geo_centroid", "geo_point"),
GEO_BOUNDS("geo_bounds", "geo_shape"),
SCRIPTED_METRIC("scripted_metric", DYNAMIC),
WEIGHTED_AVG("weighted_avg", DYNAMIC),
BUCKET_SELECTOR("bucket_selector", DYNAMIC),

View File

@ -7,6 +7,7 @@
package org.elasticsearch.xpack.dataframe.transforms.pivot;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.geo.GeoPoint;
import org.elasticsearch.common.xcontent.ContextParser;
import org.elasticsearch.common.xcontent.DeprecationHandler;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
@ -31,8 +32,11 @@ import org.elasticsearch.search.aggregations.bucket.terms.StringTerms;
import org.elasticsearch.search.aggregations.metrics.AvgAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.CardinalityAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.ExtendedStatsAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.GeoBounds;
import org.elasticsearch.search.aggregations.metrics.GeoCentroid;
import org.elasticsearch.search.aggregations.metrics.MaxAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.MinAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
import org.elasticsearch.search.aggregations.metrics.ParsedAvg;
import org.elasticsearch.search.aggregations.metrics.ParsedCardinality;
import org.elasticsearch.search.aggregations.metrics.ParsedExtendedStats;
@ -42,6 +46,7 @@ import org.elasticsearch.search.aggregations.metrics.ParsedScriptedMetric;
import org.elasticsearch.search.aggregations.metrics.ParsedStats;
import org.elasticsearch.search.aggregations.metrics.ParsedSum;
import org.elasticsearch.search.aggregations.metrics.ParsedValueCount;
import org.elasticsearch.search.aggregations.metrics.ScriptedMetric;
import org.elasticsearch.search.aggregations.metrics.ScriptedMetricAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.StatsAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.SumAggregationBuilder;
@ -56,6 +61,7 @@ import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameIndexerTransfo
import org.elasticsearch.xpack.core.dataframe.transforms.pivot.GroupConfig;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
@ -67,6 +73,11 @@ import java.util.stream.Collectors;
import static java.util.Arrays.asList;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.hasItem;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.CoreMatchers.nullValue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class AggregationResultUtilsTests extends ESTestCase {
@ -781,6 +792,151 @@ public class AggregationResultUtilsTests extends ESTestCase {
equalTo("mixed object types of nested and non-nested fields [foo.bar]"));
}
private NumericMetricsAggregation.SingleValue createSingleMetricAgg(Double value, String valueAsString) {
NumericMetricsAggregation.SingleValue agg = mock(NumericMetricsAggregation.SingleValue.class);
when(agg.value()).thenReturn(value);
when(agg.getValueAsString()).thenReturn(valueAsString);
return agg;
}
public void testSingleValueAggExtractor() {
Aggregation agg = createSingleMetricAgg(Double.NaN, "NaN");
assertThat(AggregationResultUtils.getExtractor(agg).value(agg, "double"), is(nullValue()));
agg = createSingleMetricAgg(Double.POSITIVE_INFINITY, "NaN");
assertThat(AggregationResultUtils.getExtractor(agg).value(agg, "double"), is(nullValue()));
agg = createSingleMetricAgg(100.0, "100.0");
assertThat(AggregationResultUtils.getExtractor(agg).value(agg, "double"), equalTo(100.0));
agg = createSingleMetricAgg(100.0, "one_hundred");
assertThat(AggregationResultUtils.getExtractor(agg).value(agg, "double"), equalTo(100.0));
agg = createSingleMetricAgg(100.0, "one_hundred");
assertThat(AggregationResultUtils.getExtractor(agg).value(agg, "string"), equalTo("one_hundred"));
}
private ScriptedMetric createScriptedMetric(Object returnValue) {
ScriptedMetric agg = mock(ScriptedMetric.class);
when(agg.aggregation()).thenReturn(returnValue);
return agg;
}
@SuppressWarnings("unchecked")
public void testScriptedMetricAggExtractor() {
Aggregation agg = createScriptedMetric(null);
assertThat(AggregationResultUtils.getExtractor(agg).value(agg, "object"), is(nullValue()));
agg = createScriptedMetric(Collections.singletonList("values"));
Object val = AggregationResultUtils.getExtractor(agg).value(agg, "object");
assertThat((List<String>)val, hasItem("values"));
agg = createScriptedMetric(Collections.singletonMap("key", 100));
val = AggregationResultUtils.getExtractor(agg).value(agg, "object");
assertThat(((Map<String, Object>)val).get("key"), equalTo(100));
}
private GeoCentroid createGeoCentroid(GeoPoint point, long count) {
GeoCentroid agg = mock(GeoCentroid.class);
when(agg.centroid()).thenReturn(point);
when(agg.count()).thenReturn(count);
return agg;
}
public void testGeoCentroidAggExtractor() {
Aggregation agg = createGeoCentroid(null, 0);
assertThat(AggregationResultUtils.getExtractor(agg).value(agg, "geo_point"), is(nullValue()));
agg = createGeoCentroid(new GeoPoint(100.0, 101.0), 0);
assertThat(AggregationResultUtils.getExtractor(agg).value(agg, "geo_point"), is(nullValue()));
agg = createGeoCentroid(new GeoPoint(100.0, 101.0), randomIntBetween(1, 100));
assertThat(AggregationResultUtils.getExtractor(agg).value(agg, "geo_point"), equalTo("100.0, 101.0"));
}
private GeoBounds createGeoBounds(GeoPoint tl, GeoPoint br) {
GeoBounds agg = mock(GeoBounds.class);
when(agg.bottomRight()).thenReturn(br);
when(agg.topLeft()).thenReturn(tl);
return agg;
}
@SuppressWarnings("unchecked")
public void testGeoBoundsAggExtractor() {
final int numberOfRuns = 25;
Aggregation agg = createGeoBounds(null, new GeoPoint(100.0, 101.0));
assertThat(AggregationResultUtils.getExtractor(agg).value(agg, "geo_shape"), is(nullValue()));
agg = createGeoBounds(new GeoPoint(100.0, 101.0), null);
assertThat(AggregationResultUtils.getExtractor(agg).value(agg, "geo_shape"), is(nullValue()));
String type = "point";
for (int i = 0; i < numberOfRuns; i++) {
Map<String, Object> expectedObject = new HashMap<>();
expectedObject.put("type", type);
double lat = randomDoubleBetween(-90.0, 90.0, false);
double lon = randomDoubleBetween(-180.0, 180.0, false);
expectedObject.put("coordinates", Arrays.asList(lon, lat));
agg = createGeoBounds(new GeoPoint(lat, lon), new GeoPoint(lat, lon));
assertThat(AggregationResultUtils.getExtractor(agg).value(agg, "geo_shape"), equalTo(expectedObject));
}
type = "linestring";
for (int i = 0; i < numberOfRuns; i++) {
double lat = randomDoubleBetween(-90.0, 90.0, false);
double lon = randomDoubleBetween(-180.0, 180.0, false);
double lat2 = lat;
double lon2 = lon;
if (randomBoolean()) {
lat2 = randomDoubleBetween(-90.0, 90.0, false);
} else {
lon2 = randomDoubleBetween(-180.0, 180.0, false);
}
agg = createGeoBounds(new GeoPoint(lat, lon), new GeoPoint(lat2, lon2));
Object val = AggregationResultUtils.getExtractor(agg).value(agg, "geo_shape");
Map<String, Object> geoJson = (Map<String, Object>)val;
assertThat(geoJson.get("type"), equalTo(type));
List<Double[]> coordinates = (List<Double[]>)geoJson.get("coordinates");
for(Double[] coor : coordinates) {
assertThat(coor.length, equalTo(2));
}
assertThat(coordinates.get(0)[0], equalTo(lon));
assertThat(coordinates.get(0)[1], equalTo(lat));
assertThat(coordinates.get(1)[0], equalTo(lon2));
assertThat(coordinates.get(1)[1], equalTo(lat2));
}
type = "polygon";
for (int i = 0; i < numberOfRuns; i++) {
double lat = randomDoubleBetween(-90.0, 90.0, false);
double lon = randomDoubleBetween(-180.0, 180.0, false);
double lat2 = randomDoubleBetween(-90.0, 90.0, false);
double lon2 = randomDoubleBetween(-180.0, 180.0, false);
while (Double.compare(lat, lat2) == 0 || Double.compare(lon, lon2) == 0) {
lat2 = randomDoubleBetween(-90.0, 90.0, false);
lon2 = randomDoubleBetween(-180.0, 180.0, false);
}
agg = createGeoBounds(new GeoPoint(lat, lon), new GeoPoint(lat2, lon2));
Object val = AggregationResultUtils.getExtractor(agg).value(agg, "geo_shape");
Map<String, Object> geoJson = (Map<String, Object>)val;
assertThat(geoJson.get("type"), equalTo(type));
List<List<Double[]>> coordinates = (List<List<Double[]>>)geoJson.get("coordinates");
assertThat(coordinates.size(), equalTo(1));
assertThat(coordinates.get(0).size(), equalTo(5));
List<List<Double>> expected = Arrays.asList(
Arrays.asList(lon, lat),
Arrays.asList(lon2, lat),
Arrays.asList(lon2, lat2),
Arrays.asList(lon, lat2),
Arrays.asList(lon, lat));
for(int j = 0; j < 5; j++) {
Double[] coordinate = coordinates.get(0).get(j);
assertThat(coordinate.length, equalTo(2));
assertThat(coordinate[0], equalTo(expected.get(j).get(0)));
assertThat(coordinate[1], equalTo(expected.get(j).get(1)));
}
}
}
private void executeTest(GroupConfig groups,
Collection<AggregationBuilder> aggregationBuilders,

View File

@ -42,6 +42,10 @@ public class AggregationsTests extends ESTestCase {
assertEquals("geo_point", Aggregations.resolveTargetMapping("geo_centroid", "geo_point"));
assertEquals("geo_point", Aggregations.resolveTargetMapping("geo_centroid", null));
// geo_bounds
assertEquals("geo_shape", Aggregations.resolveTargetMapping("geo_bounds", "geo_shape"));
assertEquals("geo_shape", Aggregations.resolveTargetMapping("geo_bounds", null));
// scripted_metric
assertEquals("_dynamic", Aggregations.resolveTargetMapping("scripted_metric", null));
assertEquals("_dynamic", Aggregations.resolveTargetMapping("scripted_metric", "int"));