[ML] adds geo_centroid aggregation support to data frames (#42088) (#42094)

This commit is contained in:
Benjamin Trent 2019-05-17 16:51:05 -04:00 committed by GitHub
parent 076ca75ea5
commit f2447364fd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 66 additions and 2 deletions

View File

@ -418,6 +418,56 @@ public class DataFramePivotRestIT extends DataFrameRestTestCase {
assertEquals(3.878048780, actual.doubleValue(), 0.000001);
}
public void testPivotWithGeoCentroidAgg() throws Exception {
String transformId = "geoCentroidPivot";
String dataFrameIndex = "geo_centroid_pivot_reviews";
setupDataAccessRole(DATA_ACCESS_ROLE, REVIEWS_INDEX_NAME, dataFrameIndex);
final Request createDataframeTransformRequest = createRequestWithAuth("PUT", DATAFRAME_ENDPOINT + transformId,
BASIC_AUTH_VALUE_DATA_FRAME_ADMIN_WITH_SOME_DATA_ACCESS);
String config = "{"
+ " \"source\": {\"index\":\"" + REVIEWS_INDEX_NAME + "\"},"
+ " \"dest\": {\"index\":\"" + dataFrameIndex + "\"},";
config += " \"pivot\": {"
+ " \"group_by\": {"
+ " \"reviewer\": {"
+ " \"terms\": {"
+ " \"field\": \"user_id\""
+ " } } },"
+ " \"aggregations\": {"
+ " \"avg_rating\": {"
+ " \"avg\": {"
+ " \"field\": \"stars\""
+ " } },"
+ " \"location\": {"
+ " \"geo_centroid\": {\"field\": \"location\"}"
+ " } } }"
+ "}";
createDataframeTransformRequest.setJsonEntity(config);
Map<String, Object> createDataframeTransformResponse = entityAsMap(client().performRequest(createDataframeTransformRequest));
assertThat(createDataframeTransformResponse.get("acknowledged"), equalTo(Boolean.TRUE));
startAndWaitForTransform(transformId, dataFrameIndex, BASIC_AUTH_VALUE_DATA_FRAME_ADMIN_WITH_SOME_DATA_ACCESS);
assertTrue(indexExists(dataFrameIndex));
// we expect 27 documents as there shall be 27 user_id's
Map<String, Object> indexStats = getAsMap(dataFrameIndex + "/_stats");
assertEquals(27, XContentMapValues.extractValue("_all.total.docs.count", indexStats));
// get and check some users
Map<String, Object> searchResult = getAsMap(dataFrameIndex + "/_search?q=reviewer:user_4");
assertEquals(1, XContentMapValues.extractValue("hits.total.value", searchResult));
Number actual = (Number) ((List<?>) XContentMapValues.extractValue("hits.hits._source.avg_rating", searchResult)).get(0);
assertEquals(3.878048780, actual.doubleValue(), 0.000001);
String actualString = (String) ((List<?>) XContentMapValues.extractValue("hits.hits._source.location", searchResult)).get(0);
String[] latlon = actualString.split(",");
assertEquals((4 + 10), Double.valueOf(latlon[0]), 0.000001);
assertEquals((4 + 15), Double.valueOf(latlon[1]), 0.000001);
}
private void assertOnePivotValue(String query, double expected) throws IOException {
Map<String, Object> searchResult = getAsMap(query);

View File

@ -76,6 +76,9 @@ public abstract class DataFrameRestTestCase extends ESRestTestCase {
.startObject("stars")
.field("type", "integer")
.endObject()
.startObject("location")
.field("type", "geo_point")
.endObject()
.endObject()
.endObject();
}
@ -103,6 +106,7 @@ public abstract class DataFrameRestTestCase extends ESRestTestCase {
min = 10 + (i % 49);
}
int sec = 10 + (i % 49);
String location = (user + 10) + "," + (user + 15);
String date_string = "2017-01-" + day + "T" + hour + ":" + min + ":" + sec + "Z";
bulk.append("{\"user_id\":\"")
@ -113,7 +117,9 @@ public abstract class DataFrameRestTestCase extends ESRestTestCase {
.append(business)
.append("\",\"stars\":")
.append(stars)
.append(",\"timestamp\":\"")
.append(",\"location\":\"")
.append(location)
.append("\",\"timestamp\":\"")
.append(date_string)
.append("\"}\n");

View File

@ -13,6 +13,7 @@ import org.elasticsearch.search.aggregations.Aggregation;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.search.aggregations.bucket.composite.CompositeAggregation;
import org.elasticsearch.search.aggregations.metrics.GeoCentroid;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation.SingleValue;
import org.elasticsearch.search.aggregations.metrics.ScriptedMetric;
@ -84,6 +85,8 @@ public final class AggregationResultUtils {
}
} else if (aggResult instanceof ScriptedMetric) {
updateDocument(document, aggName, ((ScriptedMetric) aggResult).aggregation());
} else if (aggResult instanceof GeoCentroid) {
updateDocument(document, aggName, ((GeoCentroid) aggResult).centroid().toString());
} else {
// Execution should never reach this point!
// Creating transforms with unsupported aggregations shall not be possible

View File

@ -35,6 +35,7 @@ public final class Aggregations {
MAX("max", SOURCE),
MIN("min", SOURCE),
SUM("sum", SOURCE),
GEO_CENTROID("geo_centroid", "geo_point"),
SCRIPTED_METRIC("scripted_metric", DYNAMIC),
BUCKET_SCRIPT("bucket_script", DYNAMIC);

View File

@ -38,11 +38,15 @@ public class AggregationsTests extends ESTestCase {
assertEquals("double", Aggregations.resolveTargetMapping("sum", "double"));
assertEquals("half_float", Aggregations.resolveTargetMapping("sum", "half_float"));
// geo_centroid
assertEquals("geo_point", Aggregations.resolveTargetMapping("geo_centroid", "geo_point"));
assertEquals("geo_point", Aggregations.resolveTargetMapping("geo_centroid", null));
// scripted_metric
assertEquals("_dynamic", Aggregations.resolveTargetMapping("scripted_metric", null));
assertEquals("_dynamic", Aggregations.resolveTargetMapping("scripted_metric", "int"));
// scripted_metric
// bucket_script
assertEquals("_dynamic", Aggregations.resolveTargetMapping("bucket_script", null));
assertEquals("_dynamic", Aggregations.resolveTargetMapping("bucket_script", "int"));
}