From 03fbca3f500d8541b4b32c1456997a8493ebe4f5 Mon Sep 17 00:00:00 2001 From: Peng Huo Date: Thu, 21 Apr 2022 07:06:33 -0700 Subject: [PATCH] Add new multi_term aggregation (#2687) Adds a new multi_term aggregation. The current implementation focuses on adding new type aggregates. Performance (latency) is suboptimal in this iteration, mainly because of brute force encoding/decoding a list of values into bucket keys. A performance improvement change will be made as a follow on. Signed-off-by: Peng Huo --- .../client/RestHighLevelClient.java | 3 + .../search.aggregation/370_multi_terms.yml | 620 ++++++++++++ .../aggregations/bucket/MultiTermsIT.java | 167 ++++ .../bucket/terms/BaseStringTermsTestCase.java | 256 +++++ .../bucket/terms/StringTermsIT.java | 239 +---- .../org/opensearch/search/SearchModule.java | 9 + .../aggregations/AggregationBuilders.java | 8 + .../bucket/terms/InternalMultiTerms.java | 440 +++++++++ .../bucket/terms/InternalTerms.java | 59 +- .../terms/MultiTermsAggregationBuilder.java | 443 +++++++++ .../terms/MultiTermsAggregationFactory.java | 163 ++++ .../bucket/terms/MultiTermsAggregator.java | 438 +++++++++ .../bucket/terms/ParsedMultiTerms.java | 77 ++ .../bucket/terms/ParsedTerms.java | 7 +- .../BaseMultiValuesSourceFieldConfig.java | 216 +++++ .../support/MultiTermsValuesSourceConfig.java | 203 ++++ .../support/MultiValuesSourceFieldConfig.java | 160 +-- .../aggregations/AggregationsTests.java | 2 + .../bucket/terms/InternalMultiTermsTests.java | 116 +++ .../MultiTermsAggregationBuilderTests.java | 182 ++++ .../terms/MultiTermsAggregatorTests.java | 909 ++++++++++++++++++ .../MultiTermsValuesSourceConfigTests.java | 65 ++ .../test/InternalAggregationTestCase.java | 3 + 23 files changed, 4378 insertions(+), 407 deletions(-) create mode 100644 rest-api-spec/src/main/resources/rest-api-spec/test/search.aggregation/370_multi_terms.yml create mode 100644 server/src/internalClusterTest/java/org/opensearch/search/aggregations/bucket/MultiTermsIT.java create mode 100644 server/src/internalClusterTest/java/org/opensearch/search/aggregations/bucket/terms/BaseStringTermsTestCase.java create mode 100644 server/src/main/java/org/opensearch/search/aggregations/bucket/terms/InternalMultiTerms.java create mode 100644 server/src/main/java/org/opensearch/search/aggregations/bucket/terms/MultiTermsAggregationBuilder.java create mode 100644 server/src/main/java/org/opensearch/search/aggregations/bucket/terms/MultiTermsAggregationFactory.java create mode 100644 server/src/main/java/org/opensearch/search/aggregations/bucket/terms/MultiTermsAggregator.java create mode 100644 server/src/main/java/org/opensearch/search/aggregations/bucket/terms/ParsedMultiTerms.java create mode 100644 server/src/main/java/org/opensearch/search/aggregations/support/BaseMultiValuesSourceFieldConfig.java create mode 100644 server/src/main/java/org/opensearch/search/aggregations/support/MultiTermsValuesSourceConfig.java create mode 100644 server/src/test/java/org/opensearch/search/aggregations/bucket/terms/InternalMultiTermsTests.java create mode 100644 server/src/test/java/org/opensearch/search/aggregations/bucket/terms/MultiTermsAggregationBuilderTests.java create mode 100644 server/src/test/java/org/opensearch/search/aggregations/bucket/terms/MultiTermsAggregatorTests.java create mode 100644 server/src/test/java/org/opensearch/search/aggregations/support/MultiTermsValuesSourceConfigTests.java diff --git a/client/rest-high-level/src/main/java/org/opensearch/client/RestHighLevelClient.java b/client/rest-high-level/src/main/java/org/opensearch/client/RestHighLevelClient.java index 3eebb361fd9..e69ca149d69 100644 --- a/client/rest-high-level/src/main/java/org/opensearch/client/RestHighLevelClient.java +++ b/client/rest-high-level/src/main/java/org/opensearch/client/RestHighLevelClient.java @@ -139,7 +139,9 @@ import org.opensearch.search.aggregations.bucket.range.RangeAggregationBuilder; import org.opensearch.search.aggregations.bucket.sampler.InternalSampler; import org.opensearch.search.aggregations.bucket.sampler.ParsedSampler; import org.opensearch.search.aggregations.bucket.terms.LongRareTerms; +import org.opensearch.search.aggregations.bucket.terms.MultiTermsAggregationBuilder; import org.opensearch.search.aggregations.bucket.terms.ParsedLongRareTerms; +import org.opensearch.search.aggregations.bucket.terms.ParsedMultiTerms; import org.opensearch.search.aggregations.bucket.terms.ParsedSignificantLongTerms; import org.opensearch.search.aggregations.bucket.terms.ParsedSignificantStringTerms; import org.opensearch.search.aggregations.bucket.terms.ParsedStringRareTerms; @@ -2140,6 +2142,7 @@ public class RestHighLevelClient implements Closeable { map.put(IpRangeAggregationBuilder.NAME, (p, c) -> ParsedBinaryRange.fromXContent(p, (String) c)); map.put(TopHitsAggregationBuilder.NAME, (p, c) -> ParsedTopHits.fromXContent(p, (String) c)); map.put(CompositeAggregationBuilder.NAME, (p, c) -> ParsedComposite.fromXContent(p, (String) c)); + map.put(MultiTermsAggregationBuilder.NAME, (p, c) -> ParsedMultiTerms.fromXContent(p, (String) c)); List entries = map.entrySet() .stream() .map(entry -> new NamedXContentRegistry.Entry(Aggregation.class, new ParseField(entry.getKey()), entry.getValue())) diff --git a/rest-api-spec/src/main/resources/rest-api-spec/test/search.aggregation/370_multi_terms.yml b/rest-api-spec/src/main/resources/rest-api-spec/test/search.aggregation/370_multi_terms.yml new file mode 100644 index 00000000000..a0e4762ea9b --- /dev/null +++ b/rest-api-spec/src/main/resources/rest-api-spec/test/search.aggregation/370_multi_terms.yml @@ -0,0 +1,620 @@ +setup: + - do: + indices.create: + index: test_1 + body: + settings: + number_of_shards: 1 + number_of_replicas: 0 + mappings: + properties: + str: + type: keyword + ip: + type: ip + boolean: + type: boolean + integer: + type: long + double: + type: double + number: + type: long + date: + type: date + + - do: + indices.create: + index: test_2 + body: + settings: + number_of_shards: 2 + number_of_replicas: 0 + mappings: + properties: + str: + type: keyword + integer: + type: long + boolean: + type: boolean + + - do: + cluster.health: + wait_for_status: green + +--- +"Basic test": + - skip: + version: "- 2.9.99" + reason: multi_terms aggregation is introduced in 3.0.0 + + - do: + bulk: + index: test_1 + refresh: true + body: + - '{"index": {}}' + - '{"str": "a", "integer": 1}' + - '{"index": {}}' + - '{"str": "a", "integer": 2}' + - '{"index": {}}' + - '{"str": "b", "integer": 1}' + - '{"index": {}}' + - '{"str": "b", "integer": 2}' + - '{"index": {}}' + - '{"str": "a", "integer": 1}' + - '{"index": {}}' + - '{"str": "b", "integer": 1}' + + - do: + search: + index: test_1 + size: 0 + body: + aggs: + m_terms: + multi_terms: + terms: + - field: str + - field: integer + + - length: { aggregations.m_terms.buckets: 4 } + - match: { aggregations.m_terms.buckets.0.key: ["a", 1] } + - match: { aggregations.m_terms.buckets.0.key_as_string: "a|1" } + - match: { aggregations.m_terms.buckets.0.doc_count: 2 } + - match: { aggregations.m_terms.buckets.1.key: ["b", 1] } + - match: { aggregations.m_terms.buckets.1.key_as_string: "b|1" } + - match: { aggregations.m_terms.buckets.1.doc_count: 2 } + - match: { aggregations.m_terms.buckets.2.key: ["a", 2] } + - match: { aggregations.m_terms.buckets.2.key_as_string: "a|2" } + - match: { aggregations.m_terms.buckets.2.doc_count: 1 } + - match: { aggregations.m_terms.buckets.3.key: ["b", 2] } + - match: { aggregations.m_terms.buckets.3.key_as_string: "b|2" } + - match: { aggregations.m_terms.buckets.3.doc_count: 1 } + +--- +"IP test": + - skip: + version: "- 2.9.99" + reason: multi_terms aggregation is introduced in 3.0.0 + + - do: + bulk: + index: test_1 + refresh: true + body: + - '{"index": {}}' + - '{"str": "a", "ip": "::1"}' + - '{"index": {}}' + - '{"str": "a", "ip": "127.0.0.1"}' + - '{"index": {}}' + - '{"str": "b", "ip": "::1"}' + - '{"index": {}}' + - '{"str": "b", "ip": "127.0.0.1"}' + - '{"index": {}}' + - '{"str": "a", "ip": "127.0.0.1"}' + - '{"index": {}}' + - '{"str": "b", "ip": "::1"}' + + - do: + search: + index: test_1 + size: 0 + body: + aggs: + m_terms: + multi_terms: + terms: + - field: str + - field: ip + + - length: { aggregations.m_terms.buckets: 4 } + - match: { aggregations.m_terms.buckets.0.key: ["a", "127.0.0.1"] } + - match: { aggregations.m_terms.buckets.0.key_as_string: "a|127.0.0.1" } + - match: { aggregations.m_terms.buckets.0.doc_count: 2 } + - match: { aggregations.m_terms.buckets.1.key: ["b", "::1"] } + - match: { aggregations.m_terms.buckets.1.key_as_string: "b|::1" } + - match: { aggregations.m_terms.buckets.1.doc_count: 2 } + - match: { aggregations.m_terms.buckets.2.key: ["a", "::1"] } + - match: { aggregations.m_terms.buckets.2.key_as_string: "a|::1" } + - match: { aggregations.m_terms.buckets.2.doc_count: 1 } + - match: { aggregations.m_terms.buckets.3.key: ["b", "127.0.0.1"] } + - match: { aggregations.m_terms.buckets.3.key_as_string: "b|127.0.0.1" } + - match: { aggregations.m_terms.buckets.3.doc_count: 1 } + +--- +"Boolean test": + - skip: + version: "- 2.9.99" + reason: multi_terms aggregation is introduced in 3.0.0 + + - do: + bulk: + index: test_1 + refresh: true + body: + - '{"index": {}}' + - '{"str": "a", "boolean": true}' + - '{"index": {}}' + - '{"str": "a", "boolean": false}' + - '{"index": {}}' + - '{"str": "b", "boolean": false}' + - '{"index": {}}' + - '{"str": "b", "boolean": true}' + - '{"index": {}}' + - '{"str": "a", "boolean": true}' + - '{"index": {}}' + - '{"str": "b", "boolean": false}' + + - do: + search: + index: test_1 + size: 0 + body: + aggs: + m_terms: + multi_terms: + terms: + - field: str + - field: boolean + + - length: { aggregations.m_terms.buckets: 4 } + - match: { aggregations.m_terms.buckets.0.key: ["a", true] } + - match: { aggregations.m_terms.buckets.0.key_as_string: "a|true" } + - match: { aggregations.m_terms.buckets.0.doc_count: 2 } + - match: { aggregations.m_terms.buckets.1.key: ["b", false] } + - match: { aggregations.m_terms.buckets.1.key_as_string: "b|false" } + - match: { aggregations.m_terms.buckets.1.doc_count: 2 } + - match: { aggregations.m_terms.buckets.2.key: ["a", false] } + - match: { aggregations.m_terms.buckets.2.key_as_string: "a|false" } + - match: { aggregations.m_terms.buckets.2.doc_count: 1 } + - match: { aggregations.m_terms.buckets.3.key: ["b", true] } + - match: { aggregations.m_terms.buckets.3.key_as_string: "b|true" } + - match: { aggregations.m_terms.buckets.3.doc_count: 1 } + +--- +"Double test": + - skip: + version: "- 2.9.99" + reason: multi_terms aggregation is introduced in 3.0.0 + + - do: + bulk: + index: test_1 + refresh: true + body: + - '{"index": {}}' + - '{"str": "a", "double": 1234.5}' + - '{"index": {}}' + - '{"str": "a", "double": 5678.5}' + - '{"index": {}}' + - '{"str": "b", "double": 1234.5}' + - '{"index": {}}' + - '{"str": "a", "double": 1234.5}' + + - do: + search: + index: test_1 + size: 0 + body: + aggs: + m_terms: + multi_terms: + terms: + - field: str + - field: double + + - length: { aggregations.m_terms.buckets: 3 } + - match: { aggregations.m_terms.buckets.0.key: ["a", 1234.5] } + - match: { aggregations.m_terms.buckets.0.key_as_string: "a|1234.5" } + - match: { aggregations.m_terms.buckets.0.doc_count: 2 } + - match: { aggregations.m_terms.buckets.1.key: ["a", 5678.5] } + - match: { aggregations.m_terms.buckets.1.key_as_string: "a|5678.5" } + - match: { aggregations.m_terms.buckets.1.doc_count: 1 } + - match: { aggregations.m_terms.buckets.2.key: ["b", 1234.5] } + - match: { aggregations.m_terms.buckets.2.key_as_string: "b|1234.5" } + - match: { aggregations.m_terms.buckets.2.doc_count: 1 } + +--- +"Date test": + - skip: + version: "- 2.9.99" + reason: multi_terms aggregation is introduced in 3.0.0 + + - do: + bulk: + index: test_1 + refresh: true + body: + - '{"index": {}}' + - '{"str": "a", "date": "2022-03-23"}' + - '{"index": {}}' + - '{"str": "a", "date": "2022-03-25"}' + - '{"index": {}}' + - '{"str": "b", "date": "2022-03-23"}' + - '{"index": {}}' + - '{"str": "a", "date": "2022-03-23"}' + + - do: + search: + index: test_1 + size: 0 + body: + aggs: + m_terms: + multi_terms: + terms: + - field: str + - field: date + + - length: { aggregations.m_terms.buckets: 3 } + - match: { aggregations.m_terms.buckets.0.key: ["a", "2022-03-23T00:00:00.000Z"] } + - match: { aggregations.m_terms.buckets.0.key_as_string: "a|2022-03-23T00:00:00.000Z" } + - match: { aggregations.m_terms.buckets.0.doc_count: 2 } + - match: { aggregations.m_terms.buckets.1.key: ["a", "2022-03-25T00:00:00.000Z"] } + - match: { aggregations.m_terms.buckets.1.key_as_string: "a|2022-03-25T00:00:00.000Z" } + - match: { aggregations.m_terms.buckets.1.doc_count: 1 } + - match: { aggregations.m_terms.buckets.2.key: ["b", "2022-03-23T00:00:00.000Z"] } + - match: { aggregations.m_terms.buckets.2.key_as_string: "b|2022-03-23T00:00:00.000Z" } + - match: { aggregations.m_terms.buckets.2.doc_count: 1 } + +--- +"Unmapped keywords": + - skip: + version: "- 2.9.99" + reason: multi_terms aggregation is introduced in 3.0.0 + + - do: + bulk: + index: test_1 + refresh: true + body: + - '{"index": {}}' + - '{"str": "a", "integer": 1}' + - '{"index": {}}' + - '{"str": "a", "integer": 2}' + - '{"index": {}}' + - '{"str": "b", "integer": 1}' + + - do: + search: + index: test_1 + size: 0 + body: + aggs: + m_terms: + multi_terms: + terms: + - field: str + - field: unmapped_string + value_type: string + missing: abc + + - length: { aggregations.m_terms.buckets: 2 } + - match: { aggregations.m_terms.buckets.0.key: ["a", "abc"] } + - match: { aggregations.m_terms.buckets.0.key_as_string: "a|abc" } + - match: { aggregations.m_terms.buckets.0.doc_count: 2 } + - match: { aggregations.m_terms.buckets.1.key: ["b", "abc"] } + - match: { aggregations.m_terms.buckets.1.key_as_string: "b|abc" } + - match: { aggregations.m_terms.buckets.1.doc_count: 1 } + +--- +"Null value": + - skip: + version: "- 2.9.99" + reason: multi_terms aggregation is introduced in 3.0.0 + + - do: + bulk: + index: test_1 + refresh: true + body: + - '{"index": {}}' + - '{"str": "a", "integer": null}' + - '{"index": {}}' + - '{"str": "a", "integer": 2}' + - '{"index": {}}' + - '{"str": null, "integer": 1}' + + - do: + search: + index: test_1 + size: 0 + body: + aggs: + m_terms: + multi_terms: + terms: + - field: str + - field: integer + + - length: { aggregations.m_terms.buckets: 1 } + - match: { aggregations.m_terms.buckets.0.key: ["a", 2] } + - match: { aggregations.m_terms.buckets.0.key_as_string: "a|2" } + - match: { aggregations.m_terms.buckets.0.doc_count: 1 } + +--- +"multiple multi_terms bucket": + - skip: + version: "- 2.9.99" + reason: multi_terms aggregation is introduced in 3.0.0 + + - do: + bulk: + index: test_1 + refresh: true + body: + - '{"index": {}}' + - '{"str": "a", "integer": 1, "double": 1234.5, "boolean": true}' + - '{"index": {}}' + - '{"str": "a", "integer": 1, "double": 5678.9, "boolean": false}' + - '{"index": {}}' + - '{"str": "a", "integer": 1, "double": 1234.5, "boolean": true}' + - '{"index": {}}' + - '{"str": "b", "integer": 1, "double": 1234.5, "boolean": true}' + + - do: + search: + index: test_1 + size: 0 + body: + aggs: + m_terms: + multi_terms: + terms: + - field: str + - field: integer + aggs: + n_terms: + multi_terms: + terms: + - field: double + - field: boolean + + - length: { aggregations.m_terms.buckets: 2 } + - match: { aggregations.m_terms.buckets.0.key: ["a", 1] } + - match: { aggregations.m_terms.buckets.0.key_as_string: "a|1" } + - match: { aggregations.m_terms.buckets.0.doc_count: 3 } + - match: { aggregations.m_terms.buckets.0.n_terms.buckets.0.key: [1234.5, true] } + - match: { aggregations.m_terms.buckets.0.n_terms.buckets.0.key_as_string: "1234.5|true" } + - match: { aggregations.m_terms.buckets.0.n_terms.buckets.0.doc_count: 2 } + - match: { aggregations.m_terms.buckets.0.n_terms.buckets.1.key: [5678.9, false] } + - match: { aggregations.m_terms.buckets.0.n_terms.buckets.1.key_as_string: "5678.9|false" } + - match: { aggregations.m_terms.buckets.0.n_terms.buckets.1.doc_count: 1 } + - match: { aggregations.m_terms.buckets.1.key: ["b", 1] } + - match: { aggregations.m_terms.buckets.1.key_as_string: "b|1" } + - match: { aggregations.m_terms.buckets.1.doc_count: 1 } + +--- +"ordered by metrics": + - skip: + version: "- 3.0.0" + reason: multi_terms aggregation is introduced in 3.0.0 + + - do: + bulk: + index: test_1 + refresh: true + body: + - '{"index": {}}' + - '{"str": "a", "double": 1234.5, "integer": 1}' + - '{"index": {}}' + - '{"str": "b", "double": 5678.9, "integer": 2}' + - '{"index": {}}' + - '{"str": "b", "double": 5678.9, "integer": 2}' + - '{"index": {}}' + - '{"str": "a", "double": 1234.5, "integer": 1}' + + - do: + search: + index: test_1 + size: 0 + body: + aggs: + m_terms: + multi_terms: + terms: + - field: str + - field: double + order: + the_int_sum: desc + aggs: + the_int_sum: + sum: + field: integer + + - length: { aggregations.m_terms.buckets: 2 } + - match: { aggregations.m_terms.buckets.0.key: ["b", 5678.9] } + - match: { aggregations.m_terms.buckets.0.key_as_string: "b|5678.9" } + - match: { aggregations.m_terms.buckets.0.the_int_sum.value: 4.0 } + - match: { aggregations.m_terms.buckets.0.doc_count: 2 } + - match: { aggregations.m_terms.buckets.1.key: ["a", 1234.5] } + - match: { aggregations.m_terms.buckets.1.key_as_string: "a|1234.5" } + - match: { aggregations.m_terms.buckets.1.the_int_sum.value: 2.0 } + - match: { aggregations.m_terms.buckets.1.doc_count: 2 } + +--- +"top 1 ordered by metrics ": + - skip: + version: "- 2.9.99" + reason: multi_terms aggregation is introduced in 3.0.0 + + - do: + bulk: + index: test_1 + refresh: true + body: + - '{"index": {}}' + - '{"str": "a", "double": 1234.5, "integer": 1}' + - '{"index": {}}' + - '{"str": "b", "double": 5678.9, "integer": 2}' + - '{"index": {}}' + - '{"str": "b", "double": 5678.9, "integer": 2}' + - '{"index": {}}' + - '{"str": "a", "double": 1234.5, "integer": 1}' + + - do: + search: + index: test_1 + size: 0 + body: + aggs: + m_terms: + multi_terms: + terms: + - field: str + - field: double + order: + the_int_sum: desc + size: 1 + aggs: + the_int_sum: + sum: + field: integer + + - length: { aggregations.m_terms.buckets: 1 } + - match: { aggregations.m_terms.buckets.0.key: ["b", 5678.9] } + - match: { aggregations.m_terms.buckets.0.key_as_string: "b|5678.9" } + - match: { aggregations.m_terms.buckets.0.the_int_sum.value: 4.0 } + - match: { aggregations.m_terms.buckets.0.doc_count: 2 } + +--- +"min_doc_count": + - skip: + version: "- 2.9.99" + reason: multi_terms aggregation is introduced in 3.0.0 + + - do: + bulk: + index: test_1 + refresh: true + body: + - '{"index": {}}' + - '{"str": "a", "integer": 1}' + - '{"index": {}}' + - '{"str": "a", "integer": 1}' + - '{"index": {}}' + - '{"str": "b", "integer": 1}' + - '{"index": {}}' + - '{"str": "c", "integer": 1}' + + - do: + search: + index: test_1 + body: + size: 0 + query: + simple_query_string: + fields: [str] + query: a b + minimum_should_match: 1 + aggs: + m_terms: + multi_terms: + terms: + - field: str + - field: integer + min_doc_count: 2 + + - length: { aggregations.m_terms.buckets: 1 } + - match: { aggregations.m_terms.buckets.0.key: ["a", 1] } + - match: { aggregations.m_terms.buckets.0.key_as_string: "a|1" } + - match: { aggregations.m_terms.buckets.0.doc_count: 2 } + + - do: + search: + index: test_1 + body: + size: 0 + query: + simple_query_string: + fields: [str] + query: a b + minimum_should_match: 1 + aggs: + m_terms: + multi_terms: + terms: + - field: str + - field: integer + min_doc_count: 0 + + - length: { aggregations.m_terms.buckets: 3 } + - match: { aggregations.m_terms.buckets.0.key: ["a", 1] } + - match: { aggregations.m_terms.buckets.0.key_as_string: "a|1" } + - match: { aggregations.m_terms.buckets.0.doc_count: 2 } + - match: { aggregations.m_terms.buckets.1.key: ["b", 1] } + - match: { aggregations.m_terms.buckets.1.key_as_string: "b|1" } + - match: { aggregations.m_terms.buckets.1.doc_count: 1 } + - match: { aggregations.m_terms.buckets.2.key: ["c", 1] } + - match: { aggregations.m_terms.buckets.2.key_as_string: "c|1" } + - match: { aggregations.m_terms.buckets.2.doc_count: 0 } + +--- +"sum_other_doc_count": + - skip: + version: "- 2.9.99" + reason: multi_terms aggregation is introduced in 3.0.0 + + - do: + bulk: + index: test_2 + refresh: true + body: + - '{"index": {"routing": "s1"}}' + - '{"str": "a", "integer": 1}' + - '{"index": {"routing": "s1"}}' + - '{"str": "a", "integer": 1}' + - '{"index": {"routing": "s1"}}' + - '{"str": "a", "integer": 1}' + - '{"index": {"routing": "s1"}}' + - '{"str": "a", "integer": 1}' + - '{"index": {"routing": "s2"}}' + - '{"str": "b", "integer": 1}' + - '{"index": {"routing": "s2"}}' + - '{"str": "b", "integer": 1}' + - '{"index": {"routing": "s2"}}' + - '{"str": "b", "integer": 1}' + - '{"index": {"routing": "s2"}}' + - '{"str": "a", "integer": 1}' + + - do: + search: + index: test_2 + size: 0 + body: + aggs: + m_terms: + multi_terms: + size: 1 + shard_size: 1 + terms: + - field: str + - field: integer + + - length: { aggregations.m_terms.buckets: 1 } + - match: { aggregations.m_terms.sum_other_doc_count: 4 } + - match: { aggregations.m_terms.buckets.0.key: ["a", 1] } + - match: { aggregations.m_terms.buckets.0.key_as_string: "a|1" } + - match: { aggregations.m_terms.buckets.0.doc_count: 4 } diff --git a/server/src/internalClusterTest/java/org/opensearch/search/aggregations/bucket/MultiTermsIT.java b/server/src/internalClusterTest/java/org/opensearch/search/aggregations/bucket/MultiTermsIT.java new file mode 100644 index 00000000000..7d7f80c8ac7 --- /dev/null +++ b/server/src/internalClusterTest/java/org/opensearch/search/aggregations/bucket/MultiTermsIT.java @@ -0,0 +1,167 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.aggregations.bucket; + +import org.opensearch.action.search.SearchResponse; +import org.opensearch.script.Script; +import org.opensearch.script.ScriptType; +import org.opensearch.search.aggregations.bucket.terms.BaseStringTermsTestCase; +import org.opensearch.search.aggregations.bucket.terms.StringTermsIT; +import org.opensearch.search.aggregations.bucket.terms.Terms; +import org.opensearch.search.aggregations.support.MultiTermsValuesSourceConfig; +import org.opensearch.search.aggregations.support.ValueType; +import org.opensearch.test.OpenSearchIntegTestCase; + +import java.util.Collections; + +import static java.util.Arrays.asList; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.core.IsNull.notNullValue; +import static org.opensearch.search.aggregations.AggregationBuilders.multiTerms; +import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertSearchResponse; + +/** + * Extend {@link BaseStringTermsTestCase}. + */ +@OpenSearchIntegTestCase.SuiteScopeTestCase +public class MultiTermsIT extends BaseStringTermsTestCase { + + // the main purpose of this test is to make sure we're not allocating 2GB of memory per shard + public void testSizeIsZero() { + final int minDocCount = randomInt(1); + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> client().prepareSearch("high_card_idx") + .addAggregation( + multiTerms("mterms").terms( + asList( + new MultiTermsValuesSourceConfig.Builder().setFieldName(SINGLE_VALUED_FIELD_NAME).build(), + new MultiTermsValuesSourceConfig.Builder().setFieldName(MULTI_VALUED_FIELD_NAME).build() + ) + ).minDocCount(minDocCount).size(0) + ) + .get() + ); + assertThat(exception.getMessage(), containsString("[size] must be greater than 0. Found [0] in [mterms]")); + } + + public void testSingleValuedFieldWithValueScript() throws Exception { + SearchResponse response = client().prepareSearch("idx") + .addAggregation( + multiTerms("mterms").terms( + asList( + new MultiTermsValuesSourceConfig.Builder().setFieldName("i").build(), + new MultiTermsValuesSourceConfig.Builder().setFieldName(SINGLE_VALUED_FIELD_NAME) + .setScript( + new Script( + ScriptType.INLINE, + StringTermsIT.CustomScriptPlugin.NAME, + "'foo_' + _value", + Collections.emptyMap() + ) + ) + .build() + ) + ) + ) + .get(); + + assertSearchResponse(response); + + Terms terms = response.getAggregations().get("mterms"); + assertThat(terms, notNullValue()); + assertThat(terms.getName(), equalTo("mterms")); + assertThat(terms.getBuckets().size(), equalTo(5)); + + for (int i = 0; i < 5; i++) { + Terms.Bucket bucket = terms.getBucketByKey(i + "|foo_val" + i); + assertThat(bucket, notNullValue()); + assertThat(key(bucket), equalTo(i + "|foo_val" + i)); + assertThat(bucket.getDocCount(), equalTo(1L)); + } + } + + public void testSingleValuedFieldWithScript() throws Exception { + SearchResponse response = client().prepareSearch("idx") + .addAggregation( + multiTerms("mterms").terms( + asList( + new MultiTermsValuesSourceConfig.Builder().setFieldName("i").build(), + new MultiTermsValuesSourceConfig.Builder().setScript( + new Script( + ScriptType.INLINE, + StringTermsIT.CustomScriptPlugin.NAME, + "doc['" + SINGLE_VALUED_FIELD_NAME + "'].value", + Collections.emptyMap() + ) + ).setUserValueTypeHint(ValueType.STRING).build() + ) + ) + ) + .get(); + + assertSearchResponse(response); + + Terms terms = response.getAggregations().get("mterms"); + assertThat(terms, notNullValue()); + assertThat(terms.getName(), equalTo("mterms")); + assertThat(terms.getBuckets().size(), equalTo(5)); + + for (int i = 0; i < 5; i++) { + Terms.Bucket bucket = terms.getBucketByKey(i + "|val" + i); + assertThat(bucket, notNullValue()); + assertThat(key(bucket), equalTo(i + "|val" + i)); + assertThat(bucket.getDocCount(), equalTo(1L)); + } + } + + public void testMultiValuedFieldWithValueScript() throws Exception { + SearchResponse response = client().prepareSearch("idx") + .addAggregation( + multiTerms("mterms").terms( + asList( + new MultiTermsValuesSourceConfig.Builder().setFieldName("tag").build(), + new MultiTermsValuesSourceConfig.Builder().setFieldName(MULTI_VALUED_FIELD_NAME) + .setScript( + new Script( + ScriptType.INLINE, + StringTermsIT.CustomScriptPlugin.NAME, + "_value.substring(0,3)", + Collections.emptyMap() + ) + ) + .build() + ) + ) + ) + .get(); + + assertSearchResponse(response); + + Terms terms = response.getAggregations().get("mterms"); + assertThat(terms, notNullValue()); + assertThat(terms.getName(), equalTo("mterms")); + assertThat(terms.getBuckets().size(), equalTo(2)); + + Terms.Bucket bucket = terms.getBucketByKey("more|val"); + assertThat(bucket, notNullValue()); + assertThat(key(bucket), equalTo("more|val")); + assertThat(bucket.getDocCount(), equalTo(3L)); + + bucket = terms.getBucketByKey("less|val"); + assertThat(bucket, notNullValue()); + assertThat(key(bucket), equalTo("less|val")); + assertThat(bucket.getDocCount(), equalTo(2L)); + } + + private MultiTermsValuesSourceConfig field(String name) { + return new MultiTermsValuesSourceConfig.Builder().setFieldName(name).build(); + } +} diff --git a/server/src/internalClusterTest/java/org/opensearch/search/aggregations/bucket/terms/BaseStringTermsTestCase.java b/server/src/internalClusterTest/java/org/opensearch/search/aggregations/bucket/terms/BaseStringTermsTestCase.java new file mode 100644 index 00000000000..7775618ba5b --- /dev/null +++ b/server/src/internalClusterTest/java/org/opensearch/search/aggregations/bucket/terms/BaseStringTermsTestCase.java @@ -0,0 +1,256 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.aggregations.bucket.terms; + +import org.junit.After; +import org.junit.Before; +import org.opensearch.action.index.IndexRequestBuilder; +import org.opensearch.common.Strings; +import org.opensearch.index.fielddata.ScriptDocValues; +import org.opensearch.plugins.Plugin; +import org.opensearch.search.aggregations.AggregationTestScriptsPlugin; +import org.opensearch.search.aggregations.bucket.AbstractTermsTestCase; +import org.opensearch.test.OpenSearchIntegTestCase; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Function; + +import static org.opensearch.common.xcontent.XContentFactory.jsonBuilder; +import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertAcked; + +@OpenSearchIntegTestCase.SuiteScopeTestCase +public class BaseStringTermsTestCase extends AbstractTermsTestCase { + + protected static final String SINGLE_VALUED_FIELD_NAME = "s_value"; + protected static final String MULTI_VALUED_FIELD_NAME = "s_values"; + protected static Map> expectedMultiSortBuckets; + + @Override + protected Collection> nodePlugins() { + return Collections.singleton(CustomScriptPlugin.class); + } + + @Before + public void randomizeOptimizations() { + TermsAggregatorFactory.COLLECT_SEGMENT_ORDS = randomBoolean(); + TermsAggregatorFactory.REMAP_GLOBAL_ORDS = randomBoolean(); + } + + @After + public void resetOptimizations() { + TermsAggregatorFactory.COLLECT_SEGMENT_ORDS = null; + TermsAggregatorFactory.REMAP_GLOBAL_ORDS = null; + } + + public static class CustomScriptPlugin extends AggregationTestScriptsPlugin { + + @Override + protected Map, Object>> pluginScripts() { + Map, Object>> scripts = super.pluginScripts(); + + scripts.put("'foo_' + _value", vars -> "foo_" + (String) vars.get("_value")); + scripts.put("_value.substring(0,3)", vars -> ((String) vars.get("_value")).substring(0, 3)); + + scripts.put("doc['" + MULTI_VALUED_FIELD_NAME + "']", vars -> { + Map doc = (Map) vars.get("doc"); + return doc.get(MULTI_VALUED_FIELD_NAME); + }); + + scripts.put("doc['" + SINGLE_VALUED_FIELD_NAME + "'].value", vars -> { + Map doc = (Map) vars.get("doc"); + ScriptDocValues.Strings value = (ScriptDocValues.Strings) doc.get(SINGLE_VALUED_FIELD_NAME); + return value.getValue(); + }); + + scripts.put("42", vars -> 42); + + return scripts; + } + + @Override + protected Map, Object>> nonDeterministicPluginScripts() { + Map, Object>> scripts = new HashMap<>(); + + scripts.put("Math.random()", vars -> randomDouble()); + + return scripts; + } + } + + @Override + public void setupSuiteScopeCluster() throws Exception { + assertAcked( + client().admin() + .indices() + .prepareCreate("idx") + .setMapping(SINGLE_VALUED_FIELD_NAME, "type=keyword", MULTI_VALUED_FIELD_NAME, "type=keyword", "tag", "type=keyword") + .get() + ); + List builders = new ArrayList<>(); + for (int i = 0; i < 5; i++) { + builders.add( + client().prepareIndex("idx") + .setSource( + jsonBuilder().startObject() + .field(SINGLE_VALUED_FIELD_NAME, "val" + i) + .field("i", i) + .field("constant", 1) + .field("tag", i < 5 / 2 + 1 ? "more" : "less") + .startArray(MULTI_VALUED_FIELD_NAME) + .value("val" + i) + .value("val" + (i + 1)) + .endArray() + .endObject() + ) + ); + } + + getMultiSortDocs(builders); + + assertAcked( + client().admin() + .indices() + .prepareCreate("high_card_idx") + .setMapping(SINGLE_VALUED_FIELD_NAME, "type=keyword", MULTI_VALUED_FIELD_NAME, "type=keyword", "tag", "type=keyword") + .get() + ); + for (int i = 0; i < 100; i++) { + builders.add( + client().prepareIndex("high_card_idx") + .setSource( + jsonBuilder().startObject() + .field(SINGLE_VALUED_FIELD_NAME, "val" + Strings.padStart(i + "", 3, '0')) + .startArray(MULTI_VALUED_FIELD_NAME) + .value("val" + Strings.padStart(i + "", 3, '0')) + .value("val" + Strings.padStart((i + 1) + "", 3, '0')) + .endArray() + .endObject() + ) + ); + } + prepareCreate("empty_bucket_idx").setMapping(SINGLE_VALUED_FIELD_NAME, "type=integer").get(); + + for (int i = 0; i < 2; i++) { + builders.add( + client().prepareIndex("empty_bucket_idx") + .setId("" + i) + .setSource(jsonBuilder().startObject().field(SINGLE_VALUED_FIELD_NAME, i * 2).endObject()) + ); + } + indexRandom(true, builders); + createIndex("idx_unmapped"); + ensureSearchable(); + } + + private void getMultiSortDocs(List builders) throws IOException { + expectedMultiSortBuckets = new HashMap<>(); + Map bucketProps = new HashMap<>(); + bucketProps.put("_term", "val1"); + bucketProps.put("_count", 3L); + bucketProps.put("avg_l", 1d); + bucketProps.put("sum_d", 6d); + expectedMultiSortBuckets.put((String) bucketProps.get("_term"), bucketProps); + bucketProps = new HashMap<>(); + bucketProps.put("_term", "val2"); + bucketProps.put("_count", 3L); + bucketProps.put("avg_l", 2d); + bucketProps.put("sum_d", 6d); + expectedMultiSortBuckets.put((String) bucketProps.get("_term"), bucketProps); + bucketProps = new HashMap<>(); + bucketProps.put("_term", "val3"); + bucketProps.put("_count", 2L); + bucketProps.put("avg_l", 3d); + bucketProps.put("sum_d", 3d); + expectedMultiSortBuckets.put((String) bucketProps.get("_term"), bucketProps); + bucketProps = new HashMap<>(); + bucketProps.put("_term", "val4"); + bucketProps.put("_count", 2L); + bucketProps.put("avg_l", 3d); + bucketProps.put("sum_d", 4d); + expectedMultiSortBuckets.put((String) bucketProps.get("_term"), bucketProps); + bucketProps = new HashMap<>(); + bucketProps.put("_term", "val5"); + bucketProps.put("_count", 2L); + bucketProps.put("avg_l", 5d); + bucketProps.put("sum_d", 3d); + expectedMultiSortBuckets.put((String) bucketProps.get("_term"), bucketProps); + bucketProps = new HashMap<>(); + bucketProps.put("_term", "val6"); + bucketProps.put("_count", 1L); + bucketProps.put("avg_l", 5d); + bucketProps.put("sum_d", 1d); + expectedMultiSortBuckets.put((String) bucketProps.get("_term"), bucketProps); + bucketProps = new HashMap<>(); + bucketProps.put("_term", "val7"); + bucketProps.put("_count", 1L); + bucketProps.put("avg_l", 5d); + bucketProps.put("sum_d", 1d); + expectedMultiSortBuckets.put((String) bucketProps.get("_term"), bucketProps); + + assertAcked( + client().admin() + .indices() + .prepareCreate("sort_idx") + .setMapping(SINGLE_VALUED_FIELD_NAME, "type=keyword", MULTI_VALUED_FIELD_NAME, "type=keyword", "tag", "type=keyword") + .get() + ); + for (int i = 1; i <= 3; i++) { + builders.add( + client().prepareIndex("sort_idx") + .setSource(jsonBuilder().startObject().field(SINGLE_VALUED_FIELD_NAME, "val1").field("l", 1).field("d", i).endObject()) + ); + builders.add( + client().prepareIndex("sort_idx") + .setSource(jsonBuilder().startObject().field(SINGLE_VALUED_FIELD_NAME, "val2").field("l", 2).field("d", i).endObject()) + ); + } + builders.add( + client().prepareIndex("sort_idx") + .setSource(jsonBuilder().startObject().field(SINGLE_VALUED_FIELD_NAME, "val3").field("l", 3).field("d", 1).endObject()) + ); + builders.add( + client().prepareIndex("sort_idx") + .setSource(jsonBuilder().startObject().field(SINGLE_VALUED_FIELD_NAME, "val3").field("l", 3).field("d", 2).endObject()) + ); + builders.add( + client().prepareIndex("sort_idx") + .setSource(jsonBuilder().startObject().field(SINGLE_VALUED_FIELD_NAME, "val4").field("l", 3).field("d", 1).endObject()) + ); + builders.add( + client().prepareIndex("sort_idx") + .setSource(jsonBuilder().startObject().field(SINGLE_VALUED_FIELD_NAME, "val4").field("l", 3).field("d", 3).endObject()) + ); + builders.add( + client().prepareIndex("sort_idx") + .setSource(jsonBuilder().startObject().field(SINGLE_VALUED_FIELD_NAME, "val5").field("l", 5).field("d", 1).endObject()) + ); + builders.add( + client().prepareIndex("sort_idx") + .setSource(jsonBuilder().startObject().field(SINGLE_VALUED_FIELD_NAME, "val5").field("l", 5).field("d", 2).endObject()) + ); + builders.add( + client().prepareIndex("sort_idx") + .setSource(jsonBuilder().startObject().field(SINGLE_VALUED_FIELD_NAME, "val6").field("l", 5).field("d", 1).endObject()) + ); + builders.add( + client().prepareIndex("sort_idx") + .setSource(jsonBuilder().startObject().field(SINGLE_VALUED_FIELD_NAME, "val7").field("l", 5).field("d", 1).endObject()) + ); + } + + protected String key(Terms.Bucket bucket) { + return bucket.getKeyAsString(); + } +} diff --git a/server/src/internalClusterTest/java/org/opensearch/search/aggregations/bucket/terms/StringTermsIT.java b/server/src/internalClusterTest/java/org/opensearch/search/aggregations/bucket/terms/StringTermsIT.java index 3190bcb72fc..64f81cdcdec 100644 --- a/server/src/internalClusterTest/java/org/opensearch/search/aggregations/bucket/terms/StringTermsIT.java +++ b/server/src/internalClusterTest/java/org/opensearch/search/aggregations/bucket/terms/StringTermsIT.java @@ -32,25 +32,19 @@ package org.opensearch.search.aggregations.bucket.terms; import org.opensearch.OpenSearchException; -import org.opensearch.action.index.IndexRequestBuilder; import org.opensearch.action.search.SearchPhaseExecutionException; import org.opensearch.action.search.SearchResponse; -import org.opensearch.common.Strings; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.XContentParseException; import org.opensearch.common.xcontent.XContentParser; import org.opensearch.common.xcontent.json.JsonXContent; -import org.opensearch.index.fielddata.ScriptDocValues; import org.opensearch.index.mapper.IndexFieldMapper; import org.opensearch.index.query.QueryBuilders; -import org.opensearch.plugins.Plugin; import org.opensearch.script.Script; import org.opensearch.script.ScriptType; import org.opensearch.search.aggregations.AggregationExecutionException; -import org.opensearch.search.aggregations.AggregationTestScriptsPlugin; import org.opensearch.search.aggregations.Aggregator.SubAggCollectionMode; import org.opensearch.search.aggregations.BucketOrder; -import org.opensearch.search.aggregations.bucket.AbstractTermsTestCase; import org.opensearch.search.aggregations.bucket.filter.Filter; import org.opensearch.search.aggregations.bucket.terms.Terms.Bucket; import org.opensearch.search.aggregations.metrics.Avg; @@ -60,23 +54,13 @@ import org.opensearch.search.aggregations.metrics.Sum; import org.opensearch.search.aggregations.support.ValueType; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.test.OpenSearchIntegTestCase; -import org.junit.After; -import org.junit.Before; -import java.io.IOException; -import java.util.ArrayList; import java.util.Arrays; -import java.util.Collection; import java.util.Collections; -import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; -import java.util.List; -import java.util.Map; import java.util.Set; -import java.util.function.Function; -import static org.opensearch.common.xcontent.XContentFactory.jsonBuilder; import static org.opensearch.index.query.QueryBuilders.termQuery; import static org.opensearch.search.aggregations.AggregationBuilders.avg; import static org.opensearch.search.aggregations.AggregationBuilders.extendedStats; @@ -93,228 +77,7 @@ import static org.hamcrest.Matchers.startsWith; import static org.hamcrest.core.IsNull.notNullValue; @OpenSearchIntegTestCase.SuiteScopeTestCase -public class StringTermsIT extends AbstractTermsTestCase { - - private static final String SINGLE_VALUED_FIELD_NAME = "s_value"; - private static final String MULTI_VALUED_FIELD_NAME = "s_values"; - private static Map> expectedMultiSortBuckets; - - @Override - protected Collection> nodePlugins() { - return Collections.singleton(CustomScriptPlugin.class); - } - - @Before - public void randomizeOptimizations() { - TermsAggregatorFactory.COLLECT_SEGMENT_ORDS = randomBoolean(); - TermsAggregatorFactory.REMAP_GLOBAL_ORDS = randomBoolean(); - } - - @After - public void resetOptimizations() { - TermsAggregatorFactory.COLLECT_SEGMENT_ORDS = null; - TermsAggregatorFactory.REMAP_GLOBAL_ORDS = null; - } - - public static class CustomScriptPlugin extends AggregationTestScriptsPlugin { - - @Override - protected Map, Object>> pluginScripts() { - Map, Object>> scripts = super.pluginScripts(); - - scripts.put("'foo_' + _value", vars -> "foo_" + (String) vars.get("_value")); - scripts.put("_value.substring(0,3)", vars -> ((String) vars.get("_value")).substring(0, 3)); - - scripts.put("doc['" + MULTI_VALUED_FIELD_NAME + "']", vars -> { - Map doc = (Map) vars.get("doc"); - return doc.get(MULTI_VALUED_FIELD_NAME); - }); - - scripts.put("doc['" + SINGLE_VALUED_FIELD_NAME + "'].value", vars -> { - Map doc = (Map) vars.get("doc"); - ScriptDocValues.Strings value = (ScriptDocValues.Strings) doc.get(SINGLE_VALUED_FIELD_NAME); - return value.getValue(); - }); - - scripts.put("42", vars -> 42); - - return scripts; - } - - @Override - protected Map, Object>> nonDeterministicPluginScripts() { - Map, Object>> scripts = new HashMap<>(); - - scripts.put("Math.random()", vars -> StringTermsIT.randomDouble()); - - return scripts; - } - } - - @Override - public void setupSuiteScopeCluster() throws Exception { - assertAcked( - client().admin() - .indices() - .prepareCreate("idx") - .setMapping(SINGLE_VALUED_FIELD_NAME, "type=keyword", MULTI_VALUED_FIELD_NAME, "type=keyword", "tag", "type=keyword") - .get() - ); - List builders = new ArrayList<>(); - for (int i = 0; i < 5; i++) { - builders.add( - client().prepareIndex("idx") - .setSource( - jsonBuilder().startObject() - .field(SINGLE_VALUED_FIELD_NAME, "val" + i) - .field("i", i) - .field("constant", 1) - .field("tag", i < 5 / 2 + 1 ? "more" : "less") - .startArray(MULTI_VALUED_FIELD_NAME) - .value("val" + i) - .value("val" + (i + 1)) - .endArray() - .endObject() - ) - ); - } - - getMultiSortDocs(builders); - - assertAcked( - client().admin() - .indices() - .prepareCreate("high_card_idx") - .setMapping(SINGLE_VALUED_FIELD_NAME, "type=keyword", MULTI_VALUED_FIELD_NAME, "type=keyword", "tag", "type=keyword") - .get() - ); - for (int i = 0; i < 100; i++) { - builders.add( - client().prepareIndex("high_card_idx") - .setSource( - jsonBuilder().startObject() - .field(SINGLE_VALUED_FIELD_NAME, "val" + Strings.padStart(i + "", 3, '0')) - .startArray(MULTI_VALUED_FIELD_NAME) - .value("val" + Strings.padStart(i + "", 3, '0')) - .value("val" + Strings.padStart((i + 1) + "", 3, '0')) - .endArray() - .endObject() - ) - ); - } - prepareCreate("empty_bucket_idx").setMapping(SINGLE_VALUED_FIELD_NAME, "type=integer").get(); - - for (int i = 0; i < 2; i++) { - builders.add( - client().prepareIndex("empty_bucket_idx") - .setId("" + i) - .setSource(jsonBuilder().startObject().field(SINGLE_VALUED_FIELD_NAME, i * 2).endObject()) - ); - } - indexRandom(true, builders); - createIndex("idx_unmapped"); - ensureSearchable(); - } - - private void getMultiSortDocs(List builders) throws IOException { - expectedMultiSortBuckets = new HashMap<>(); - Map bucketProps = new HashMap<>(); - bucketProps.put("_term", "val1"); - bucketProps.put("_count", 3L); - bucketProps.put("avg_l", 1d); - bucketProps.put("sum_d", 6d); - expectedMultiSortBuckets.put((String) bucketProps.get("_term"), bucketProps); - bucketProps = new HashMap<>(); - bucketProps.put("_term", "val2"); - bucketProps.put("_count", 3L); - bucketProps.put("avg_l", 2d); - bucketProps.put("sum_d", 6d); - expectedMultiSortBuckets.put((String) bucketProps.get("_term"), bucketProps); - bucketProps = new HashMap<>(); - bucketProps.put("_term", "val3"); - bucketProps.put("_count", 2L); - bucketProps.put("avg_l", 3d); - bucketProps.put("sum_d", 3d); - expectedMultiSortBuckets.put((String) bucketProps.get("_term"), bucketProps); - bucketProps = new HashMap<>(); - bucketProps.put("_term", "val4"); - bucketProps.put("_count", 2L); - bucketProps.put("avg_l", 3d); - bucketProps.put("sum_d", 4d); - expectedMultiSortBuckets.put((String) bucketProps.get("_term"), bucketProps); - bucketProps = new HashMap<>(); - bucketProps.put("_term", "val5"); - bucketProps.put("_count", 2L); - bucketProps.put("avg_l", 5d); - bucketProps.put("sum_d", 3d); - expectedMultiSortBuckets.put((String) bucketProps.get("_term"), bucketProps); - bucketProps = new HashMap<>(); - bucketProps.put("_term", "val6"); - bucketProps.put("_count", 1L); - bucketProps.put("avg_l", 5d); - bucketProps.put("sum_d", 1d); - expectedMultiSortBuckets.put((String) bucketProps.get("_term"), bucketProps); - bucketProps = new HashMap<>(); - bucketProps.put("_term", "val7"); - bucketProps.put("_count", 1L); - bucketProps.put("avg_l", 5d); - bucketProps.put("sum_d", 1d); - expectedMultiSortBuckets.put((String) bucketProps.get("_term"), bucketProps); - - assertAcked( - client().admin() - .indices() - .prepareCreate("sort_idx") - .setMapping(SINGLE_VALUED_FIELD_NAME, "type=keyword", MULTI_VALUED_FIELD_NAME, "type=keyword", "tag", "type=keyword") - .get() - ); - for (int i = 1; i <= 3; i++) { - builders.add( - client().prepareIndex("sort_idx") - .setSource(jsonBuilder().startObject().field(SINGLE_VALUED_FIELD_NAME, "val1").field("l", 1).field("d", i).endObject()) - ); - builders.add( - client().prepareIndex("sort_idx") - .setSource(jsonBuilder().startObject().field(SINGLE_VALUED_FIELD_NAME, "val2").field("l", 2).field("d", i).endObject()) - ); - } - builders.add( - client().prepareIndex("sort_idx") - .setSource(jsonBuilder().startObject().field(SINGLE_VALUED_FIELD_NAME, "val3").field("l", 3).field("d", 1).endObject()) - ); - builders.add( - client().prepareIndex("sort_idx") - .setSource(jsonBuilder().startObject().field(SINGLE_VALUED_FIELD_NAME, "val3").field("l", 3).field("d", 2).endObject()) - ); - builders.add( - client().prepareIndex("sort_idx") - .setSource(jsonBuilder().startObject().field(SINGLE_VALUED_FIELD_NAME, "val4").field("l", 3).field("d", 1).endObject()) - ); - builders.add( - client().prepareIndex("sort_idx") - .setSource(jsonBuilder().startObject().field(SINGLE_VALUED_FIELD_NAME, "val4").field("l", 3).field("d", 3).endObject()) - ); - builders.add( - client().prepareIndex("sort_idx") - .setSource(jsonBuilder().startObject().field(SINGLE_VALUED_FIELD_NAME, "val5").field("l", 5).field("d", 1).endObject()) - ); - builders.add( - client().prepareIndex("sort_idx") - .setSource(jsonBuilder().startObject().field(SINGLE_VALUED_FIELD_NAME, "val5").field("l", 5).field("d", 2).endObject()) - ); - builders.add( - client().prepareIndex("sort_idx") - .setSource(jsonBuilder().startObject().field(SINGLE_VALUED_FIELD_NAME, "val6").field("l", 5).field("d", 1).endObject()) - ); - builders.add( - client().prepareIndex("sort_idx") - .setSource(jsonBuilder().startObject().field(SINGLE_VALUED_FIELD_NAME, "val7").field("l", 5).field("d", 1).endObject()) - ); - } - - private String key(Terms.Bucket bucket) { - return bucket.getKeyAsString(); - } +public class StringTermsIT extends BaseStringTermsTestCase { // the main purpose of this test is to make sure we're not allocating 2GB of memory per shard public void testSizeIsZero() { diff --git a/server/src/main/java/org/opensearch/search/SearchModule.java b/server/src/main/java/org/opensearch/search/SearchModule.java index dc5309b50ab..bf0cc646d27 100644 --- a/server/src/main/java/org/opensearch/search/SearchModule.java +++ b/server/src/main/java/org/opensearch/search/SearchModule.java @@ -159,8 +159,11 @@ import org.opensearch.search.aggregations.bucket.sampler.InternalSampler; import org.opensearch.search.aggregations.bucket.sampler.SamplerAggregationBuilder; import org.opensearch.search.aggregations.bucket.sampler.UnmappedSampler; import org.opensearch.search.aggregations.bucket.terms.DoubleTerms; +import org.opensearch.search.aggregations.bucket.terms.InternalMultiTerms; import org.opensearch.search.aggregations.bucket.terms.LongRareTerms; import org.opensearch.search.aggregations.bucket.terms.LongTerms; +import org.opensearch.search.aggregations.bucket.terms.MultiTermsAggregationBuilder; +import org.opensearch.search.aggregations.bucket.terms.MultiTermsAggregationFactory; import org.opensearch.search.aggregations.bucket.terms.RareTermsAggregationBuilder; import org.opensearch.search.aggregations.bucket.terms.SignificantLongTerms; import org.opensearch.search.aggregations.bucket.terms.SignificantStringTerms; @@ -687,6 +690,12 @@ public class SearchModule { .setAggregatorRegistrar(CompositeAggregationBuilder::registerAggregators), builder ); + registerAggregation( + new AggregationSpec(MultiTermsAggregationBuilder.NAME, MultiTermsAggregationBuilder::new, MultiTermsAggregationBuilder.PARSER) + .addResultReader(InternalMultiTerms::new) + .setAggregatorRegistrar(MultiTermsAggregationFactory::registerAggregators), + builder + ); registerFromPlugin(plugins, SearchPlugin::getAggregations, (agg) -> this.registerAggregation(agg, builder)); // after aggs have been registered, see if there are any new VSTypes that need to be linked to core fields diff --git a/server/src/main/java/org/opensearch/search/aggregations/AggregationBuilders.java b/server/src/main/java/org/opensearch/search/aggregations/AggregationBuilders.java index 99a1107675e..69a9fd92ac4 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/AggregationBuilders.java +++ b/server/src/main/java/org/opensearch/search/aggregations/AggregationBuilders.java @@ -66,6 +66,7 @@ import org.opensearch.search.aggregations.bucket.range.RangeAggregationBuilder; import org.opensearch.search.aggregations.bucket.sampler.DiversifiedAggregationBuilder; import org.opensearch.search.aggregations.bucket.sampler.Sampler; import org.opensearch.search.aggregations.bucket.sampler.SamplerAggregationBuilder; +import org.opensearch.search.aggregations.bucket.terms.MultiTermsAggregationBuilder; import org.opensearch.search.aggregations.bucket.terms.SignificantTerms; import org.opensearch.search.aggregations.bucket.terms.SignificantTermsAggregationBuilder; import org.opensearch.search.aggregations.bucket.terms.SignificantTextAggregationBuilder; @@ -388,4 +389,11 @@ public class AggregationBuilders { public static CompositeAggregationBuilder composite(String name, List> sources) { return new CompositeAggregationBuilder(name, sources); } + + /** + * Create a new {@link MultiTermsAggregationBuilder} aggregation with the given name. + */ + public static MultiTermsAggregationBuilder multiTerms(String name) { + return new MultiTermsAggregationBuilder(name); + } } diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/InternalMultiTerms.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/InternalMultiTerms.java new file mode 100644 index 00000000000..fd1758d3ea8 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/InternalMultiTerms.java @@ -0,0 +1,440 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.aggregations.bucket.terms; + +import org.apache.lucene.util.BytesRef; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.common.xcontent.XContentBuilder; +import org.opensearch.search.DocValueFormat; +import org.opensearch.search.aggregations.AggregationExecutionException; +import org.opensearch.search.aggregations.Aggregations; +import org.opensearch.search.aggregations.BucketOrder; +import org.opensearch.search.aggregations.InternalAggregations; +import org.opensearch.search.aggregations.KeyComparable; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.function.Function; +import java.util.stream.Collectors; + +/** + * Result of the {@link MultiTermsAggregator}. + */ +public class InternalMultiTerms extends InternalTerms { + /** + * Internal Multi Terms Bucket. + */ + public static class Bucket extends InternalTerms.AbstractInternalBucket implements KeyComparable { + + protected long bucketOrd; + /** + * list of terms values. + */ + protected List termValues; + protected long docCount; + protected InternalAggregations aggregations; + protected boolean showDocCountError; + protected long docCountError; + /** + * A list of term's {@link DocValueFormat}. + */ + protected final List termFormats; + + private static final String PIPE = "|"; + + /** + * Create default {@link Bucket}. + */ + public static Bucket EMPTY(boolean showTermDocCountError, List formats) { + return new Bucket(null, 0, null, showTermDocCountError, 0, formats); + } + + public Bucket( + List values, + long docCount, + InternalAggregations aggregations, + boolean showDocCountError, + long docCountError, + List formats + ) { + this.termValues = values; + this.docCount = docCount; + this.aggregations = aggregations; + this.showDocCountError = showDocCountError; + this.docCountError = docCountError; + this.termFormats = formats; + } + + public Bucket(StreamInput in, List formats, boolean showDocCountError) throws IOException { + this.termValues = in.readList(StreamInput::readGenericValue); + this.docCount = in.readVLong(); + this.aggregations = InternalAggregations.readFrom(in); + this.showDocCountError = showDocCountError; + this.docCountError = -1; + if (showDocCountError) { + this.docCountError = in.readLong(); + } + this.termFormats = formats; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(CommonFields.KEY.getPreferredName(), getKey()); + builder.field(CommonFields.KEY_AS_STRING.getPreferredName(), getKeyAsString()); + builder.field(CommonFields.DOC_COUNT.getPreferredName(), getDocCount()); + if (showDocCountError) { + builder.field(DOC_COUNT_ERROR_UPPER_BOUND_FIELD_NAME.getPreferredName(), getDocCountError()); + } + aggregations.toXContentInternal(builder, params); + builder.endObject(); + return builder; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeCollection(termValues, StreamOutput::writeGenericValue); + out.writeVLong(docCount); + aggregations.writeTo(out); + if (showDocCountError) { + out.writeLong(docCountError); + } + } + + @Override + public List getKey() { + List keys = new ArrayList<>(termValues.size()); + for (int i = 0; i < termValues.size(); i++) { + keys.add(formatObject(termValues.get(i), termFormats.get(i))); + } + return keys; + } + + @Override + public String getKeyAsString() { + return getKey().stream().map(Object::toString).collect(Collectors.joining(PIPE)); + } + + @Override + public long getDocCount() { + return docCount; + } + + @Override + public Aggregations getAggregations() { + return aggregations; + } + + @Override + void setDocCountError(long docCountError) { + this.docCountError = docCountError; + } + + @Override + public void setDocCountError(Function updater) { + this.docCountError = updater.apply(this.docCountError); + } + + @Override + public boolean showDocCountError() { + return showDocCountError; + } + + @Override + public Number getKeyAsNumber() { + throw new IllegalArgumentException("getKeyAsNumber is not supported by [" + MultiTermsAggregationBuilder.NAME + "]"); + } + + @Override + public long getDocCountError() { + if (!showDocCountError) { + throw new IllegalStateException("show_terms_doc_count_error is false"); + } + return docCountError; + } + + @Override + public boolean equals(Object obj) { + if (obj == null || getClass() != obj.getClass()) { + return false; + } + Bucket other = (Bucket) obj; + if (showDocCountError && docCountError != other.docCountError) { + return false; + } + return termValues.equals(other.termValues) + && docCount == other.docCount + && aggregations.equals(other.aggregations) + && showDocCountError == other.showDocCountError; + } + + @Override + public int hashCode() { + return Objects.hash(termValues, docCount, aggregations, showDocCountError, showDocCountError ? docCountError : 0); + } + + @Override + public int compareKey(Bucket other) { + return new BucketComparator().compare(this.termValues, other.termValues); + } + + /** + * Visible for testing. + */ + protected static class BucketComparator implements Comparator> { + @SuppressWarnings({ "unchecked" }) + @Override + public int compare(List thisObjects, List thatObjects) { + if (thisObjects.size() != thatObjects.size()) { + throw new AggregationExecutionException( + "[" + MultiTermsAggregationBuilder.NAME + "] aggregations failed due to terms" + " size is different" + ); + } + for (int i = 0; i < thisObjects.size(); i++) { + final Object thisObject = thisObjects.get(i); + final Object thatObject = thatObjects.get(i); + int ret = ((Comparable) thisObject).compareTo(thatObject); + if (ret != 0) { + return ret; + } + } + return 0; + } + } + } + + private final int shardSize; + private final boolean showTermDocCountError; + private final long otherDocCount; + private final List termFormats; + private final List buckets; + private Map bucketMap; + + private long docCountError; + + public InternalMultiTerms( + String name, + BucketOrder reduceOrder, + BucketOrder order, + int requiredSize, + long minDocCount, + Map metadata, + int shardSize, + boolean showTermDocCountError, + long otherDocCount, + long docCountError, + List formats, + List buckets + ) { + super(name, reduceOrder, order, requiredSize, minDocCount, metadata); + this.shardSize = shardSize; + this.showTermDocCountError = showTermDocCountError; + this.otherDocCount = otherDocCount; + this.termFormats = formats; + this.buckets = buckets; + this.docCountError = docCountError; + } + + public InternalMultiTerms(StreamInput in) throws IOException { + super(in); + this.docCountError = in.readZLong(); + this.termFormats = in.readList(stream -> stream.readNamedWriteable(DocValueFormat.class)); + this.shardSize = readSize(in); + this.showTermDocCountError = in.readBoolean(); + this.otherDocCount = in.readVLong(); + this.buckets = in.readList(steam -> new Bucket(steam, termFormats, showTermDocCountError)); + } + + @Override + public String getWriteableName() { + return MultiTermsAggregationBuilder.NAME; + } + + @Override + public XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException { + return doXContentCommon(builder, params, docCountError, otherDocCount, buckets); + } + + @Override + public InternalMultiTerms create(List buckets) { + return new InternalMultiTerms( + name, + reduceOrder, + order, + requiredSize, + minDocCount, + metadata, + shardSize, + showTermDocCountError, + otherDocCount, + docCountError, + termFormats, + buckets + ); + } + + @Override + public Bucket createBucket(InternalAggregations aggregations, Bucket prototype) { + return new Bucket( + prototype.termValues, + prototype.docCount, + aggregations, + prototype.showDocCountError, + prototype.docCountError, + prototype.termFormats + ); + } + + @Override + protected void writeTermTypeInfoTo(StreamOutput out) throws IOException { + out.writeZLong(docCountError); + out.writeCollection(termFormats, StreamOutput::writeNamedWriteable); + writeSize(shardSize, out); + out.writeBoolean(showTermDocCountError); + out.writeVLong(otherDocCount); + out.writeList(buckets); + } + + @Override + public List getBuckets() { + return buckets; + } + + @Override + public Bucket getBucketByKey(String term) { + if (bucketMap == null) { + bucketMap = buckets.stream().collect(Collectors.toMap(InternalMultiTerms.Bucket::getKeyAsString, Function.identity())); + } + return bucketMap.get(term); + } + + @Override + public long getDocCountError() { + return docCountError; + } + + @Override + public long getSumOfOtherDocCounts() { + return otherDocCount; + } + + @Override + protected void setDocCountError(long docCountError) { + this.docCountError = docCountError; + } + + @Override + protected int getShardSize() { + return shardSize; + } + + @Override + protected InternalMultiTerms create( + String name, + List buckets, + BucketOrder reduceOrder, + long docCountError, + long otherDocCount + ) { + return new InternalMultiTerms( + name, + reduceOrder, + order, + requiredSize, + minDocCount, + metadata, + shardSize, + showTermDocCountError, + otherDocCount, + docCountError, + termFormats, + buckets + ); + } + + @Override + protected Bucket[] createBucketsArray(int size) { + return new Bucket[size]; + } + + @Override + Bucket createBucket(long docCount, InternalAggregations aggs, long docCountError, Bucket prototype) { + return new Bucket( + prototype.termValues, + docCount, + aggs, + prototype.showDocCountError, + prototype.docCountError, + prototype.termFormats + ); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) return true; + if (obj == null || getClass() != obj.getClass()) return false; + if (super.equals(obj) == false) return false; + InternalMultiTerms that = (InternalMultiTerms) obj; + + if (showTermDocCountError && docCountError != that.docCountError) { + return false; + } + return Objects.equals(buckets, that.buckets) + && Objects.equals(otherDocCount, that.otherDocCount) + && Objects.equals(showTermDocCountError, that.showTermDocCountError) + && Objects.equals(shardSize, that.shardSize) + && Objects.equals(docCountError, that.docCountError); + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), buckets, otherDocCount, showTermDocCountError, shardSize); + } + + /** + * Copy from InternalComposite + * + * Format obj using the provided {@link DocValueFormat}. + * If the format is equals to {@link DocValueFormat#RAW}, the object is returned as is + * for numbers and a string for {@link BytesRef}s. + */ + static Object formatObject(Object obj, DocValueFormat format) { + if (obj == null) { + return null; + } + if (obj.getClass() == BytesRef.class) { + BytesRef value = (BytesRef) obj; + if (format == DocValueFormat.RAW) { + return value.utf8ToString(); + } else { + return format.format(value); + } + } else if (obj.getClass() == Long.class) { + long value = (long) obj; + if (format == DocValueFormat.RAW) { + return value; + } else { + return format.format(value); + } + } else if (obj.getClass() == Double.class) { + double value = (double) obj; + if (format == DocValueFormat.RAW) { + return value; + } else { + return format.format(value); + } + } + return obj; + } +} diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/InternalTerms.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/InternalTerms.java index be397bcbb2f..8fae5720a90 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/InternalTerms.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/InternalTerms.java @@ -57,11 +57,12 @@ import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.function.Function; import static org.opensearch.search.aggregations.InternalOrder.isKeyAsc; import static org.opensearch.search.aggregations.InternalOrder.isKeyOrder; -public abstract class InternalTerms, B extends InternalTerms.Bucket> extends +public abstract class InternalTerms, B extends InternalTerms.AbstractInternalBucket> extends InternalMultiBucketAggregation implements Terms { @@ -69,10 +70,15 @@ public abstract class InternalTerms, B extends Int protected static final ParseField DOC_COUNT_ERROR_UPPER_BOUND_FIELD_NAME = new ParseField("doc_count_error_upper_bound"); protected static final ParseField SUM_OF_OTHER_DOC_COUNTS = new ParseField("sum_other_doc_count"); - public abstract static class Bucket> extends InternalMultiBucketAggregation.InternalBucket - implements - Terms.Bucket, - KeyComparable { + public abstract static class AbstractInternalBucket extends InternalMultiBucketAggregation.InternalBucket implements Terms.Bucket { + abstract void setDocCountError(long docCountError); + + abstract void setDocCountError(Function updater); + + abstract boolean showDocCountError(); + } + + public abstract static class Bucket> extends AbstractInternalBucket implements KeyComparable { /** * Reads a bucket. Should be a constructor reference. */ @@ -142,6 +148,21 @@ public abstract class InternalTerms, B extends Int return docCountError; } + @Override + public void setDocCountError(long docCountError) { + this.docCountError = docCountError; + } + + @Override + public void setDocCountError(Function updater) { + this.docCountError = updater.apply(this.docCountError); + } + + @Override + public boolean showDocCountError() { + return showDocCountError; + } + @Override public Aggregations getAggregations() { return aggregations; @@ -274,7 +295,7 @@ public abstract class InternalTerms, B extends Int } else { // otherwise use the doc count of the last term in the // aggregation - return terms.getBuckets().stream().mapToLong(Bucket::getDocCount).min().getAsLong(); + return terms.getBuckets().stream().mapToLong(MultiBucketsAggregation.Bucket::getDocCount).min().getAsLong(); } } else { return -1; @@ -393,7 +414,7 @@ public abstract class InternalTerms, B extends Int // for the existing error calculated in a previous reduce. // Note that if the error is unbounded (-1) this will be fixed // later in this method. - bucket.docCountError -= thisAggDocCountError; + bucket.setDocCountError(docCountError -> docCountError - thisAggDocCountError); } } @@ -419,11 +440,12 @@ public abstract class InternalTerms, B extends Int final BucketPriorityQueue ordered = new BucketPriorityQueue<>(size, order.comparator()); for (B bucket : reducedBuckets) { if (sumDocCountError == -1) { - bucket.docCountError = -1; + bucket.setDocCountError(-1); } else { - bucket.docCountError += sumDocCountError; + final long finalSumDocCountError = sumDocCountError; + bucket.setDocCountError(docCountError -> docCountError + finalSumDocCountError); } - if (bucket.docCount >= minDocCount) { + if (bucket.getDocCount() >= minDocCount) { B removed = ordered.insertWithOverflow(bucket); if (removed != null) { otherDocCount += removed.getDocCount(); @@ -448,9 +470,10 @@ public abstract class InternalTerms, B extends Int reduceContext.consumeBucketsAndMaybeBreak(1); list[i] = reducedBuckets.get(i); if (sumDocCountError == -1) { - list[i].docCountError = -1; + list[i].setDocCountError(-1); } else { - list[i].docCountError += sumDocCountError; + final long fSumDocCountError = sumDocCountError; + list[i].setDocCountError(docCountError -> docCountError + fSumDocCountError); } } } @@ -474,15 +497,15 @@ public abstract class InternalTerms, B extends Int long docCountError = 0; List aggregationsList = new ArrayList<>(buckets.size()); for (B bucket : buckets) { - docCount += bucket.docCount; + docCount += bucket.getDocCount(); if (docCountError != -1) { - if (bucket.docCountError == -1) { + if (bucket.showDocCountError() == false || bucket.getDocCountError() == -1) { docCountError = -1; } else { - docCountError += bucket.docCountError; + docCountError += bucket.getDocCountError(); } } - aggregationsList.add(bucket.aggregations); + aggregationsList.add((InternalAggregations) bucket.getAggregations()); } InternalAggregations aggs = InternalAggregations.reduce(aggregationsList, context); return createBucket(docCount, aggs, docCountError, buckets.get(0)); @@ -524,12 +547,12 @@ public abstract class InternalTerms, B extends Int Params params, long docCountError, long otherDocCount, - List buckets + List buckets ) throws IOException { builder.field(DOC_COUNT_ERROR_UPPER_BOUND_FIELD_NAME.getPreferredName(), docCountError); builder.field(SUM_OF_OTHER_DOC_COUNTS.getPreferredName(), otherDocCount); builder.startArray(CommonFields.BUCKETS.getPreferredName()); - for (Bucket bucket : buckets) { + for (AbstractInternalBucket bucket : buckets) { bucket.toXContent(builder, params); } builder.endArray(); diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/MultiTermsAggregationBuilder.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/MultiTermsAggregationBuilder.java new file mode 100644 index 00000000000..78be4f980bc --- /dev/null +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/MultiTermsAggregationBuilder.java @@ -0,0 +1,443 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.aggregations.bucket.terms; + +import org.opensearch.common.ParseField; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.ObjectParser; +import org.opensearch.common.xcontent.XContentBuilder; +import org.opensearch.index.query.QueryShardContext; +import org.opensearch.search.aggregations.AbstractAggregationBuilder; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.Aggregator; +import org.opensearch.search.aggregations.AggregatorFactories; +import org.opensearch.search.aggregations.AggregatorFactory; +import org.opensearch.search.aggregations.BucketOrder; +import org.opensearch.search.aggregations.InternalOrder; +import org.opensearch.search.aggregations.support.MultiTermsValuesSourceConfig; +import org.opensearch.search.aggregations.support.ValuesSourceRegistry; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder.DEFAULT_BUCKET_COUNT_THRESHOLDS; + +/** + * Multi-terms aggregation supports collecting terms from multiple fields in the same document. + * + *

+ * For example, using the multi-terms aggregation to group by two fields region, host, calculate max cpu, and sort by max cpu. + *

+ *
+ *   GET test_000001/_search
+ *   {
+ *     "size": 0,
+ *     "aggs": {
+ *       "hot": {
+ *         "multi_terms": {
+ *           "terms": [{
+ *             "field": "region"
+ *           },{
+ *             "field": "host"
+ *           }],
+ *           "order": {"max-cpu": "desc"}
+ *         },
+ *         "aggs": {
+ *           "max-cpu": { "max": { "field": "cpu" } }
+ *         }
+ *       }
+ *     }
+ *   }
+ * 
+ * + *

+ * The aggregation result contains + * - key: a list of value extract from multiple fields in the same doc. + *

+ *
+ *   {
+ *     "hot": {
+ *       "doc_count_error_upper_bound": 0,
+ *       "sum_other_doc_count": 0,
+ *       "buckets": [
+ *         {
+ *           "key": [
+ *             "dub",
+ *             "h1"
+ *           ],
+ *           "key_as_string": "dub|h1",
+ *           "doc_count": 2,
+ *           "max-cpu": {
+ *             "value": 90.0
+ *           }
+ *         },
+ *         {
+ *           "key": [
+ *             "dub",
+ *             "h2"
+ *           ],
+ *           "key_as_string": "dub|h2",
+ *           "doc_count": 2,
+ *           "max-cpu": {
+ *             "value": 70.0
+ *           }
+ *         }
+ *       ]
+ *     }
+ *   }
+ * 
+ * + *

+ * Notes: The current implementation focuses on adding new type aggregates. Performance (latency) is not good,mainly because of + * simply encoding/decoding a list of values as bucket keys. + *

+ */ +public class MultiTermsAggregationBuilder extends AbstractAggregationBuilder { + public static final String NAME = "multi_terms"; + public static final ObjectParser PARSER = ObjectParser.fromBuilder( + NAME, + MultiTermsAggregationBuilder::new + ); + + public static final ParseField TERMS_FIELD = new ParseField("terms"); + public static final ParseField SHARD_SIZE_FIELD_NAME = new ParseField("shard_size"); + public static final ParseField MIN_DOC_COUNT_FIELD_NAME = new ParseField("min_doc_count"); + public static final ParseField SHARD_MIN_DOC_COUNT_FIELD_NAME = new ParseField("shard_min_doc_count"); + public static final ParseField REQUIRED_SIZE_FIELD_NAME = new ParseField("size"); + public static final ParseField SHOW_TERM_DOC_COUNT_ERROR = new ParseField("show_term_doc_count_error"); + public static final ParseField ORDER_FIELD = new ParseField("order"); + + @Override + public String getType() { + return NAME; + } + + static { + final ObjectParser parser = MultiTermsValuesSourceConfig.PARSER.apply( + true, + true, + true, + true + ); + PARSER.declareObjectArray(MultiTermsAggregationBuilder::terms, (p, c) -> parser.parse(p, null).build(), TERMS_FIELD); + + PARSER.declareBoolean(MultiTermsAggregationBuilder::showTermDocCountError, SHOW_TERM_DOC_COUNT_ERROR); + + PARSER.declareInt(MultiTermsAggregationBuilder::shardSize, SHARD_SIZE_FIELD_NAME); + + PARSER.declareLong(MultiTermsAggregationBuilder::minDocCount, MIN_DOC_COUNT_FIELD_NAME); + + PARSER.declareLong(MultiTermsAggregationBuilder::shardMinDocCount, SHARD_MIN_DOC_COUNT_FIELD_NAME); + + PARSER.declareInt(MultiTermsAggregationBuilder::size, REQUIRED_SIZE_FIELD_NAME); + + PARSER.declareObjectArray(MultiTermsAggregationBuilder::order, (p, c) -> InternalOrder.Parser.parseOrderParam(p), ORDER_FIELD); + + PARSER.declareField( + MultiTermsAggregationBuilder::collectMode, + (p, c) -> Aggregator.SubAggCollectionMode.parse(p.text(), LoggingDeprecationHandler.INSTANCE), + Aggregator.SubAggCollectionMode.KEY, + ObjectParser.ValueType.STRING + ); + } + + public static final ValuesSourceRegistry.RegistryKey REGISTRY_KEY = + new ValuesSourceRegistry.RegistryKey<>( + MultiTermsAggregationBuilder.NAME, + MultiTermsAggregationFactory.InternalValuesSourceSupplier.class + ); + + private List terms; + + private BucketOrder order = BucketOrder.compound(BucketOrder.count(false)); // automatically adds tie-breaker key asc order + private Aggregator.SubAggCollectionMode collectMode = null; + private TermsAggregator.BucketCountThresholds bucketCountThresholds = new TermsAggregator.BucketCountThresholds( + DEFAULT_BUCKET_COUNT_THRESHOLDS + ); + private boolean showTermDocCountError = false; + + public MultiTermsAggregationBuilder(String name) { + super(name); + } + + protected MultiTermsAggregationBuilder( + MultiTermsAggregationBuilder clone, + AggregatorFactories.Builder factoriesBuilder, + Map metadata + ) { + super(clone, factoriesBuilder, metadata); + this.terms = new ArrayList<>(clone.terms); + this.order = clone.order; + this.collectMode = clone.collectMode; + this.bucketCountThresholds = new TermsAggregator.BucketCountThresholds(clone.bucketCountThresholds); + this.showTermDocCountError = clone.showTermDocCountError; + } + + @Override + protected AggregationBuilder shallowCopy(AggregatorFactories.Builder factoriesBuilder, Map metadata) { + return new MultiTermsAggregationBuilder(this, factoriesBuilder, metadata); + } + + /** + * Read from a stream. + */ + public MultiTermsAggregationBuilder(StreamInput in) throws IOException { + super(in); + terms = in.readList(MultiTermsValuesSourceConfig::new); + bucketCountThresholds = new TermsAggregator.BucketCountThresholds(in); + collectMode = in.readOptionalWriteable(Aggregator.SubAggCollectionMode::readFromStream); + order = InternalOrder.Streams.readOrder(in); + showTermDocCountError = in.readBoolean(); + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + out.writeList(terms); + bucketCountThresholds.writeTo(out); + out.writeOptionalWriteable(collectMode); + order.writeTo(out); + out.writeBoolean(showTermDocCountError); + } + + @Override + protected AggregatorFactory doBuild( + QueryShardContext queryShardContext, + AggregatorFactory parent, + AggregatorFactories.Builder subfactoriesBuilder + ) throws IOException { + return new MultiTermsAggregationFactory( + name, + queryShardContext, + parent, + subfactoriesBuilder, + metadata, + terms, + order, + collectMode, + bucketCountThresholds, + showTermDocCountError + ); + } + + @Override + protected XContentBuilder internalXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (terms != null) { + builder.field(TERMS_FIELD.getPreferredName(), terms); + } + bucketCountThresholds.toXContent(builder, params); + builder.field(SHOW_TERM_DOC_COUNT_ERROR.getPreferredName(), showTermDocCountError); + builder.field(ORDER_FIELD.getPreferredName()); + order.toXContent(builder, params); + if (collectMode != null) { + builder.field(Aggregator.SubAggCollectionMode.KEY.getPreferredName(), collectMode.parseField().getPreferredName()); + } + builder.endObject(); + return builder; + } + + /** + * Set the terms. + */ + public MultiTermsAggregationBuilder terms(List terms) { + if (terms == null) { + throw new IllegalArgumentException("[terms] must not be null. Found null terms in [" + name + "]"); + } + if (terms.size() < 2) { + throw new IllegalArgumentException( + "multi term aggregation must has at least 2 terms. Found [" + + terms.size() + + "] in" + + " [" + + name + + "]" + + (terms.size() == 1 ? " Use terms aggregation for single term aggregation" : "") + ); + } + this.terms = terms; + return this; + } + + /** + * Sets the size - indicating how many term buckets should be returned + * (defaults to 10) + */ + public MultiTermsAggregationBuilder size(int size) { + if (size <= 0) { + throw new IllegalArgumentException("[size] must be greater than 0. Found [" + size + "] in [" + name + "]"); + } + bucketCountThresholds.setRequiredSize(size); + return this; + } + + /** + * Returns the number of term buckets currently configured + */ + public int size() { + return bucketCountThresholds.getRequiredSize(); + } + + /** + * Sets the shard_size - indicating the number of term buckets each shard + * will return to the coordinating node (the node that coordinates the + * search execution). The higher the shard size is, the more accurate the + * results are. + */ + public MultiTermsAggregationBuilder shardSize(int shardSize) { + if (shardSize <= 0) { + throw new IllegalArgumentException("[shardSize] must be greater than 0. Found [" + shardSize + "] in [" + name + "]"); + } + bucketCountThresholds.setShardSize(shardSize); + return this; + } + + /** + * Returns the number of term buckets per shard that are currently configured + */ + public int shardSize() { + return bucketCountThresholds.getShardSize(); + } + + /** + * Set the minimum document count terms should have in order to appear in + * the response. + */ + public MultiTermsAggregationBuilder minDocCount(long minDocCount) { + if (minDocCount < 0) { + throw new IllegalArgumentException( + "[minDocCount] must be greater than or equal to 0. Found [" + minDocCount + "] in [" + name + "]" + ); + } + bucketCountThresholds.setMinDocCount(minDocCount); + return this; + } + + /** + * Returns the minimum document count required per term + */ + public long minDocCount() { + return bucketCountThresholds.getMinDocCount(); + } + + /** + * Set the minimum document count terms should have on the shard in order to + * appear in the response. + */ + public MultiTermsAggregationBuilder shardMinDocCount(long shardMinDocCount) { + if (shardMinDocCount < 0) { + throw new IllegalArgumentException( + "[shardMinDocCount] must be greater than or equal to 0. Found [" + shardMinDocCount + "] in [" + name + "]" + ); + } + bucketCountThresholds.setShardMinDocCount(shardMinDocCount); + return this; + } + + /** + * Returns the minimum document count required per term, per shard + */ + public long shardMinDocCount() { + return bucketCountThresholds.getShardMinDocCount(); + } + + /** Set a new order on this builder and return the builder so that calls + * can be chained. A tie-breaker may be added to avoid non-deterministic ordering. */ + public MultiTermsAggregationBuilder order(BucketOrder order) { + if (order == null) { + throw new IllegalArgumentException("[order] must not be null: [" + name + "]"); + } + if (order instanceof InternalOrder.CompoundOrder || InternalOrder.isKeyOrder(order)) { + this.order = order; // if order already contains a tie-breaker we are good to go + } else { // otherwise add a tie-breaker by using a compound order + this.order = BucketOrder.compound(order); + } + return this; + } + + /** + * Sets the order in which the buckets will be returned. A tie-breaker may be added to avoid non-deterministic + * ordering. + */ + public MultiTermsAggregationBuilder order(List orders) { + if (orders == null) { + throw new IllegalArgumentException("[orders] must not be null: [" + name + "]"); + } + // if the list only contains one order use that to avoid inconsistent xcontent + order(orders.size() > 1 ? BucketOrder.compound(orders) : orders.get(0)); + return this; + } + + /** + * Gets the order in which the buckets will be returned. + */ + public BucketOrder order() { + return order; + } + + /** + * Expert: set the collection mode. + */ + public MultiTermsAggregationBuilder collectMode(Aggregator.SubAggCollectionMode collectMode) { + if (collectMode == null) { + throw new IllegalArgumentException("[collectMode] must not be null: [" + name + "]"); + } + this.collectMode = collectMode; + return this; + } + + /** + * Expert: get the collection mode. + */ + public Aggregator.SubAggCollectionMode collectMode() { + return collectMode; + } + + /** + * Get whether doc count error will be return for individual terms + */ + public boolean showTermDocCountError() { + return showTermDocCountError; + } + + /** + * Set whether doc count error will be return for individual terms + */ + public MultiTermsAggregationBuilder showTermDocCountError(boolean showTermDocCountError) { + this.showTermDocCountError = showTermDocCountError; + return this; + } + + @Override + public BucketCardinality bucketCardinality() { + return BucketCardinality.MANY; + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), bucketCountThresholds, collectMode, order, showTermDocCountError); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) return true; + if (obj == null || getClass() != obj.getClass()) return false; + if (super.equals(obj) == false) return false; + MultiTermsAggregationBuilder other = (MultiTermsAggregationBuilder) obj; + return Objects.equals(terms, other.terms) + && Objects.equals(bucketCountThresholds, other.bucketCountThresholds) + && Objects.equals(collectMode, other.collectMode) + && Objects.equals(order, other.order) + && Objects.equals(showTermDocCountError, other.showTermDocCountError); + } +} diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/MultiTermsAggregationFactory.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/MultiTermsAggregationFactory.java new file mode 100644 index 00000000000..d5600bc030b --- /dev/null +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/MultiTermsAggregationFactory.java @@ -0,0 +1,163 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.aggregations.bucket.terms; + +import org.opensearch.common.collect.Tuple; +import org.opensearch.index.query.QueryShardContext; +import org.opensearch.search.DocValueFormat; +import org.opensearch.search.aggregations.Aggregator; +import org.opensearch.search.aggregations.AggregatorFactories; +import org.opensearch.search.aggregations.AggregatorFactory; +import org.opensearch.search.aggregations.BucketOrder; +import org.opensearch.search.aggregations.CardinalityUpperBound; +import org.opensearch.search.aggregations.InternalOrder; +import org.opensearch.search.aggregations.bucket.BucketUtils; +import org.opensearch.search.aggregations.support.CoreValuesSourceType; +import org.opensearch.search.aggregations.support.MultiTermsValuesSourceConfig; +import org.opensearch.search.aggregations.support.ValuesSource; +import org.opensearch.search.aggregations.support.ValuesSourceConfig; +import org.opensearch.search.aggregations.support.ValuesSourceRegistry; +import org.opensearch.search.internal.SearchContext; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import static org.opensearch.search.aggregations.bucket.terms.MultiTermsAggregationBuilder.REGISTRY_KEY; + +/** + * Factory of {@link MultiTermsAggregator}. + */ +public class MultiTermsAggregationFactory extends AggregatorFactory { + + private final List> configs; + private final List formats; + /** + * Fields inherent from Terms Aggregation Factory. + */ + private final BucketOrder order; + private final Aggregator.SubAggCollectionMode collectMode; + private final TermsAggregator.BucketCountThresholds bucketCountThresholds; + private final boolean showTermDocCountError; + + public static void registerAggregators(ValuesSourceRegistry.Builder builder) { + builder.register( + REGISTRY_KEY, + org.opensearch.common.collect.List.of(CoreValuesSourceType.BYTES, CoreValuesSourceType.IP), + config -> { + final IncludeExclude.StringFilter filter = config.v2() == null + ? null + : config.v2().convertToStringFilter(config.v1().format()); + return MultiTermsAggregator.InternalValuesSourceFactory.bytesValuesSource(config.v1().getValuesSource(), filter); + }, + true + ); + + builder.register( + REGISTRY_KEY, + org.opensearch.common.collect.List.of(CoreValuesSourceType.NUMERIC, CoreValuesSourceType.BOOLEAN, CoreValuesSourceType.DATE), + config -> { + ValuesSourceConfig valuesSourceConfig = config.v1(); + IncludeExclude includeExclude = config.v2(); + ValuesSource.Numeric valuesSource = ((ValuesSource.Numeric) valuesSourceConfig.getValuesSource()); + IncludeExclude.LongFilter longFilter = null; + if (valuesSource.isFloatingPoint()) { + if (includeExclude != null) { + longFilter = includeExclude.convertToDoubleFilter(); + } + return MultiTermsAggregator.InternalValuesSourceFactory.doubleValueSource(valuesSource, longFilter); + } else { + if (includeExclude != null) { + longFilter = includeExclude.convertToLongFilter(valuesSourceConfig.format()); + } + return MultiTermsAggregator.InternalValuesSourceFactory.longValuesSource(valuesSource, longFilter); + } + }, + true + ); + + builder.registerUsage(MultiTermsAggregationBuilder.NAME); + } + + public MultiTermsAggregationFactory( + String name, + QueryShardContext queryShardContext, + AggregatorFactory parent, + AggregatorFactories.Builder subFactoriesBuilder, + Map metadata, + List multiTermConfigs, + BucketOrder order, + Aggregator.SubAggCollectionMode collectMode, + TermsAggregator.BucketCountThresholds bucketCountThresholds, + boolean showTermDocCountError + ) throws IOException { + super(name, queryShardContext, parent, subFactoriesBuilder, metadata); + this.configs = multiTermConfigs.stream() + .map( + c -> new Tuple( + ValuesSourceConfig.resolveUnregistered( + queryShardContext, + c.getUserValueTypeHint(), + c.getFieldName(), + c.getScript(), + c.getMissing(), + c.getTimeZone(), + c.getFormat(), + CoreValuesSourceType.BYTES + ), + c.getIncludeExclude() + ) + ) + .collect(Collectors.toList()); + this.formats = this.configs.stream().map(c -> c.v1().format()).collect(Collectors.toList()); + this.order = order; + this.collectMode = collectMode; + this.bucketCountThresholds = bucketCountThresholds; + this.showTermDocCountError = showTermDocCountError; + } + + @Override + protected Aggregator createInternal( + SearchContext searchContext, + Aggregator parent, + CardinalityUpperBound cardinality, + Map metadata + ) throws IOException { + TermsAggregator.BucketCountThresholds bucketCountThresholds = new TermsAggregator.BucketCountThresholds(this.bucketCountThresholds); + if (InternalOrder.isKeyOrder(order) == false + && bucketCountThresholds.getShardSize() == TermsAggregationBuilder.DEFAULT_BUCKET_COUNT_THRESHOLDS.getShardSize()) { + // The user has not made a shardSize selection. Use default + // heuristic to avoid any wrong-ranking caused by distributed + // counting + bucketCountThresholds.setShardSize(BucketUtils.suggestShardSideQueueSize(bucketCountThresholds.getRequiredSize())); + } + bucketCountThresholds.ensureValidity(); + return new MultiTermsAggregator( + name, + factories, + showTermDocCountError, + configs.stream() + .map(config -> queryShardContext.getValuesSourceRegistry().getAggregator(REGISTRY_KEY, config.v1()).build(config)) + .collect(Collectors.toList()), + configs.stream().map(c -> c.v1().format()).collect(Collectors.toList()), + order, + collectMode, + bucketCountThresholds, + searchContext, + parent, + cardinality, + metadata + ); + } + + public interface InternalValuesSourceSupplier { + MultiTermsAggregator.InternalValuesSource build(Tuple config); + } +} diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/MultiTermsAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/MultiTermsAggregator.java new file mode 100644 index 00000000000..36bf710f743 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/MultiTermsAggregator.java @@ -0,0 +1,438 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.aggregations.bucket.terms; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.SortedNumericDocValues; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.BytesRefBuilder; +import org.apache.lucene.util.NumericUtils; +import org.apache.lucene.util.PriorityQueue; +import org.opensearch.ExceptionsHelper; +import org.opensearch.common.CheckedSupplier; +import org.opensearch.common.bytes.BytesArray; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.common.lease.Releasables; +import org.opensearch.index.fielddata.SortedBinaryDocValues; +import org.opensearch.index.fielddata.SortedNumericDoubleValues; +import org.opensearch.search.DocValueFormat; +import org.opensearch.search.aggregations.Aggregator; +import org.opensearch.search.aggregations.AggregatorFactories; +import org.opensearch.search.aggregations.BucketOrder; +import org.opensearch.search.aggregations.CardinalityUpperBound; +import org.opensearch.search.aggregations.InternalAggregation; +import org.opensearch.search.aggregations.InternalOrder; +import org.opensearch.search.aggregations.LeafBucketCollector; +import org.opensearch.search.aggregations.bucket.DeferableBucketAggregator; +import org.opensearch.search.aggregations.support.AggregationPath; +import org.opensearch.search.aggregations.support.ValuesSource; +import org.opensearch.search.internal.SearchContext; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.opensearch.search.aggregations.InternalOrder.isKeyOrder; +import static org.opensearch.search.aggregations.bucket.terms.TermsAggregator.descendsFromNestedAggregator; + +/** + * An aggregator that aggregate with multi_terms. + */ +public class MultiTermsAggregator extends DeferableBucketAggregator { + + private final BytesKeyedBucketOrds bucketOrds; + private final MultiTermsValuesSource multiTermsValue; + private final boolean showTermDocCountError; + private final List formats; + private final TermsAggregator.BucketCountThresholds bucketCountThresholds; + private final BucketOrder order; + private final Comparator partiallyBuiltBucketComparator; + private final SubAggCollectionMode collectMode; + private final Set aggsUsedForSorting = new HashSet<>(); + + public MultiTermsAggregator( + String name, + AggregatorFactories factories, + boolean showTermDocCountError, + List internalValuesSources, + List formats, + BucketOrder order, + SubAggCollectionMode collectMode, + TermsAggregator.BucketCountThresholds bucketCountThresholds, + SearchContext context, + Aggregator parent, + CardinalityUpperBound cardinality, + Map metadata + ) throws IOException { + super(name, factories, context, parent, metadata); + this.bucketOrds = BytesKeyedBucketOrds.build(context.bigArrays(), cardinality); + this.multiTermsValue = new MultiTermsValuesSource(internalValuesSources); + this.showTermDocCountError = showTermDocCountError; + this.formats = formats; + this.bucketCountThresholds = bucketCountThresholds; + this.order = order; + this.partiallyBuiltBucketComparator = order == null ? null : order.partiallyBuiltBucketComparator(b -> b.bucketOrd, this); + // Todo, copy from TermsAggregator. need to remove duplicate code. + if (subAggsNeedScore() && descendsFromNestedAggregator(parent)) { + /** + * Force the execution to depth_first because we need to access the score of + * nested documents in a sub-aggregation and we are not able to generate this score + * while replaying deferred documents. + */ + this.collectMode = SubAggCollectionMode.DEPTH_FIRST; + } else { + this.collectMode = collectMode; + } + // Don't defer any child agg if we are dependent on it for pruning results + if (order instanceof InternalOrder.Aggregation) { + AggregationPath path = ((InternalOrder.Aggregation) order).path(); + aggsUsedForSorting.add(path.resolveTopmostAggregator(this)); + } else if (order instanceof InternalOrder.CompoundOrder) { + InternalOrder.CompoundOrder compoundOrder = (InternalOrder.CompoundOrder) order; + for (BucketOrder orderElement : compoundOrder.orderElements()) { + if (orderElement instanceof InternalOrder.Aggregation) { + AggregationPath path = ((InternalOrder.Aggregation) orderElement).path(); + aggsUsedForSorting.add(path.resolveTopmostAggregator(this)); + } + } + } + } + + @Override + public InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws IOException { + InternalMultiTerms.Bucket[][] topBucketsPerOrd = new InternalMultiTerms.Bucket[owningBucketOrds.length][]; + long[] otherDocCounts = new long[owningBucketOrds.length]; + for (int ordIdx = 0; ordIdx < owningBucketOrds.length; ordIdx++) { + collectZeroDocEntriesIfNeeded(owningBucketOrds[ordIdx]); + long bucketsInOrd = bucketOrds.bucketsInOrd(owningBucketOrds[ordIdx]); + + int size = (int) Math.min(bucketsInOrd, bucketCountThresholds.getShardSize()); + PriorityQueue ordered = new BucketPriorityQueue<>(size, partiallyBuiltBucketComparator); + InternalMultiTerms.Bucket spare = null; + BytesRef dest = null; + BytesKeyedBucketOrds.BucketOrdsEnum ordsEnum = bucketOrds.ordsEnum(owningBucketOrds[ordIdx]); + CheckedSupplier emptyBucketBuilder = () -> InternalMultiTerms.Bucket.EMPTY( + showTermDocCountError, + formats + ); + while (ordsEnum.next()) { + long docCount = bucketDocCount(ordsEnum.ord()); + otherDocCounts[ordIdx] += docCount; + if (docCount < bucketCountThresholds.getShardMinDocCount()) { + continue; + } + if (spare == null) { + spare = emptyBucketBuilder.get(); + dest = new BytesRef(); + } + + ordsEnum.readValue(dest); + + spare.termValues = decode(dest); + spare.docCount = docCount; + spare.bucketOrd = ordsEnum.ord(); + spare = ordered.insertWithOverflow(spare); + } + + // Get the top buckets + InternalMultiTerms.Bucket[] bucketsForOrd = new InternalMultiTerms.Bucket[ordered.size()]; + topBucketsPerOrd[ordIdx] = bucketsForOrd; + for (int b = ordered.size() - 1; b >= 0; --b) { + topBucketsPerOrd[ordIdx][b] = ordered.pop(); + otherDocCounts[ordIdx] -= topBucketsPerOrd[ordIdx][b].getDocCount(); + } + } + + buildSubAggsForAllBuckets(topBucketsPerOrd, b -> b.bucketOrd, (b, aggs) -> b.aggregations = aggs); + + InternalAggregation[] result = new InternalAggregation[owningBucketOrds.length]; + for (int ordIdx = 0; ordIdx < owningBucketOrds.length; ordIdx++) { + result[ordIdx] = buildResult(owningBucketOrds[ordIdx], otherDocCounts[ordIdx], topBucketsPerOrd[ordIdx]); + } + return result; + } + + InternalMultiTerms buildResult(long owningBucketOrd, long otherDocCount, InternalMultiTerms.Bucket[] topBuckets) { + BucketOrder reduceOrder; + if (isKeyOrder(order) == false) { + reduceOrder = InternalOrder.key(true); + Arrays.sort(topBuckets, reduceOrder.comparator()); + } else { + reduceOrder = order; + } + return new InternalMultiTerms( + name, + reduceOrder, + order, + bucketCountThresholds.getRequiredSize(), + bucketCountThresholds.getMinDocCount(), + metadata(), + bucketCountThresholds.getShardSize(), + showTermDocCountError, + otherDocCount, + 0, + formats, + org.opensearch.common.collect.List.of(topBuckets) + ); + } + + @Override + public InternalAggregation buildEmptyAggregation() { + return null; + } + + @Override + protected LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException { + MultiTermsValuesSourceCollector collector = multiTermsValue.getValues(ctx); + return new LeafBucketCollector() { + @Override + public void collect(int doc, long owningBucketOrd) throws IOException { + for (List value : collector.apply(doc)) { + long bucketOrd = bucketOrds.add(owningBucketOrd, encode(value)); + if (bucketOrd < 0) { + bucketOrd = -1 - bucketOrd; + collectExistingBucket(sub, doc, bucketOrd); + } else { + collectBucket(sub, doc, bucketOrd); + } + } + } + }; + } + + @Override + protected void doClose() { + Releasables.close(bucketOrds); + } + + private static BytesRef encode(List values) { + try (BytesStreamOutput output = new BytesStreamOutput()) { + output.writeCollection(values, StreamOutput::writeGenericValue); + return output.bytes().toBytesRef(); + } catch (IOException e) { + throw ExceptionsHelper.convertToRuntime(e); + } + } + + private static List decode(BytesRef bytesRef) { + try (StreamInput input = new BytesArray(bytesRef).streamInput()) { + return input.readList(StreamInput::readGenericValue); + } catch (IOException e) { + throw ExceptionsHelper.convertToRuntime(e); + } + } + + private boolean subAggsNeedScore() { + for (Aggregator subAgg : subAggregators) { + if (subAgg.scoreMode().needsScores()) { + return true; + } + } + return false; + } + + @Override + protected boolean shouldDefer(Aggregator aggregator) { + return collectMode == Aggregator.SubAggCollectionMode.BREADTH_FIRST && !aggsUsedForSorting.contains(aggregator); + } + + private void collectZeroDocEntriesIfNeeded(long owningBucketOrd) throws IOException { + if (bucketCountThresholds.getMinDocCount() != 0) { + return; + } + if (InternalOrder.isCountDesc(order) && bucketOrds.bucketsInOrd(owningBucketOrd) >= bucketCountThresholds.getRequiredSize()) { + return; + } + // we need to fill-in the blanks + for (LeafReaderContext ctx : context.searcher().getTopReaderContext().leaves()) { + MultiTermsValuesSourceCollector collector = multiTermsValue.getValues(ctx); + // brute force + for (int docId = 0; docId < ctx.reader().maxDoc(); ++docId) { + for (List value : collector.apply(docId)) { + bucketOrds.add(owningBucketOrd, encode(value)); + } + } + } + } + + /** + * A multi_terms collector which collect values on each doc, + */ + @FunctionalInterface + interface MultiTermsValuesSourceCollector { + /** + * Collect a list values of multi_terms on each doc. + * Each terms could have multi_values, so the result is the cartesian product of each term's values. + */ + List> apply(int doc) throws IOException; + } + + @FunctionalInterface + interface InternalValuesSource { + /** + * Create {@link InternalValuesSourceCollector} from existing {@link LeafReaderContext}. + */ + InternalValuesSourceCollector apply(LeafReaderContext ctx) throws IOException; + } + + /** + * A terms collector which collect values on each doc, + */ + @FunctionalInterface + interface InternalValuesSourceCollector { + /** + * Collect a list values of a term on specific doc. + */ + List apply(int doc) throws IOException; + } + + /** + * Multi_Term ValuesSource, it is a collection of {@link InternalValuesSource} + */ + static class MultiTermsValuesSource { + private final List valuesSources; + + public MultiTermsValuesSource(List valuesSources) { + this.valuesSources = valuesSources; + } + + public MultiTermsValuesSourceCollector getValues(LeafReaderContext ctx) throws IOException { + List collectors = new ArrayList<>(); + for (InternalValuesSource valuesSource : valuesSources) { + collectors.add(valuesSource.apply(ctx)); + } + return new MultiTermsValuesSourceCollector() { + @Override + public List> apply(int doc) throws IOException { + List, IOException>> collectedValues = new ArrayList<>(); + for (InternalValuesSourceCollector collector : collectors) { + collectedValues.add(() -> collector.apply(doc)); + } + List> result = new ArrayList<>(); + apply(0, collectedValues, new ArrayList<>(), result); + return result; + } + + /** + * DFS traverse each term's values and add cartesian product to results lists. + */ + private void apply( + int index, + List, IOException>> collectedValues, + List current, + List> results + ) throws IOException { + if (index == collectedValues.size()) { + results.add(org.opensearch.common.collect.List.copyOf(current)); + } else if (null != collectedValues.get(index)) { + for (Object value : collectedValues.get(index).get()) { + current.add(value); + apply(index + 1, collectedValues, current, results); + current.remove(current.size() - 1); + } + } + } + }; + } + } + + /** + * Factory for construct {@link InternalValuesSource}. + */ + static class InternalValuesSourceFactory { + static InternalValuesSource bytesValuesSource(ValuesSource valuesSource, IncludeExclude.StringFilter includeExclude) { + return ctx -> { + SortedBinaryDocValues values = valuesSource.bytesValues(ctx); + return doc -> { + BytesRefBuilder previous = new BytesRefBuilder(); + + if (false == values.advanceExact(doc)) { + return Collections.emptyList(); + } + int valuesCount = values.docValueCount(); + List termValues = new ArrayList<>(valuesCount); + + // SortedBinaryDocValues don't guarantee uniqueness so we + // need to take care of dups + previous.clear(); + for (int i = 0; i < valuesCount; ++i) { + BytesRef bytes = values.nextValue(); + if (includeExclude != null && false == includeExclude.accept(bytes)) { + continue; + } + if (i > 0 && previous.get().equals(bytes)) { + continue; + } + previous.copyBytes(bytes); + termValues.add(BytesRef.deepCopyOf(bytes)); + } + return termValues; + }; + }; + } + + static InternalValuesSource longValuesSource(ValuesSource.Numeric valuesSource, IncludeExclude.LongFilter longFilter) { + return ctx -> { + SortedNumericDocValues values = valuesSource.longValues(ctx); + return doc -> { + if (values.advanceExact(doc)) { + int valuesCount = values.docValueCount(); + + long previous = Long.MAX_VALUE; + List termValues = new ArrayList<>(valuesCount); + for (int i = 0; i < valuesCount; ++i) { + long val = values.nextValue(); + if (previous != val || i == 0) { + if (longFilter == null || longFilter.accept(val)) { + termValues.add(val); + } + previous = val; + } + } + return termValues; + } + return Collections.emptyList(); + }; + }; + } + + static InternalValuesSource doubleValueSource(ValuesSource.Numeric valuesSource, IncludeExclude.LongFilter longFilter) { + return ctx -> { + SortedNumericDoubleValues values = valuesSource.doubleValues(ctx); + return doc -> { + if (values.advanceExact(doc)) { + int valuesCount = values.docValueCount(); + + double previous = Double.MAX_VALUE; + List termValues = new ArrayList<>(valuesCount); + for (int i = 0; i < valuesCount; ++i) { + double val = values.nextValue(); + if (previous != val || i == 0) { + if (longFilter == null || longFilter.accept(NumericUtils.doubleToSortableLong(val))) { + termValues.add(val); + } + previous = val; + } + } + return termValues; + } + return Collections.emptyList(); + }; + }; + } + } +} diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/ParsedMultiTerms.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/ParsedMultiTerms.java new file mode 100644 index 00000000000..8686d329fa3 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/ParsedMultiTerms.java @@ -0,0 +1,77 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.aggregations.bucket.terms; + +import org.opensearch.common.xcontent.ObjectParser; +import org.opensearch.common.xcontent.XContentBuilder; +import org.opensearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.List; + +public class ParsedMultiTerms extends ParsedTerms { + @Override + public String getType() { + return MultiTermsAggregationBuilder.NAME; + } + + private static final ObjectParser PARSER = new ObjectParser<>( + ParsedMultiTerms.class.getSimpleName(), + true, + ParsedMultiTerms::new + ); + static { + declareParsedTermsFields(PARSER, ParsedBucket::fromXContent); + } + + public static ParsedMultiTerms fromXContent(XContentParser parser, String name) throws IOException { + ParsedMultiTerms aggregation = PARSER.parse(parser, null); + aggregation.setName(name); + return aggregation; + } + + public static class ParsedBucket extends ParsedTerms.ParsedBucket { + + private List key; + + @Override + public List getKey() { + return key; + } + + @Override + public String getKeyAsString() { + String keyAsString = super.getKeyAsString(); + if (keyAsString != null) { + return keyAsString; + } + if (key != null) { + return key.toString(); + } + return null; + } + + public Number getKeyAsNumber() { + throw new UnsupportedOperationException("not implemented"); + } + + @Override + protected XContentBuilder keyToXContent(XContentBuilder builder) throws IOException { + builder.field(CommonFields.KEY.getPreferredName(), key); + if (super.getKeyAsString() != null) { + builder.field(CommonFields.KEY_AS_STRING.getPreferredName(), getKeyAsString()); + } + return builder; + } + + static ParsedBucket fromXContent(XContentParser parser) throws IOException { + return parseTermsBucketXContent(parser, ParsedBucket::new, (p, bucket) -> { bucket.key = p.list(); }); + } + } +} diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/ParsedTerms.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/ParsedTerms.java index ce5f56c898f..054ea7d8270 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/ParsedTerms.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/ParsedTerms.java @@ -139,13 +139,16 @@ public abstract class ParsedTerms extends ParsedMultiBucketAggregation>, Void>, + Boolean, + Boolean> PARSER = (parser, scriptable, timezoneAware) -> { + parser.declareString(Builder::setFieldName, ParseField.CommonFields.FIELD); + parser.declareField( + Builder::setMissing, + XContentParser::objectText, + ParseField.CommonFields.MISSING, + ObjectParser.ValueType.VALUE + ); + + if (scriptable) { + parser.declareField( + Builder::setScript, + (p, context) -> Script.parse(p), + Script.SCRIPT_PARSE_FIELD, + ObjectParser.ValueType.OBJECT_OR_STRING + ); + } + + if (timezoneAware) { + parser.declareField(Builder::setTimeZone, p -> { + if (p.currentToken() == XContentParser.Token.VALUE_STRING) { + return ZoneId.of(p.text()); + } else { + return ZoneOffset.ofHours(p.intValue()); + } + }, ParseField.CommonFields.TIME_ZONE, ObjectParser.ValueType.LONG); + } + }; + + public BaseMultiValuesSourceFieldConfig(String fieldName, Object missing, Script script, ZoneId timeZone) { + this.fieldName = fieldName; + this.missing = missing; + this.script = script; + this.timeZone = timeZone; + } + + public BaseMultiValuesSourceFieldConfig(StreamInput in) throws IOException { + if (in.getVersion().onOrAfter(LegacyESVersion.V_7_6_0)) { + this.fieldName = in.readOptionalString(); + } else { + this.fieldName = in.readString(); + } + this.missing = in.readGenericValue(); + this.script = in.readOptionalWriteable(Script::new); + if (in.getVersion().before(LegacyESVersion.V_7_0_0)) { + this.timeZone = DateUtils.dateTimeZoneToZoneId(in.readOptionalTimeZone()); + } else { + this.timeZone = in.readOptionalZoneId(); + } + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + if (out.getVersion().onOrAfter(LegacyESVersion.V_7_6_0)) { + out.writeOptionalString(fieldName); + } else { + out.writeString(fieldName); + } + out.writeGenericValue(missing); + out.writeOptionalWriteable(script); + if (out.getVersion().before(LegacyESVersion.V_7_0_0)) { + out.writeOptionalTimeZone(DateUtils.zoneIdToDateTimeZone(timeZone)); + } else { + out.writeOptionalZoneId(timeZone); + } + doWriteTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (missing != null) { + builder.field(ParseField.CommonFields.MISSING.getPreferredName(), missing); + } + if (script != null) { + builder.field(Script.SCRIPT_PARSE_FIELD.getPreferredName(), script); + } + if (fieldName != null) { + builder.field(ParseField.CommonFields.FIELD.getPreferredName(), fieldName); + } + if (timeZone != null) { + builder.field(ParseField.CommonFields.TIME_ZONE.getPreferredName(), timeZone.getId()); + } + doXContentBody(builder, params); + builder.endObject(); + return builder; + } + + public Object getMissing() { + return missing; + } + + public Script getScript() { + return script; + } + + public ZoneId getTimeZone() { + return timeZone; + } + + public String getFieldName() { + return fieldName; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + BaseMultiValuesSourceFieldConfig that = (BaseMultiValuesSourceFieldConfig) o; + return Objects.equals(fieldName, that.fieldName) + && Objects.equals(missing, that.missing) + && Objects.equals(script, that.script) + && Objects.equals(timeZone, that.timeZone); + } + + @Override + public int hashCode() { + return Objects.hash(fieldName, missing, script, timeZone); + } + + @Override + public String toString() { + return Strings.toString(this); + } + + abstract void doXContentBody(XContentBuilder builder, Params params) throws IOException; + + abstract void doWriteTo(StreamOutput out) throws IOException; + + public abstract static class Builder> { + String fieldName; + Object missing = null; + Script script = null; + ZoneId timeZone = null; + + public String getFieldName() { + return fieldName; + } + + public B setFieldName(String fieldName) { + this.fieldName = fieldName; + return (B) this; + } + + public Object getMissing() { + return missing; + } + + public B setMissing(Object missing) { + this.missing = missing; + return (B) this; + } + + public Script getScript() { + return script; + } + + public B setScript(Script script) { + this.script = script; + return (B) this; + } + + public ZoneId getTimeZone() { + return timeZone; + } + + public B setTimeZone(ZoneId timeZone) { + this.timeZone = timeZone; + return (B) this; + } + + abstract public C build(); + } +} diff --git a/server/src/main/java/org/opensearch/search/aggregations/support/MultiTermsValuesSourceConfig.java b/server/src/main/java/org/opensearch/search/aggregations/support/MultiTermsValuesSourceConfig.java new file mode 100644 index 00000000000..3bc7f444c61 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/aggregations/support/MultiTermsValuesSourceConfig.java @@ -0,0 +1,203 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.aggregations.support; + +import org.opensearch.common.ParseField; +import org.opensearch.common.Strings; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.common.xcontent.ObjectParser; +import org.opensearch.common.xcontent.XContentBuilder; +import org.opensearch.common.xcontent.XContentParser; +import org.opensearch.script.Script; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.bucket.terms.IncludeExclude; + +import java.io.IOException; +import java.time.ZoneId; +import java.util.Objects; + +/** + * A configuration that used by multi_terms aggregations. + */ +public class MultiTermsValuesSourceConfig extends BaseMultiValuesSourceFieldConfig { + private final ValueType userValueTypeHint; + private final String format; + private final IncludeExclude includeExclude; + + private static final String NAME = "field_config"; + public static final ParseField FILTER = new ParseField("filter"); + + public interface ParserSupplier { + ObjectParser apply( + Boolean scriptable, + Boolean timezoneAware, + Boolean valueTypeHinted, + Boolean formatted + ); + } + + public static final MultiTermsValuesSourceConfig.ParserSupplier PARSER = (scriptable, timezoneAware, valueTypeHinted, formatted) -> { + + ObjectParser parser = new ObjectParser<>( + MultiTermsValuesSourceConfig.NAME, + MultiTermsValuesSourceConfig.Builder::new + ); + + BaseMultiValuesSourceFieldConfig.PARSER.apply(parser, scriptable, timezoneAware); + + if (valueTypeHinted) { + parser.declareField( + MultiTermsValuesSourceConfig.Builder::setUserValueTypeHint, + p -> ValueType.lenientParse(p.text()), + ValueType.VALUE_TYPE, + ObjectParser.ValueType.STRING + ); + } + + if (formatted) { + parser.declareField( + MultiTermsValuesSourceConfig.Builder::setFormat, + XContentParser::text, + ParseField.CommonFields.FORMAT, + ObjectParser.ValueType.STRING + ); + } + + parser.declareField( + (b, v) -> b.setIncludeExclude(IncludeExclude.merge(b.getIncludeExclude(), v)), + IncludeExclude::parseExclude, + IncludeExclude.EXCLUDE_FIELD, + ObjectParser.ValueType.STRING_ARRAY + ); + + return parser; + }; + + protected MultiTermsValuesSourceConfig( + String fieldName, + Object missing, + Script script, + ZoneId timeZone, + ValueType userValueTypeHint, + String format, + IncludeExclude includeExclude + ) { + super(fieldName, missing, script, timeZone); + this.userValueTypeHint = userValueTypeHint; + this.format = format; + this.includeExclude = includeExclude; + } + + public MultiTermsValuesSourceConfig(StreamInput in) throws IOException { + super(in); + this.userValueTypeHint = in.readOptionalWriteable(ValueType::readFromStream); + this.format = in.readOptionalString(); + this.includeExclude = in.readOptionalWriteable(IncludeExclude::new); + } + + public ValueType getUserValueTypeHint() { + return userValueTypeHint; + } + + public String getFormat() { + return format; + } + + /** + * Get terms to include and exclude from the aggregation results + */ + public IncludeExclude getIncludeExclude() { + return includeExclude; + } + + @Override + public void doWriteTo(StreamOutput out) throws IOException { + out.writeOptionalWriteable(userValueTypeHint); + out.writeOptionalString(format); + out.writeOptionalWriteable(includeExclude); + } + + @Override + public void doXContentBody(XContentBuilder builder, Params params) throws IOException { + if (userValueTypeHint != null) { + builder.field(AggregationBuilder.CommonFields.VALUE_TYPE.getPreferredName(), userValueTypeHint.getPreferredName()); + } + if (format != null) { + builder.field(AggregationBuilder.CommonFields.FORMAT.getPreferredName(), format); + } + if (includeExclude != null) { + includeExclude.toXContent(builder, params); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + if (super.equals(o) == false) return false; + + MultiTermsValuesSourceConfig that = (MultiTermsValuesSourceConfig) o; + return Objects.equals(userValueTypeHint, that.userValueTypeHint) + && Objects.equals(format, that.format) + && Objects.equals(includeExclude, that.includeExclude); + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), userValueTypeHint, format, includeExclude); + } + + public static class Builder extends BaseMultiValuesSourceFieldConfig.Builder { + private ValueType userValueTypeHint = null; + private String format; + private IncludeExclude includeExclude = null; + + public IncludeExclude getIncludeExclude() { + return includeExclude; + } + + public Builder setIncludeExclude(IncludeExclude includeExclude) { + this.includeExclude = includeExclude; + return this; + } + + public ValueType getUserValueTypeHint() { + return userValueTypeHint; + } + + public Builder setUserValueTypeHint(ValueType userValueTypeHint) { + this.userValueTypeHint = userValueTypeHint; + return this; + } + + public String getFormat() { + return format; + } + + public Builder setFormat(String format) { + this.format = format; + return this; + } + + public MultiTermsValuesSourceConfig build() { + if (Strings.isNullOrEmpty(fieldName) && script == null) { + throw new IllegalArgumentException( + "[" + + ParseField.CommonFields.FIELD.getPreferredName() + + "] and [" + + Script.SCRIPT_PARSE_FIELD.getPreferredName() + + "] cannot both be null. " + + "Please specify one or the other." + ); + } + return new MultiTermsValuesSourceConfig(fieldName, missing, script, timeZone, userValueTypeHint, format, includeExclude); + } + } +} diff --git a/server/src/main/java/org/opensearch/search/aggregations/support/MultiValuesSourceFieldConfig.java b/server/src/main/java/org/opensearch/search/aggregations/support/MultiValuesSourceFieldConfig.java index 54450763148..ea9bbe80192 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/support/MultiValuesSourceFieldConfig.java +++ b/server/src/main/java/org/opensearch/search/aggregations/support/MultiValuesSourceFieldConfig.java @@ -38,26 +38,17 @@ import org.opensearch.common.Strings; import org.opensearch.common.TriFunction; import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamOutput; -import org.opensearch.common.io.stream.Writeable; -import org.opensearch.common.time.DateUtils; import org.opensearch.common.xcontent.ObjectParser; -import org.opensearch.common.xcontent.ToXContentObject; import org.opensearch.common.xcontent.XContentBuilder; -import org.opensearch.common.xcontent.XContentParser; import org.opensearch.index.query.AbstractQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.script.Script; import java.io.IOException; import java.time.ZoneId; -import java.time.ZoneOffset; import java.util.Objects; -public class MultiValuesSourceFieldConfig implements Writeable, ToXContentObject { - private final String fieldName; - private final Object missing; - private final Script script; - private final ZoneId timeZone; +public class MultiValuesSourceFieldConfig extends BaseMultiValuesSourceFieldConfig { private final QueryBuilder filter; private static final String NAME = "field_config"; @@ -73,32 +64,7 @@ public class MultiValuesSourceFieldConfig implements Writeable, ToXContentObject MultiValuesSourceFieldConfig.Builder::new ); - parser.declareString(MultiValuesSourceFieldConfig.Builder::setFieldName, ParseField.CommonFields.FIELD); - parser.declareField( - MultiValuesSourceFieldConfig.Builder::setMissing, - XContentParser::objectText, - ParseField.CommonFields.MISSING, - ObjectParser.ValueType.VALUE - ); - - if (scriptable) { - parser.declareField( - MultiValuesSourceFieldConfig.Builder::setScript, - (p, context) -> Script.parse(p), - Script.SCRIPT_PARSE_FIELD, - ObjectParser.ValueType.OBJECT_OR_STRING - ); - } - - if (timezoneAware) { - parser.declareField(MultiValuesSourceFieldConfig.Builder::setTimeZone, p -> { - if (p.currentToken() == XContentParser.Token.VALUE_STRING) { - return ZoneId.of(p.text()); - } else { - return ZoneOffset.ofHours(p.intValue()); - } - }, ParseField.CommonFields.TIME_ZONE, ObjectParser.ValueType.LONG); - } + BaseMultiValuesSourceFieldConfig.PARSER.apply(parser, scriptable, timezoneAware); if (filtered) { parser.declareField( @@ -112,26 +78,12 @@ public class MultiValuesSourceFieldConfig implements Writeable, ToXContentObject }; protected MultiValuesSourceFieldConfig(String fieldName, Object missing, Script script, ZoneId timeZone, QueryBuilder filter) { - this.fieldName = fieldName; - this.missing = missing; - this.script = script; - this.timeZone = timeZone; + super(fieldName, missing, script, timeZone); this.filter = filter; } public MultiValuesSourceFieldConfig(StreamInput in) throws IOException { - if (in.getVersion().onOrAfter(LegacyESVersion.V_7_6_0)) { - this.fieldName = in.readOptionalString(); - } else { - this.fieldName = in.readString(); - } - this.missing = in.readGenericValue(); - this.script = in.readOptionalWriteable(Script::new); - if (in.getVersion().before(LegacyESVersion.V_7_0_0)) { - this.timeZone = DateUtils.dateTimeZoneToZoneId(in.readOptionalTimeZone()); - } else { - this.timeZone = in.readOptionalZoneId(); - } + super(in); if (in.getVersion().onOrAfter(LegacyESVersion.V_7_8_0)) { this.filter = in.readOptionalNamedWriteable(QueryBuilder.class); } else { @@ -139,133 +91,43 @@ public class MultiValuesSourceFieldConfig implements Writeable, ToXContentObject } } - public Object getMissing() { - return missing; - } - - public Script getScript() { - return script; - } - - public ZoneId getTimeZone() { - return timeZone; - } - - public String getFieldName() { - return fieldName; - } - public QueryBuilder getFilter() { return filter; } @Override - public void writeTo(StreamOutput out) throws IOException { - if (out.getVersion().onOrAfter(LegacyESVersion.V_7_6_0)) { - out.writeOptionalString(fieldName); - } else { - out.writeString(fieldName); - } - out.writeGenericValue(missing); - out.writeOptionalWriteable(script); - if (out.getVersion().before(LegacyESVersion.V_7_0_0)) { - out.writeOptionalTimeZone(DateUtils.zoneIdToDateTimeZone(timeZone)); - } else { - out.writeOptionalZoneId(timeZone); - } + public void doWriteTo(StreamOutput out) throws IOException { if (out.getVersion().onOrAfter(LegacyESVersion.V_7_8_0)) { out.writeOptionalNamedWriteable(filter); } } @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - if (missing != null) { - builder.field(ParseField.CommonFields.MISSING.getPreferredName(), missing); - } - if (script != null) { - builder.field(Script.SCRIPT_PARSE_FIELD.getPreferredName(), script); - } - if (fieldName != null) { - builder.field(ParseField.CommonFields.FIELD.getPreferredName(), fieldName); - } - if (timeZone != null) { - builder.field(ParseField.CommonFields.TIME_ZONE.getPreferredName(), timeZone.getId()); - } + public void doXContentBody(XContentBuilder builder, Params params) throws IOException { if (filter != null) { builder.field(FILTER.getPreferredName()); filter.toXContent(builder, params); } - builder.endObject(); - return builder; } @Override public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; + if (super.equals(o) == false) return false; + MultiValuesSourceFieldConfig that = (MultiValuesSourceFieldConfig) o; - return Objects.equals(fieldName, that.fieldName) - && Objects.equals(missing, that.missing) - && Objects.equals(script, that.script) - && Objects.equals(timeZone, that.timeZone) - && Objects.equals(filter, that.filter); + return Objects.equals(filter, that.filter); } @Override public int hashCode() { - return Objects.hash(fieldName, missing, script, timeZone, filter); + return Objects.hash(super.hashCode(), filter); } - @Override - public String toString() { - return Strings.toString(this); - } - - public static class Builder { - private String fieldName; - private Object missing = null; - private Script script = null; - private ZoneId timeZone = null; + public static class Builder extends BaseMultiValuesSourceFieldConfig.Builder { private QueryBuilder filter = null; - public String getFieldName() { - return fieldName; - } - - public Builder setFieldName(String fieldName) { - this.fieldName = fieldName; - return this; - } - - public Object getMissing() { - return missing; - } - - public Builder setMissing(Object missing) { - this.missing = missing; - return this; - } - - public Script getScript() { - return script; - } - - public Builder setScript(Script script) { - this.script = script; - return this; - } - - public ZoneId getTimeZone() { - return timeZone; - } - - public Builder setTimeZone(ZoneId timeZone) { - this.timeZone = timeZone; - return this; - } - public Builder setFilter(QueryBuilder filter) { this.filter = filter; return this; diff --git a/server/src/test/java/org/opensearch/search/aggregations/AggregationsTests.java b/server/src/test/java/org/opensearch/search/aggregations/AggregationsTests.java index fe029d22a45..421865013a2 100644 --- a/server/src/test/java/org/opensearch/search/aggregations/AggregationsTests.java +++ b/server/src/test/java/org/opensearch/search/aggregations/AggregationsTests.java @@ -64,6 +64,7 @@ import org.opensearch.search.aggregations.bucket.range.InternalGeoDistanceTests; import org.opensearch.search.aggregations.bucket.range.InternalRangeTests; import org.opensearch.search.aggregations.bucket.sampler.InternalSamplerTests; import org.opensearch.search.aggregations.bucket.terms.DoubleTermsTests; +import org.opensearch.search.aggregations.bucket.terms.InternalMultiTermsTests; import org.opensearch.search.aggregations.bucket.terms.LongRareTermsTests; import org.opensearch.search.aggregations.bucket.terms.LongTermsTests; import org.opensearch.search.aggregations.bucket.terms.SignificantLongTermsTests; @@ -172,6 +173,7 @@ public class AggregationsTests extends OpenSearchTestCase { aggsTests.add(new InternalTopHitsTests()); aggsTests.add(new InternalCompositeTests()); aggsTests.add(new InternalMedianAbsoluteDeviationTests()); + aggsTests.add(new InternalMultiTermsTests()); return Collections.unmodifiableList(aggsTests); } diff --git a/server/src/test/java/org/opensearch/search/aggregations/bucket/terms/InternalMultiTermsTests.java b/server/src/test/java/org/opensearch/search/aggregations/bucket/terms/InternalMultiTermsTests.java new file mode 100644 index 00000000000..2657f2bdd51 --- /dev/null +++ b/server/src/test/java/org/opensearch/search/aggregations/bucket/terms/InternalMultiTermsTests.java @@ -0,0 +1,116 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.aggregations.bucket.terms; + +import org.apache.lucene.document.InetAddressPoint; +import org.apache.lucene.util.BytesRef; +import org.opensearch.search.DocValueFormat; +import org.opensearch.search.aggregations.BucketOrder; +import org.opensearch.search.aggregations.InternalAggregations; +import org.opensearch.search.aggregations.ParsedMultiBucketAggregation; +import org.opensearch.search.aggregations.support.CoreValuesSourceType; +import org.opensearch.search.aggregations.support.ValuesSourceType; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import static java.util.Arrays.asList; + +public class InternalMultiTermsTests extends InternalTermsTestCase { + + /** + * terms count and type should consistent across entire test. + */ + private final List types = getSupportedValuesSourceTypes(); + + @Override + protected InternalTerms createTestInstance( + String name, + Map metadata, + InternalAggregations aggregations, + boolean showTermDocCountError, + long docCountError + ) { + BucketOrder order = BucketOrder.count(false); + long minDocCount = 1; + int requiredSize = 3; + int shardSize = requiredSize + 2; + long otherDocCount = 0; + + final int numBuckets = randomNumberOfBuckets(); + + List buckets = new ArrayList<>(); + List formats = types.stream().map(type -> type.getFormatter(null, null)).collect(Collectors.toList()); + + for (int i = 0; i < numBuckets; i++) { + buckets.add( + new InternalMultiTerms.Bucket( + types.stream().map(this::value).collect(Collectors.toList()), + minDocCount, + aggregations, + showTermDocCountError, + docCountError, + formats + ) + ); + } + BucketOrder reduceOrder = rarely() ? order : BucketOrder.key(true); + // mimic per-shard bucket sort operation, which is required by bucket reduce phase. + Collections.sort(buckets, reduceOrder.comparator()); + return new InternalMultiTerms( + name, + reduceOrder, + order, + requiredSize, + minDocCount, + metadata, + shardSize, + showTermDocCountError, + otherDocCount, + docCountError, + formats, + buckets + ); + } + + @Override + protected Class implementationClass() { + return ParsedMultiTerms.class; + } + + private static List getSupportedValuesSourceTypes() { + return Collections.unmodifiableList( + asList( + CoreValuesSourceType.NUMERIC, + CoreValuesSourceType.BYTES, + CoreValuesSourceType.IP, + CoreValuesSourceType.DATE, + CoreValuesSourceType.BOOLEAN + ) + ); + } + + private Object value(ValuesSourceType type) { + if (CoreValuesSourceType.NUMERIC.equals(type)) { + return randomInt(); + } else if (CoreValuesSourceType.DATE.equals(type)) { + return randomNonNegativeLong(); + } else if (CoreValuesSourceType.BOOLEAN.equals(type)) { + return randomBoolean(); + } else if (CoreValuesSourceType.BYTES.equals(type)) { + return new BytesRef(randomAlphaOfLength(10)); + } else if (CoreValuesSourceType.IP.equals(type)) { + return new BytesRef(InetAddressPoint.encode(randomIp(randomBoolean()))); + } + throw new IllegalArgumentException("unexpected type [" + type.typeName() + "]"); + } +} diff --git a/server/src/test/java/org/opensearch/search/aggregations/bucket/terms/MultiTermsAggregationBuilderTests.java b/server/src/test/java/org/opensearch/search/aggregations/bucket/terms/MultiTermsAggregationBuilderTests.java new file mode 100644 index 00000000000..505fb7382ab --- /dev/null +++ b/server/src/test/java/org/opensearch/search/aggregations/bucket/terms/MultiTermsAggregationBuilderTests.java @@ -0,0 +1,182 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.aggregations.bucket.terms; + +import org.opensearch.search.DocValueFormat; +import org.opensearch.search.aggregations.Aggregator; +import org.opensearch.search.aggregations.BaseAggregationTestCase; +import org.opensearch.search.aggregations.BucketOrder; +import org.opensearch.search.aggregations.support.MultiTermsValuesSourceConfig; +import org.opensearch.search.aggregations.support.ValueType; + +import java.time.ZoneId; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.function.Supplier; + +public class MultiTermsAggregationBuilderTests extends BaseAggregationTestCase { + + @Override + protected MultiTermsAggregationBuilder createTestAggregatorBuilder() { + String name = randomAlphaOfLengthBetween(3, 20); + MultiTermsAggregationBuilder factory = new MultiTermsAggregationBuilder(name); + + int termsCount = randomIntBetween(2, 10); + List fieldConfigs = new ArrayList<>(); + for (int i = 0; i < termsCount; i++) { + fieldConfigs.add(randomFieldConfig()); + } + factory.terms(fieldConfigs); + + if (randomBoolean()) { + factory.size(randomIntBetween(1, Integer.MAX_VALUE)); + } + if (randomBoolean()) { + factory.shardSize(randomIntBetween(1, Integer.MAX_VALUE)); + } + if (randomBoolean()) { + int minDocCount = randomInt(4); + switch (minDocCount) { + case 0: + break; + case 1: + case 2: + case 3: + case 4: + minDocCount = randomIntBetween(0, Integer.MAX_VALUE); + break; + default: + fail(); + } + factory.minDocCount(minDocCount); + } + if (randomBoolean()) { + int shardMinDocCount = randomInt(4); + switch (shardMinDocCount) { + case 0: + break; + case 1: + case 2: + case 3: + case 4: + shardMinDocCount = randomIntBetween(0, Integer.MAX_VALUE); + break; + default: + fail(); + } + factory.shardMinDocCount(shardMinDocCount); + } + if (randomBoolean()) { + factory.collectMode(randomFrom(Aggregator.SubAggCollectionMode.values())); + } + if (randomBoolean()) { + List order = randomOrder(); + if (order.size() == 1 && randomBoolean()) { + factory.order(order.get(0)); + } else { + factory.order(order); + } + } + if (randomBoolean()) { + factory.showTermDocCountError(randomBoolean()); + } + return factory; + } + + public void testInvalidTermsParams() { + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> { new MultiTermsAggregationBuilder("_name").terms(Collections.singletonList(randomFieldConfig())); } + ); + assertEquals( + "multi term aggregation must has at least 2 terms. Found [1] in [_name] Use terms aggregation for single term aggregation", + exception.getMessage() + ); + + exception = expectThrows( + IllegalArgumentException.class, + () -> { new MultiTermsAggregationBuilder("_name").terms(Collections.emptyList()); } + ); + assertEquals("multi term aggregation must has at least 2 terms. Found [0] in [_name]", exception.getMessage()); + + exception = expectThrows(IllegalArgumentException.class, () -> { new MultiTermsAggregationBuilder("_name").terms(null); }); + assertEquals("[terms] must not be null. Found null terms in [_name]", exception.getMessage()); + } + + private List randomOrder() { + List orders = new ArrayList<>(); + switch (randomInt(4)) { + case 0: + orders.add(BucketOrder.key(randomBoolean())); + break; + case 1: + orders.add(BucketOrder.count(randomBoolean())); + break; + case 2: + orders.add(BucketOrder.aggregation(randomAlphaOfLengthBetween(3, 20), randomBoolean())); + break; + case 3: + orders.add(BucketOrder.aggregation(randomAlphaOfLengthBetween(3, 20), randomAlphaOfLengthBetween(3, 20), randomBoolean())); + break; + case 4: + int numOrders = randomIntBetween(1, 3); + for (int i = 0; i < numOrders; i++) { + orders.addAll(randomOrder()); + } + break; + default: + fail(); + } + return orders; + } + + protected static MultiTermsValuesSourceConfig randomFieldConfig() { + String field = randomAlphaOfLength(10); + Object missing = randomBoolean() ? randomAlphaOfLength(10) : null; + ZoneId timeZone = randomBoolean() ? randomZone() : null; + ValueType userValueTypeHint = randomBoolean() + ? randomFrom(ValueType.STRING, ValueType.LONG, ValueType.DOUBLE, ValueType.DATE, ValueType.IP) + : null; + String format = randomBoolean() ? randomNumericDocValueFormat().toString() : null; + return randomFieldOrScript( + new MultiTermsValuesSourceConfig.Builder().setMissing(missing) + .setTimeZone(timeZone) + .setUserValueTypeHint(userValueTypeHint) + .setFormat(format), + field + ).build(); + } + + protected static MultiTermsValuesSourceConfig.Builder randomFieldOrScript(MultiTermsValuesSourceConfig.Builder builder, String field) { + int choice = randomInt(1); + switch (choice) { + case 0: + builder.setFieldName(field); + break; + case 1: + builder.setScript(mockScript("doc[" + field + "] + 1")); + break; + default: + throw new AssertionError("Unknown random operation [" + choice + "]"); + } + return builder; + } + + /** + * @return a random {@link DocValueFormat} that can be used in aggregations which + * compute numbers. + */ + protected static DocValueFormat randomNumericDocValueFormat() { + final List> formats = new ArrayList<>(3); + formats.add(() -> DocValueFormat.RAW); + formats.add(() -> new DocValueFormat.Decimal(randomFrom("###.##", "###,###.##"))); + return randomFrom(formats).get(); + } +} diff --git a/server/src/test/java/org/opensearch/search/aggregations/bucket/terms/MultiTermsAggregatorTests.java b/server/src/test/java/org/opensearch/search/aggregations/bucket/terms/MultiTermsAggregatorTests.java new file mode 100644 index 00000000000..f3922a65ff2 --- /dev/null +++ b/server/src/test/java/org/opensearch/search/aggregations/bucket/terms/MultiTermsAggregatorTests.java @@ -0,0 +1,909 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.aggregations.bucket.terms; + +import org.apache.lucene.document.DoubleDocValuesField; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.FloatDocValuesField; +import org.apache.lucene.document.InetAddressPoint; +import org.apache.lucene.document.NumericDocValuesField; +import org.apache.lucene.document.SortedDocValuesField; +import org.apache.lucene.document.SortedNumericDocValuesField; +import org.apache.lucene.document.SortedSetDocValuesField; +import org.apache.lucene.document.StringField; +import org.apache.lucene.tests.index.RandomIndexWriter; +import org.apache.lucene.index.Term; +import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.TermQuery; +import org.apache.lucene.util.BytesRef; +import org.hamcrest.MatcherAssert; +import org.opensearch.common.CheckedConsumer; +import org.opensearch.common.network.InetAddresses; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.time.DateFormatter; +import org.opensearch.index.mapper.BooleanFieldMapper; +import org.opensearch.index.mapper.DateFieldMapper; +import org.opensearch.index.mapper.GeoPointFieldMapper; +import org.opensearch.index.mapper.IpFieldMapper; +import org.opensearch.index.mapper.KeywordFieldMapper; +import org.opensearch.index.mapper.MappedFieldType; +import org.opensearch.index.mapper.NumberFieldMapper; +import org.opensearch.script.MockScriptEngine; +import org.opensearch.script.Script; +import org.opensearch.script.ScriptEngine; +import org.opensearch.script.ScriptModule; +import org.opensearch.script.ScriptService; +import org.opensearch.script.ScriptType; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.AggregatorTestCase; +import org.opensearch.search.aggregations.BucketOrder; +import org.opensearch.search.aggregations.metrics.InternalMax; +import org.opensearch.search.aggregations.metrics.MaxAggregationBuilder; +import org.opensearch.search.aggregations.support.CoreValuesSourceType; +import org.opensearch.search.aggregations.support.MultiTermsValuesSourceConfig; +import org.opensearch.search.aggregations.support.ValueType; +import org.opensearch.search.aggregations.support.ValuesSourceType; +import org.opensearch.search.lookup.LeafDocLookup; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; +import java.util.function.Function; + +import static java.util.Arrays.asList; +import static java.util.Collections.emptyMap; +import static java.util.Collections.singletonList; +import static java.util.Collections.singletonMap; +import static java.util.stream.Collectors.toList; +import static org.hamcrest.Matchers.closeTo; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; + +public class MultiTermsAggregatorTests extends AggregatorTestCase { + private static final String FIELD_NAME = "field"; + private static final String VALUE_SCRIPT_NAME = "value_script"; + private static final String FIELD_SCRIPT_NAME = "field_script"; + + private static final String AGG_NAME = "_name"; + + private static final String INT_FIELD = "int"; + private static final String LONG_FIELD = "long"; + private static final String FLOAT_FIELD = "float"; + private static final String DOUBLE_FIELD = "double"; + private static final String KEYWORD_FIELD = "keyword"; + private static final String DATE_FIELD = "date"; + private static final String IP_FIELD = "ip"; + private static final String GEO_POINT_FIELD = "geopoint"; + private static final String BOOL_FIELD = "bool"; + private static final String UNRELATED_KEYWORD_FIELD = "unrelated"; + + private static final Map mappedFieldTypeMap = new HashMap() { + { + put(INT_FIELD, new NumberFieldMapper.NumberFieldType(INT_FIELD, NumberFieldMapper.NumberType.INTEGER)); + put(LONG_FIELD, new NumberFieldMapper.NumberFieldType(LONG_FIELD, NumberFieldMapper.NumberType.LONG)); + put(FLOAT_FIELD, new NumberFieldMapper.NumberFieldType(FLOAT_FIELD, NumberFieldMapper.NumberType.FLOAT)); + put(DOUBLE_FIELD, new NumberFieldMapper.NumberFieldType(DOUBLE_FIELD, NumberFieldMapper.NumberType.DOUBLE)); + put(DATE_FIELD, dateFieldType(DATE_FIELD)); + put(KEYWORD_FIELD, new KeywordFieldMapper.KeywordFieldType(KEYWORD_FIELD)); + put(IP_FIELD, new IpFieldMapper.IpFieldType(IP_FIELD)); + put(FIELD_NAME, new NumberFieldMapper.NumberFieldType(FIELD_NAME, NumberFieldMapper.NumberType.INTEGER)); + put(UNRELATED_KEYWORD_FIELD, new KeywordFieldMapper.KeywordFieldType(UNRELATED_KEYWORD_FIELD)); + put(GEO_POINT_FIELD, new GeoPointFieldMapper.GeoPointFieldType(GEO_POINT_FIELD)); + put(BOOL_FIELD, new BooleanFieldMapper.BooleanFieldType(BOOL_FIELD)); + } + }; + + private static final Consumer NONE_DECORATOR = null; + + @Override + protected List getSupportedValuesSourceTypes() { + return Collections.unmodifiableList( + asList( + CoreValuesSourceType.NUMERIC, + CoreValuesSourceType.BYTES, + CoreValuesSourceType.IP, + CoreValuesSourceType.DATE, + CoreValuesSourceType.BOOLEAN + ) + ); + } + + @Override + protected AggregationBuilder createAggBuilderForTypeTest(MappedFieldType fieldType, String fieldName) { + return createTestAggregatorBuilder(asList(term(fieldName), term(fieldName))); + } + + @Override + protected ScriptService getMockScriptService() { + final Map, Object>> scripts = org.opensearch.common.collect.Map.of( + VALUE_SCRIPT_NAME, + vars -> ((Number) vars.get("_value")).doubleValue() + 1, + FIELD_SCRIPT_NAME, + vars -> { + final String fieldName = (String) vars.get(FIELD_NAME); + final LeafDocLookup lookup = (LeafDocLookup) vars.get("doc"); + return lookup.get(fieldName).stream().map(value -> ((Number) value).longValue() + 1).collect(toList()); + } + ); + final MockScriptEngine engine = new MockScriptEngine(MockScriptEngine.NAME, scripts, emptyMap()); + final Map engines = singletonMap(engine.getType(), engine); + return new ScriptService(Settings.EMPTY, engines, ScriptModule.CORE_CONTEXTS); + } + + public void testNumbers() throws IOException { + testAggregation( + new MatchAllDocsQuery(), + fieldConfigs(asList(INT_FIELD, LONG_FIELD, FLOAT_FIELD, DOUBLE_FIELD)), + NONE_DECORATOR, + iw -> { + iw.addDocument( + asList( + new NumericDocValuesField(INT_FIELD, 1), + new SortedNumericDocValuesField(LONG_FIELD, 1L), + new FloatDocValuesField(FLOAT_FIELD, 1.0f), + new DoubleDocValuesField(DOUBLE_FIELD, 1.0d) + ) + ); + iw.addDocument( + asList( + new NumericDocValuesField(INT_FIELD, 1), + new SortedNumericDocValuesField(LONG_FIELD, 1L), + new FloatDocValuesField(FLOAT_FIELD, 1.0f), + new DoubleDocValuesField(DOUBLE_FIELD, 1.0d) + ) + ); + iw.addDocument( + asList( + new NumericDocValuesField(INT_FIELD, 2), + new SortedNumericDocValuesField(LONG_FIELD, 2L), + new FloatDocValuesField(FLOAT_FIELD, 2.0f), + new DoubleDocValuesField(DOUBLE_FIELD, 2.0d) + ) + ); + iw.addDocument( + asList( + new NumericDocValuesField(INT_FIELD, 2), + new SortedNumericDocValuesField(LONG_FIELD, 2L), + new FloatDocValuesField(FLOAT_FIELD, 3.0f), + new DoubleDocValuesField(DOUBLE_FIELD, 3.0d) + ) + ); + iw.addDocument( + asList( + new NumericDocValuesField(INT_FIELD, 2), + new SortedNumericDocValuesField(LONG_FIELD, 2L), + new FloatDocValuesField(FLOAT_FIELD, 3.0f), + new DoubleDocValuesField(DOUBLE_FIELD, 3.0d) + ) + ); + }, + h -> { + MatcherAssert.assertThat(h.getBuckets(), hasSize(3)); + MatcherAssert.assertThat(h.getBuckets().get(0).getKey(), contains(equalTo(1L), equalTo(1L), equalTo(1.0), equalTo(1.0))); + MatcherAssert.assertThat(h.getBuckets().get(0).getDocCount(), equalTo(2L)); + MatcherAssert.assertThat(h.getBuckets().get(1).getKey(), contains(equalTo(2L), equalTo(2L), equalTo(3.0), equalTo(3.0))); + MatcherAssert.assertThat(h.getBuckets().get(1).getDocCount(), equalTo(2L)); + MatcherAssert.assertThat(h.getBuckets().get(2).getKey(), contains(equalTo(2L), equalTo(2L), equalTo(2.0), equalTo(2.0))); + MatcherAssert.assertThat(h.getBuckets().get(2).getDocCount(), equalTo(1L)); + } + ); + } + + public void testMixNumberAndKeywordWithFilter() throws IOException { + testAggregation( + new TermQuery(new Term(KEYWORD_FIELD, "a")), + fieldConfigs(asList(KEYWORD_FIELD, FLOAT_FIELD)), + NONE_DECORATOR, + iw -> { + iw.addDocument( + asList( + new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("a")), + new StringField(KEYWORD_FIELD, "a", Field.Store.NO), + new FloatDocValuesField(FLOAT_FIELD, 2.0f) + ) + ); + iw.addDocument( + asList( + new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("a")), + new StringField(KEYWORD_FIELD, "a", Field.Store.NO), + new FloatDocValuesField(FLOAT_FIELD, 1.0f) + ) + ); + iw.addDocument( + asList( + new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("a")), + new StringField(KEYWORD_FIELD, "b", Field.Store.NO), + new FloatDocValuesField(FLOAT_FIELD, 1.0f) + ) + ); + iw.addDocument( + asList( + new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("a")), + new StringField(KEYWORD_FIELD, "a", Field.Store.NO), + new FloatDocValuesField(FLOAT_FIELD, 2.0f) + ) + ); + }, + h -> { + MatcherAssert.assertThat(h.getBuckets(), hasSize(2)); + MatcherAssert.assertThat(h.getBuckets().get(0).getKey(), contains(equalTo("a"), equalTo(2.0))); + MatcherAssert.assertThat(h.getBuckets().get(0).getDocCount(), equalTo(2L)); + MatcherAssert.assertThat(h.getBuckets().get(1).getKey(), contains(equalTo("a"), equalTo(1.0))); + MatcherAssert.assertThat(h.getBuckets().get(1).getDocCount(), equalTo(1L)); + } + ); + } + + public void testMixNumberAndKeyword() throws IOException { + testAggregation(new MatchAllDocsQuery(), fieldConfigs(asList(KEYWORD_FIELD, INT_FIELD, FLOAT_FIELD)), NONE_DECORATOR, iw -> { + iw.addDocument( + asList( + new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("a")), + new NumericDocValuesField(INT_FIELD, 1), + new FloatDocValuesField(FLOAT_FIELD, 1.0f) + ) + ); + iw.addDocument( + asList( + new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("a")), + new NumericDocValuesField(INT_FIELD, 1), + new FloatDocValuesField(FLOAT_FIELD, 1.0f) + ) + ); + iw.addDocument( + asList( + new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("b")), + new NumericDocValuesField(INT_FIELD, 1), + new FloatDocValuesField(FLOAT_FIELD, 2.0f) + ) + ); + iw.addDocument( + asList( + new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("c")), + new NumericDocValuesField(INT_FIELD, 2), + new FloatDocValuesField(FLOAT_FIELD, 2.0f) + ) + ); + }, h -> { + MatcherAssert.assertThat(h.getBuckets(), hasSize(3)); + MatcherAssert.assertThat(h.getBuckets().get(0).getKey(), contains(equalTo("a"), equalTo(1L), equalTo(1.0))); + MatcherAssert.assertThat(h.getBuckets().get(0).getDocCount(), equalTo(2L)); + MatcherAssert.assertThat(h.getBuckets().get(1).getKey(), contains(equalTo("b"), equalTo(1L), equalTo(2.0))); + MatcherAssert.assertThat(h.getBuckets().get(1).getDocCount(), equalTo(1L)); + MatcherAssert.assertThat(h.getBuckets().get(2).getKey(), contains(equalTo("c"), equalTo(2L), equalTo(2.0))); + MatcherAssert.assertThat(h.getBuckets().get(2).getDocCount(), equalTo(1L)); + }); + } + + public void testMultiValuesField() throws IOException { + testAggregation(new MatchAllDocsQuery(), fieldConfigs(asList(KEYWORD_FIELD, INT_FIELD)), NONE_DECORATOR, iw -> { + iw.addDocument( + asList( + new SortedSetDocValuesField(KEYWORD_FIELD, new BytesRef("a")), + new SortedSetDocValuesField(KEYWORD_FIELD, new BytesRef("b")), + new SortedNumericDocValuesField(INT_FIELD, 1) + ) + ); + iw.addDocument( + asList( + new SortedSetDocValuesField(KEYWORD_FIELD, new BytesRef("a")), + new SortedNumericDocValuesField(INT_FIELD, 1), + new SortedNumericDocValuesField(INT_FIELD, 3) + ) + ); + }, h -> { + MatcherAssert.assertThat(h.getBuckets(), hasSize(3)); + MatcherAssert.assertThat(h.getBuckets().get(0).getKey(), contains(equalTo("a"), equalTo(1L))); + MatcherAssert.assertThat(h.getBuckets().get(0).getDocCount(), equalTo(2L)); + MatcherAssert.assertThat(h.getBuckets().get(1).getKey(), contains(equalTo("a"), equalTo(3L))); + MatcherAssert.assertThat(h.getBuckets().get(1).getDocCount(), equalTo(1L)); + MatcherAssert.assertThat(h.getBuckets().get(2).getKey(), contains(equalTo("b"), equalTo(1L))); + }); + + testAggregation(new MatchAllDocsQuery(), fieldConfigs(asList(KEYWORD_FIELD, INT_FIELD)), NONE_DECORATOR, iw -> { + iw.addDocument( + asList( + new SortedSetDocValuesField(KEYWORD_FIELD, new BytesRef("a")), + new SortedSetDocValuesField(KEYWORD_FIELD, new BytesRef("b")), + new SortedNumericDocValuesField(INT_FIELD, 1), + new SortedNumericDocValuesField(INT_FIELD, 2) + ) + ); + iw.addDocument( + asList( + new SortedSetDocValuesField(KEYWORD_FIELD, new BytesRef("a")), + new SortedSetDocValuesField(KEYWORD_FIELD, new BytesRef("c")), + new SortedNumericDocValuesField(INT_FIELD, 1), + new SortedNumericDocValuesField(INT_FIELD, 3) + ) + ); + }, h -> { + MatcherAssert.assertThat(h.getBuckets(), hasSize(7)); + MatcherAssert.assertThat(h.getBuckets().get(0).getKey(), contains(equalTo("a"), equalTo(1L))); + MatcherAssert.assertThat(h.getBuckets().get(0).getDocCount(), equalTo(2L)); + MatcherAssert.assertThat(h.getBuckets().get(1).getKey(), contains(equalTo("a"), equalTo(2L))); + MatcherAssert.assertThat(h.getBuckets().get(1).getDocCount(), equalTo(1L)); + MatcherAssert.assertThat(h.getBuckets().get(2).getKey(), contains(equalTo("a"), equalTo(3L))); + MatcherAssert.assertThat(h.getBuckets().get(2).getDocCount(), equalTo(1L)); + MatcherAssert.assertThat(h.getBuckets().get(3).getKey(), contains(equalTo("b"), equalTo(1L))); + MatcherAssert.assertThat(h.getBuckets().get(3).getDocCount(), equalTo(1L)); + MatcherAssert.assertThat(h.getBuckets().get(4).getKey(), contains(equalTo("b"), equalTo(2L))); + MatcherAssert.assertThat(h.getBuckets().get(4).getDocCount(), equalTo(1L)); + MatcherAssert.assertThat(h.getBuckets().get(5).getKey(), contains(equalTo("c"), equalTo(1L))); + MatcherAssert.assertThat(h.getBuckets().get(5).getDocCount(), equalTo(1L)); + MatcherAssert.assertThat(h.getBuckets().get(6).getKey(), contains(equalTo("c"), equalTo(3L))); + MatcherAssert.assertThat(h.getBuckets().get(6).getDocCount(), equalTo(1L)); + }); + } + + public void testScripts() throws IOException { + testAggregation( + new MatchAllDocsQuery(), + asList( + new MultiTermsValuesSourceConfig.Builder().setFieldName(KEYWORD_FIELD).build(), + new MultiTermsValuesSourceConfig.Builder().setScript( + new Script(ScriptType.INLINE, MockScriptEngine.NAME, FIELD_SCRIPT_NAME, singletonMap(FIELD_NAME, FIELD_NAME)) + ).setUserValueTypeHint(ValueType.LONG).build() + ), + null, + iw -> { + iw.addDocument( + asList(new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("a")), new NumericDocValuesField(FIELD_NAME, 1)) + ); + iw.addDocument( + asList(new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("b")), new NumericDocValuesField(FIELD_NAME, 2)) + ); + iw.addDocument( + asList(new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("b")), new NumericDocValuesField(FIELD_NAME, 2)) + ); + iw.addDocument( + asList(new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("c")), new NumericDocValuesField(FIELD_NAME, 3)) + ); + }, + h -> { + MatcherAssert.assertThat(h.getBuckets(), hasSize(3)); + MatcherAssert.assertThat(h.getBuckets().get(0).getKey(), contains(equalTo("b"), equalTo(3L))); + MatcherAssert.assertThat(h.getBuckets().get(0).getDocCount(), equalTo(2L)); + MatcherAssert.assertThat(h.getBuckets().get(1).getKey(), contains(equalTo("a"), equalTo(2L))); + MatcherAssert.assertThat(h.getBuckets().get(1).getDocCount(), equalTo(1L)); + MatcherAssert.assertThat(h.getBuckets().get(2).getKey(), contains(equalTo("c"), equalTo(4L))); + MatcherAssert.assertThat(h.getBuckets().get(2).getDocCount(), equalTo(1L)); + } + ); + } + + public void testScriptsWithoutValueTypeHint() throws IOException { + testAggregation( + new MatchAllDocsQuery(), + asList( + new MultiTermsValuesSourceConfig.Builder().setFieldName(KEYWORD_FIELD).build(), + new MultiTermsValuesSourceConfig.Builder().setScript( + new Script(ScriptType.INLINE, MockScriptEngine.NAME, FIELD_SCRIPT_NAME, singletonMap(FIELD_NAME, FIELD_NAME)) + ).build() + ), + null, + iw -> { + iw.addDocument( + asList(new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("a")), new NumericDocValuesField(FIELD_NAME, 1)) + ); + iw.addDocument( + asList(new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("b")), new NumericDocValuesField(FIELD_NAME, 2)) + ); + iw.addDocument( + asList(new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("b")), new NumericDocValuesField(FIELD_NAME, 2)) + ); + iw.addDocument( + asList(new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("c")), new NumericDocValuesField(FIELD_NAME, 3)) + ); + }, + h -> { + MatcherAssert.assertThat(h.getBuckets(), hasSize(3)); + MatcherAssert.assertThat(h.getBuckets().get(0).getKey(), contains(equalTo("b"), equalTo("3"))); + MatcherAssert.assertThat(h.getBuckets().get(0).getDocCount(), equalTo(2L)); + MatcherAssert.assertThat(h.getBuckets().get(1).getKey(), contains(equalTo("a"), equalTo("2"))); + MatcherAssert.assertThat(h.getBuckets().get(1).getDocCount(), equalTo(1L)); + MatcherAssert.assertThat(h.getBuckets().get(2).getKey(), contains(equalTo("c"), equalTo("4"))); + MatcherAssert.assertThat(h.getBuckets().get(2).getDocCount(), equalTo(1L)); + } + ); + } + + public void testValueScripts() throws IOException { + testAggregation( + new MatchAllDocsQuery(), + asList( + new MultiTermsValuesSourceConfig.Builder().setFieldName(KEYWORD_FIELD).build(), + new MultiTermsValuesSourceConfig.Builder().setFieldName(FIELD_NAME) + .setScript(new Script(ScriptType.INLINE, MockScriptEngine.NAME, VALUE_SCRIPT_NAME, emptyMap())) + .build() + ), + null, + iw -> { + iw.addDocument( + asList(new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("a")), new NumericDocValuesField(FIELD_NAME, 1)) + ); + iw.addDocument( + asList(new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("b")), new NumericDocValuesField(FIELD_NAME, 2)) + ); + iw.addDocument( + asList(new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("b")), new NumericDocValuesField(FIELD_NAME, 2)) + ); + iw.addDocument( + asList(new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("c")), new NumericDocValuesField(FIELD_NAME, 3)) + ); + }, + h -> { + MatcherAssert.assertThat(h.getBuckets(), hasSize(3)); + MatcherAssert.assertThat(h.getBuckets().get(0).getKey(), contains(equalTo("b"), equalTo(3.0))); + MatcherAssert.assertThat(h.getBuckets().get(0).getDocCount(), equalTo(2L)); + MatcherAssert.assertThat(h.getBuckets().get(1).getKey(), contains(equalTo("a"), equalTo(2.0))); + MatcherAssert.assertThat(h.getBuckets().get(1).getDocCount(), equalTo(1L)); + MatcherAssert.assertThat(h.getBuckets().get(2).getKey(), contains(equalTo("c"), equalTo(4.0))); + MatcherAssert.assertThat(h.getBuckets().get(2).getDocCount(), equalTo(1L)); + } + ); + } + + public void testOrderByMetrics() throws IOException { + testAggregation(new MatchAllDocsQuery(), fieldConfigs(asList(KEYWORD_FIELD, INT_FIELD)), b -> { + b.order(BucketOrder.aggregation("max", false)); + b.subAggregation(new MaxAggregationBuilder("max").field(FLOAT_FIELD)); + }, iw -> { + iw.addDocument( + asList( + new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("a")), + new NumericDocValuesField(INT_FIELD, 1), + new FloatDocValuesField(FLOAT_FIELD, 1.0f) + ) + ); + iw.addDocument( + asList( + new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("b")), + new NumericDocValuesField(INT_FIELD, 2), + new FloatDocValuesField(FLOAT_FIELD, 2.0f) + ) + ); + iw.addDocument( + asList( + new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("c")), + new NumericDocValuesField(INT_FIELD, 3), + new FloatDocValuesField(FLOAT_FIELD, 3.0f) + ) + ); + iw.addDocument( + asList( + new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("a")), + new NumericDocValuesField(INT_FIELD, 1), + new FloatDocValuesField(FLOAT_FIELD, 4.0f) + ) + ); + iw.addDocument( + asList( + new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("b")), + new NumericDocValuesField(INT_FIELD, 2), + new FloatDocValuesField(FLOAT_FIELD, 3.0f) + ) + ); + iw.addDocument( + asList( + new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("c")), + new NumericDocValuesField(INT_FIELD, 3), + new FloatDocValuesField(FLOAT_FIELD, 2.0f) + ) + ); + }, h -> { + MatcherAssert.assertThat(h.getBuckets(), hasSize(3)); + MatcherAssert.assertThat(h.getBuckets().get(0).getKey(), contains(equalTo("a"), equalTo(1L))); + MatcherAssert.assertThat(h.getBuckets().get(0).getDocCount(), equalTo(2L)); + MatcherAssert.assertThat(((InternalMax) (h.getBuckets().get(0).getAggregations().get("max"))).value(), closeTo(4.0f, 0.01)); + MatcherAssert.assertThat(h.getBuckets().get(1).getKey(), contains(equalTo("b"), equalTo(2L))); + MatcherAssert.assertThat(h.getBuckets().get(1).getDocCount(), equalTo(2L)); + MatcherAssert.assertThat(((InternalMax) (h.getBuckets().get(1).getAggregations().get("max"))).value(), closeTo(3.0f, 0.01)); + MatcherAssert.assertThat(h.getBuckets().get(2).getKey(), contains(equalTo("c"), equalTo(3L))); + MatcherAssert.assertThat(h.getBuckets().get(2).getDocCount(), equalTo(2L)); + MatcherAssert.assertThat(((InternalMax) (h.getBuckets().get(2).getAggregations().get("max"))).value(), closeTo(3.0f, 0.01)); + }); + } + + public void testNumberFieldFormat() throws IOException { + testAggregation( + new MatchAllDocsQuery(), + asList(term(KEYWORD_FIELD), new MultiTermsValuesSourceConfig.Builder().setFieldName(DOUBLE_FIELD).setFormat("00.00").build()), + null, + iw -> { + iw.addDocument( + asList(new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("a")), new DoubleDocValuesField(DOUBLE_FIELD, 1.0d)) + ); + iw.addDocument( + asList(new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("b")), new DoubleDocValuesField(DOUBLE_FIELD, 2.0d)) + ); + iw.addDocument( + asList(new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("a")), new DoubleDocValuesField(DOUBLE_FIELD, 2.0d)) + ); + iw.addDocument( + asList(new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("a")), new DoubleDocValuesField(DOUBLE_FIELD, 1.0d)) + ); + }, + h -> { + MatcherAssert.assertThat(h.getBuckets(), hasSize(3)); + MatcherAssert.assertThat(h.getBuckets().get(0).getKeyAsString(), equalTo("a|01.00")); + MatcherAssert.assertThat(h.getBuckets().get(0).getDocCount(), equalTo(2L)); + MatcherAssert.assertThat(h.getBuckets().get(1).getKeyAsString(), equalTo("a|02.00")); + MatcherAssert.assertThat(h.getBuckets().get(1).getDocCount(), equalTo(1L)); + MatcherAssert.assertThat(h.getBuckets().get(2).getKeyAsString(), equalTo("b|02.00")); + MatcherAssert.assertThat(h.getBuckets().get(2).getDocCount(), equalTo(1L)); + } + ); + } + + public void testDates() throws IOException { + testAggregation( + new MatchAllDocsQuery(), + asList(new MultiTermsValuesSourceConfig.Builder().setFieldName(DATE_FIELD).build(), term(KEYWORD_FIELD)), + null, + iw -> { + iw.addDocument( + asList( + new SortedNumericDocValuesField(DATE_FIELD, dateFieldType(DATE_FIELD).parse("2022-03-23")), + new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("a")) + ) + ); + iw.addDocument( + asList( + new SortedNumericDocValuesField(DATE_FIELD, dateFieldType(DATE_FIELD).parse("2022-03-23")), + new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("b")) + ) + ); + iw.addDocument( + asList( + new SortedNumericDocValuesField(DATE_FIELD, dateFieldType(DATE_FIELD).parse("2022-03-22")), + new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("a")) + ) + ); + iw.addDocument( + asList( + new SortedNumericDocValuesField(DATE_FIELD, dateFieldType(DATE_FIELD).parse("2022-03-23")), + new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("a")) + ) + ); + iw.addDocument( + asList( + new SortedNumericDocValuesField(DATE_FIELD, dateFieldType(DATE_FIELD).parse("2022-03-21")), + new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("c")) + ) + ); + }, + h -> { + MatcherAssert.assertThat(h.getBuckets(), hasSize(4)); + MatcherAssert.assertThat(h.getBuckets().get(0).getKeyAsString(), equalTo("2022-03-23|a")); + MatcherAssert.assertThat(h.getBuckets().get(0).getDocCount(), equalTo(2L)); + MatcherAssert.assertThat(h.getBuckets().get(1).getKeyAsString(), equalTo("2022-03-21|c")); + MatcherAssert.assertThat(h.getBuckets().get(1).getDocCount(), equalTo(1L)); + MatcherAssert.assertThat(h.getBuckets().get(2).getKeyAsString(), equalTo("2022-03-22|a")); + MatcherAssert.assertThat(h.getBuckets().get(2).getDocCount(), equalTo(1L)); + MatcherAssert.assertThat(h.getBuckets().get(3).getKeyAsString(), equalTo("2022-03-23|b")); + MatcherAssert.assertThat(h.getBuckets().get(3).getDocCount(), equalTo(1L)); + } + ); + } + + public void testDatesFieldFormat() throws IOException { + testAggregation( + new MatchAllDocsQuery(), + asList( + new MultiTermsValuesSourceConfig.Builder().setFieldName(DATE_FIELD).setFormat("yyyy/MM/dd").build(), + term(KEYWORD_FIELD) + ), + null, + iw -> { + iw.addDocument( + asList( + new SortedNumericDocValuesField(DATE_FIELD, dateFieldType(DATE_FIELD).parse("2022-03-23")), + new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("a")) + ) + ); + iw.addDocument( + asList( + new SortedNumericDocValuesField(DATE_FIELD, dateFieldType(DATE_FIELD).parse("2022-03-23")), + new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("b")) + ) + ); + iw.addDocument( + asList( + new SortedNumericDocValuesField(DATE_FIELD, dateFieldType(DATE_FIELD).parse("2022-03-22")), + new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("a")) + ) + ); + iw.addDocument( + asList( + new SortedNumericDocValuesField(DATE_FIELD, dateFieldType(DATE_FIELD).parse("2022-03-23")), + new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("a")) + ) + ); + iw.addDocument( + asList( + new SortedNumericDocValuesField(DATE_FIELD, dateFieldType(DATE_FIELD).parse("2022-03-21")), + new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("c")) + ) + ); + }, + h -> { + MatcherAssert.assertThat(h.getBuckets(), hasSize(4)); + MatcherAssert.assertThat(h.getBuckets().get(0).getKeyAsString(), equalTo("2022/03/23|a")); + MatcherAssert.assertThat(h.getBuckets().get(0).getDocCount(), equalTo(2L)); + MatcherAssert.assertThat(h.getBuckets().get(1).getKeyAsString(), equalTo("2022/03/21|c")); + MatcherAssert.assertThat(h.getBuckets().get(1).getDocCount(), equalTo(1L)); + MatcherAssert.assertThat(h.getBuckets().get(2).getKeyAsString(), equalTo("2022/03/22|a")); + MatcherAssert.assertThat(h.getBuckets().get(2).getDocCount(), equalTo(1L)); + MatcherAssert.assertThat(h.getBuckets().get(3).getKeyAsString(), equalTo("2022/03/23|b")); + MatcherAssert.assertThat(h.getBuckets().get(3).getDocCount(), equalTo(1L)); + } + ); + } + + public void testIpAndKeyword() throws IOException { + testAggregation(new MatchAllDocsQuery(), fieldConfigs(asList(KEYWORD_FIELD, IP_FIELD)), NONE_DECORATOR, iw -> { + iw.addDocument( + asList( + new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("a")), + new SortedDocValuesField(IP_FIELD, new BytesRef(InetAddressPoint.encode(InetAddresses.forString("192.168.0.0")))) + ) + ); + iw.addDocument( + asList( + new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("b")), + new SortedDocValuesField(IP_FIELD, new BytesRef(InetAddressPoint.encode(InetAddresses.forString("192.168.0.1")))) + ) + ); + iw.addDocument( + asList( + new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("c")), + new SortedDocValuesField(IP_FIELD, new BytesRef(InetAddressPoint.encode(InetAddresses.forString("192.168.0.2")))) + ) + ); + iw.addDocument( + asList( + new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("a")), + new SortedDocValuesField(IP_FIELD, new BytesRef(InetAddressPoint.encode(InetAddresses.forString("192.168.0.0")))) + ) + ); + }, h -> { + MatcherAssert.assertThat(h.getBuckets(), hasSize(3)); + MatcherAssert.assertThat(h.getBuckets().get(0).getKey(), contains(equalTo("a"), equalTo("192.168.0.0"))); + MatcherAssert.assertThat(h.getBuckets().get(0).getKeyAsString(), equalTo("a|192.168.0.0")); + MatcherAssert.assertThat(h.getBuckets().get(0).getDocCount(), equalTo(2L)); + MatcherAssert.assertThat(h.getBuckets().get(1).getKey(), contains(equalTo("b"), equalTo("192.168.0.1"))); + MatcherAssert.assertThat(h.getBuckets().get(1).getKeyAsString(), equalTo("b|192.168.0.1")); + MatcherAssert.assertThat(h.getBuckets().get(1).getDocCount(), equalTo(1L)); + MatcherAssert.assertThat(h.getBuckets().get(2).getKey(), contains(equalTo("c"), equalTo("192.168.0.2"))); + MatcherAssert.assertThat(h.getBuckets().get(2).getKeyAsString(), equalTo("c|192.168.0.2")); + MatcherAssert.assertThat(h.getBuckets().get(2).getDocCount(), equalTo(1L)); + }); + } + + public void testEmpty() throws IOException { + testAggregation(new MatchAllDocsQuery(), fieldConfigs(asList(KEYWORD_FIELD, INT_FIELD)), NONE_DECORATOR, iw -> {}, h -> { + MatcherAssert.assertThat(h.getName(), equalTo(AGG_NAME)); + MatcherAssert.assertThat(h.getBuckets(), hasSize(0)); + }); + } + + public void testNull() throws IOException { + testAggregation(new MatchAllDocsQuery(), fieldConfigs(asList(KEYWORD_FIELD, INT_FIELD, FLOAT_FIELD)), NONE_DECORATOR, iw -> { + iw.addDocument( + asList( + new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("a")), + new NumericDocValuesField(INT_FIELD, 1), + new FloatDocValuesField(FLOAT_FIELD, 1.0f) + ) + ); + iw.addDocument( + asList( + new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("a")), + new NumericDocValuesField(INT_FIELD, 1), + new FloatDocValuesField(FLOAT_FIELD, 1.0f) + ) + ); + iw.addDocument(asList(new NumericDocValuesField(INT_FIELD, 1), new FloatDocValuesField(FLOAT_FIELD, 2.0f))); + iw.addDocument(asList(new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("c")), new FloatDocValuesField(FLOAT_FIELD, 2.0f))); + iw.addDocument(asList(new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("d")), new NumericDocValuesField(INT_FIELD, 3))); + + }, h -> { + MatcherAssert.assertThat(h.getBuckets(), hasSize(1)); + MatcherAssert.assertThat(h.getBuckets().get(0).getKey(), contains(equalTo("a"), equalTo(1L), equalTo(1.0))); + MatcherAssert.assertThat(h.getBuckets().get(0).getDocCount(), equalTo(2L)); + }); + + } + + public void testMissing() throws IOException { + testAggregation( + new MatchAllDocsQuery(), + asList( + new MultiTermsValuesSourceConfig.Builder().setFieldName(KEYWORD_FIELD).setMissing("a").build(), + new MultiTermsValuesSourceConfig.Builder().setFieldName(INT_FIELD).setMissing(1).build(), + new MultiTermsValuesSourceConfig.Builder().setFieldName(FLOAT_FIELD).setMissing(2.0f).build() + ), + NONE_DECORATOR, + iw -> { + iw.addDocument( + asList( + new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("a")), + new NumericDocValuesField(INT_FIELD, 1), + new FloatDocValuesField(FLOAT_FIELD, 2.0f) + ) + ); + iw.addDocument( + asList( + // missing KEYWORD_FIELD + new NumericDocValuesField(INT_FIELD, 1), + new FloatDocValuesField(FLOAT_FIELD, 1.0f) + ) + ); + iw.addDocument( + asList( + new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("b")), + // missing INT_FIELD + new FloatDocValuesField(FLOAT_FIELD, 2.0f) + ) + ); + iw.addDocument(asList(new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("c")), new NumericDocValuesField(INT_FIELD, 2) + // missing FLOAT_FIELD + )); + iw.addDocument(singletonList(new SortedDocValuesField(UNRELATED_KEYWORD_FIELD, new BytesRef("unrelated")))); + }, + h -> { + MatcherAssert.assertThat(h.getBuckets(), hasSize(4)); + MatcherAssert.assertThat(h.getBuckets().get(0).getKey(), contains(equalTo("a"), equalTo(1L), equalTo(2.0))); + MatcherAssert.assertThat(h.getBuckets().get(0).getDocCount(), equalTo(2L)); + MatcherAssert.assertThat(h.getBuckets().get(1).getKey(), contains(equalTo("a"), equalTo(1L), equalTo(1.0))); + MatcherAssert.assertThat(h.getBuckets().get(1).getDocCount(), equalTo(1L)); + MatcherAssert.assertThat(h.getBuckets().get(2).getKey(), contains(equalTo("b"), equalTo(1L), equalTo(2.0))); + MatcherAssert.assertThat(h.getBuckets().get(2).getDocCount(), equalTo(1L)); + MatcherAssert.assertThat(h.getBuckets().get(3).getKey(), contains(equalTo("c"), equalTo(2L), equalTo(2.0))); + MatcherAssert.assertThat(h.getBuckets().get(3).getDocCount(), equalTo(1L)); + } + ); + } + + public void testMixKeywordAndBoolean() throws IOException { + testAggregation(new MatchAllDocsQuery(), fieldConfigs(asList(KEYWORD_FIELD, BOOL_FIELD)), NONE_DECORATOR, iw -> { + iw.addDocument(asList(new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("a")), new NumericDocValuesField(BOOL_FIELD, 1))); + iw.addDocument(asList(new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("a")), new NumericDocValuesField(BOOL_FIELD, 0))); + iw.addDocument(asList(new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("b")), new NumericDocValuesField(BOOL_FIELD, 0))); + iw.addDocument(asList(new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("b")), new NumericDocValuesField(BOOL_FIELD, 1))); + }, h -> { + MatcherAssert.assertThat(h.getBuckets(), hasSize(4)); + MatcherAssert.assertThat(h.getBuckets().get(0).getKey(), contains(equalTo("a"), equalTo(false))); + MatcherAssert.assertThat(h.getBuckets().get(0).getKeyAsString(), equalTo("a|false")); + MatcherAssert.assertThat(h.getBuckets().get(0).getDocCount(), equalTo(1L)); + MatcherAssert.assertThat(h.getBuckets().get(1).getKey(), contains(equalTo("a"), equalTo(true))); + MatcherAssert.assertThat(h.getBuckets().get(1).getKeyAsString(), equalTo("a|true")); + MatcherAssert.assertThat(h.getBuckets().get(1).getDocCount(), equalTo(1L)); + MatcherAssert.assertThat(h.getBuckets().get(2).getKey(), contains(equalTo("b"), equalTo(false))); + MatcherAssert.assertThat(h.getBuckets().get(2).getKeyAsString(), equalTo("b|false")); + MatcherAssert.assertThat(h.getBuckets().get(2).getDocCount(), equalTo(1L)); + MatcherAssert.assertThat(h.getBuckets().get(3).getKey(), contains(equalTo("b"), equalTo(true))); + MatcherAssert.assertThat(h.getBuckets().get(3).getKeyAsString(), equalTo("b|true")); + MatcherAssert.assertThat(h.getBuckets().get(3).getDocCount(), equalTo(1L)); + }); + } + + public void testGeoPointField() { + assertThrows( + IllegalArgumentException.class, + () -> testAggregation( + new MatchAllDocsQuery(), + asList(term(KEYWORD_FIELD), term(GEO_POINT_FIELD)), + NONE_DECORATOR, + iw -> {}, + f -> fail("should throw exception") + ) + ); + } + + public void testMinDocCount() throws IOException { + testAggregation(new MatchAllDocsQuery(), fieldConfigs(asList(KEYWORD_FIELD, INT_FIELD)), b -> b.minDocCount(2), iw -> { + iw.addDocument(asList(new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("a")), new NumericDocValuesField(INT_FIELD, 1))); + iw.addDocument(asList(new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("a")), new NumericDocValuesField(INT_FIELD, 1))); + iw.addDocument(asList(new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("a")), new NumericDocValuesField(INT_FIELD, 2))); + iw.addDocument(asList(new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("b")), new NumericDocValuesField(INT_FIELD, 1))); + iw.addDocument(asList(new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("c")), new NumericDocValuesField(INT_FIELD, 2))); + }, h -> { + MatcherAssert.assertThat(h.getBuckets(), hasSize(1)); + MatcherAssert.assertThat(h.getBuckets().get(0).getKey(), contains(equalTo("a"), equalTo(1L))); + MatcherAssert.assertThat(h.getBuckets().get(0).getDocCount(), equalTo(2L)); + }); + } + + public void testIncludeExclude() throws IOException { + testAggregation( + new MatchAllDocsQuery(), + asList( + new MultiTermsValuesSourceConfig.Builder().setFieldName(KEYWORD_FIELD) + .setIncludeExclude(new IncludeExclude("a", null)) + .build(), + term(INT_FIELD) + ), + NONE_DECORATOR, + iw -> { + iw.addDocument(asList(new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("a")), new NumericDocValuesField(INT_FIELD, 1))); + iw.addDocument(asList(new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("a")), new NumericDocValuesField(INT_FIELD, 1))); + iw.addDocument(asList(new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("b")), new NumericDocValuesField(INT_FIELD, 1))); + iw.addDocument(asList(new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("c")), new NumericDocValuesField(INT_FIELD, 2))); + }, + h -> { + MatcherAssert.assertThat(h.getBuckets(), hasSize(1)); + MatcherAssert.assertThat(h.getBuckets().get(0).getKey(), contains(equalTo("a"), equalTo(1L))); + MatcherAssert.assertThat(h.getBuckets().get(0).getDocCount(), equalTo(2L)); + } + ); + } + + private void testAggregation( + Query query, + List terms, + Consumer decorator, + CheckedConsumer indexBuilder, + Consumer verify + ) throws IOException { + MultiTermsAggregationBuilder builder = createTestAggregatorBuilder(terms); + if (decorator != NONE_DECORATOR) { + decorator.accept(builder); + } + testCase(builder, query, indexBuilder, verify, mappedFieldTypeMap.values().toArray(new MappedFieldType[] {})); + } + + private MultiTermsValuesSourceConfig term(String field) { + return new MultiTermsValuesSourceConfig.Builder().setFieldName(field).build(); + } + + private MultiTermsAggregationBuilder createTestAggregatorBuilder(List termsConfig) { + MultiTermsAggregationBuilder factory = new MultiTermsAggregationBuilder(AGG_NAME); + factory.terms(termsConfig); + + if (randomBoolean()) { + factory.size(randomIntBetween(10, Integer.MAX_VALUE)); + } + if (randomBoolean()) { + factory.shardSize(randomIntBetween(10, Integer.MAX_VALUE)); + } + if (randomBoolean()) { + factory.showTermDocCountError(randomBoolean()); + } + return factory; + } + + private List fieldConfigs(List terms) { + List termConfigs = new ArrayList<>(); + for (String term : terms) { + termConfigs.add(term(term)); + } + return termConfigs; + } + + private static DateFieldMapper.DateFieldType dateFieldType(String name) { + return new DateFieldMapper.DateFieldType( + name, + true, + false, + true, + DateFormatter.forPattern("date"), + DateFieldMapper.Resolution.MILLISECONDS, + null, + Collections.emptyMap() + ); + } +} diff --git a/server/src/test/java/org/opensearch/search/aggregations/support/MultiTermsValuesSourceConfigTests.java b/server/src/test/java/org/opensearch/search/aggregations/support/MultiTermsValuesSourceConfigTests.java new file mode 100644 index 00000000000..a142faa2048 --- /dev/null +++ b/server/src/test/java/org/opensearch/search/aggregations/support/MultiTermsValuesSourceConfigTests.java @@ -0,0 +1,65 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.aggregations.support; + +import org.opensearch.common.io.stream.NamedWriteableRegistry; +import org.opensearch.common.io.stream.Writeable; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.NamedXContentRegistry; +import org.opensearch.common.xcontent.XContentParser; +import org.opensearch.script.Script; +import org.opensearch.search.SearchModule; +import org.opensearch.test.AbstractSerializingTestCase; + +import java.io.IOException; +import java.time.ZoneId; +import java.util.Collections; + +import static org.hamcrest.Matchers.equalTo; + +public class MultiTermsValuesSourceConfigTests extends AbstractSerializingTestCase { + + @Override + protected MultiTermsValuesSourceConfig doParseInstance(XContentParser parser) throws IOException { + return MultiTermsValuesSourceConfig.PARSER.apply(true, true, true, true).apply(parser, null).build(); + } + + @Override + protected MultiTermsValuesSourceConfig createTestInstance() { + String field = randomAlphaOfLength(10); + Object missing = randomBoolean() ? randomAlphaOfLength(10) : null; + ZoneId timeZone = randomBoolean() ? randomZone() : null; + Script script = randomBoolean() ? new Script(randomAlphaOfLength(10)) : null; + return new MultiTermsValuesSourceConfig.Builder().setFieldName(field) + .setMissing(missing) + .setScript(script) + .setTimeZone(timeZone) + .build(); + } + + @Override + protected Writeable.Reader instanceReader() { + return MultiTermsValuesSourceConfig::new; + } + + public void testMissingFieldScript() { + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> new MultiTermsValuesSourceConfig.Builder().build()); + assertThat(e.getMessage(), equalTo("[field] and [script] cannot both be null. Please specify one or the other.")); + } + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + return new NamedWriteableRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedWriteables()); + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + return new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()); + } +} diff --git a/test/framework/src/main/java/org/opensearch/test/InternalAggregationTestCase.java b/test/framework/src/main/java/org/opensearch/test/InternalAggregationTestCase.java index 6be7abffb9a..f138de152a4 100644 --- a/test/framework/src/main/java/org/opensearch/test/InternalAggregationTestCase.java +++ b/test/framework/src/main/java/org/opensearch/test/InternalAggregationTestCase.java @@ -101,9 +101,11 @@ import org.opensearch.search.aggregations.bucket.sampler.ParsedSampler; import org.opensearch.search.aggregations.bucket.terms.DoubleTerms; import org.opensearch.search.aggregations.bucket.terms.LongRareTerms; import org.opensearch.search.aggregations.bucket.terms.LongTerms; +import org.opensearch.search.aggregations.bucket.terms.MultiTermsAggregationBuilder; import org.opensearch.search.aggregations.bucket.terms.ParsedDoubleTerms; import org.opensearch.search.aggregations.bucket.terms.ParsedLongRareTerms; import org.opensearch.search.aggregations.bucket.terms.ParsedLongTerms; +import org.opensearch.search.aggregations.bucket.terms.ParsedMultiTerms; import org.opensearch.search.aggregations.bucket.terms.ParsedSignificantLongTerms; import org.opensearch.search.aggregations.bucket.terms.ParsedSignificantStringTerms; import org.opensearch.search.aggregations.bucket.terms.ParsedStringRareTerms; @@ -289,6 +291,7 @@ public abstract class InternalAggregationTestCase map.put(IpRangeAggregationBuilder.NAME, (p, c) -> ParsedBinaryRange.fromXContent(p, (String) c)); map.put(TopHitsAggregationBuilder.NAME, (p, c) -> ParsedTopHits.fromXContent(p, (String) c)); map.put(CompositeAggregationBuilder.NAME, (p, c) -> ParsedComposite.fromXContent(p, (String) c)); + map.put(MultiTermsAggregationBuilder.NAME, (p, c) -> ParsedMultiTerms.fromXContent(p, (String) c)); namedXContents = map.entrySet() .stream()