Fix scripted metric in ccs (backport of #54776) (#54888)

`scripted_metric` did not work with cross cluster search because it
assumed that you'd never perform a partial reduction, serialize the
results, and then perform a final reduction. That
serialized-after-partial-reduction step was broken.

This is also required to support #54758.
This commit is contained in:
Nik Everett 2020-04-07 10:43:00 -04:00 committed by GitHub
parent 915092dc28
commit 3c56e0de42
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 91 additions and 18 deletions

View File

@ -196,6 +196,44 @@
- match: { aggregations.cluster.buckets.1.animal.buckets.1.s.value: 0 }
- match: { aggregations.cluster.buckets.1.average_sum.value: 1 }
# scripted_metric
- do:
search:
index: test_index,my_remote_cluster:test_index
body:
seq_no_primary_term: true
aggs:
cluster:
terms:
field: f1.keyword
aggs:
animal_length:
scripted_metric:
init_script: |
state.sum = 0
map_script: |
state.sum += doc['animal.keyword'].value.length()
combine_script: |
state.sum
reduce_script: |
long sum = 0;
for (s in states) {
sum += s;
}
return sum
- match: { num_reduce_phases: 3 }
- match: {_clusters.total: 2}
- match: {_clusters.successful: 2}
- match: {_clusters.skipped: 0}
- match: { _shards.total: 5 }
- match: { hits.total.value: 11 }
- length: { aggregations.cluster.buckets: 2 }
- match: { aggregations.cluster.buckets.0.key: "remote_cluster" }
- match: { aggregations.cluster.buckets.0.doc_count: 6 }
- match: { aggregations.cluster.buckets.0.animal_length.value: 34 }
- match: { aggregations.cluster.buckets.1.key: "local_cluster" }
- match: { aggregations.cluster.buckets.1.animal_length.value: 15 }
---
"Add transient remote cluster based on the preset cluster":
- do:

View File

@ -19,12 +19,13 @@
package org.elasticsearch.search.aggregations.metrics;
import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.util.CollectionUtils;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.script.ScriptedMetricAggContexts;
import org.elasticsearch.script.Script;
import org.elasticsearch.script.ScriptedMetricAggContexts;
import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
@ -36,19 +37,21 @@ import java.util.List;
import java.util.Map;
import java.util.Objects;
import static java.util.Collections.singletonList;
public class InternalScriptedMetric extends InternalAggregation implements ScriptedMetric {
final Script reduceScript;
private final List<Object> aggregation;
private final List<Object> aggregations;
InternalScriptedMetric(String name, Object aggregation, Script reduceScript, List<PipelineAggregator> pipelineAggregators,
Map<String, Object> metadata) {
this(name, Collections.singletonList(aggregation), reduceScript, pipelineAggregators, metadata);
}
private InternalScriptedMetric(String name, List<Object> aggregation, Script reduceScript, List<PipelineAggregator> pipelineAggregators,
Map<String, Object> metadata) {
private InternalScriptedMetric(String name, List<Object> aggregations, Script reduceScript,
List<PipelineAggregator> pipelineAggregators, Map<String, Object> metadata) {
super(name, pipelineAggregators, metadata);
this.aggregation = aggregation;
this.aggregations = aggregations;
this.reduceScript = reduceScript;
}
@ -58,13 +61,29 @@ public class InternalScriptedMetric extends InternalAggregation implements Scrip
public InternalScriptedMetric(StreamInput in) throws IOException {
super(in);
reduceScript = in.readOptionalWriteable(Script::new);
aggregation = Collections.singletonList(in.readGenericValue());
if (in.getVersion().before(Version.V_7_8_0)) {
aggregations = singletonList(in.readGenericValue());
} else {
aggregations = in.readList(StreamInput::readGenericValue);
}
}
@Override
protected void doWriteTo(StreamOutput out) throws IOException {
out.writeOptionalWriteable(reduceScript);
out.writeGenericValue(aggregation());
if (out.getVersion().before(Version.V_7_8_0)) {
if (aggregations.size() > 0) {
/*
* I *believe* that this situation can only happen in cross
* cluster search right now. Thus the message. But computers
* are hard.
*/
throw new IllegalArgumentException("scripted_metric doesn't support cross cluster search until 7.8.0");
}
out.writeGenericValue(aggregations.get(0));
} else {
out.writeCollection(aggregations, StreamOutput::writeGenericValue);
}
}
@Override
@ -74,14 +93,14 @@ public class InternalScriptedMetric extends InternalAggregation implements Scrip
@Override
public Object aggregation() {
if (aggregation.size() != 1) {
if (aggregations.size() != 1) {
throw new IllegalStateException("aggregation was not reduced");
}
return aggregation.get(0);
return aggregations.get(0);
}
List<Object> getAggregation() {
return aggregation;
return aggregations;
}
@Override
@ -89,7 +108,7 @@ public class InternalScriptedMetric extends InternalAggregation implements Scrip
List<Object> aggregationObjects = new ArrayList<>();
for (InternalAggregation aggregation : aggregations) {
InternalScriptedMetric mapReduceAggregation = (InternalScriptedMetric) aggregation;
aggregationObjects.addAll(mapReduceAggregation.aggregation);
aggregationObjects.addAll(mapReduceAggregation.aggregations);
}
InternalScriptedMetric firstAggregation = ((InternalScriptedMetric) aggregations.get(0));
List<Object> aggregation;
@ -142,12 +161,12 @@ public class InternalScriptedMetric extends InternalAggregation implements Scrip
InternalScriptedMetric other = (InternalScriptedMetric) obj;
return Objects.equals(reduceScript, other.reduceScript) &&
Objects.equals(aggregation, other.aggregation);
Objects.equals(aggregations, other.aggregations);
}
@Override
public int hashCode() {
return Objects.hash(super.hashCode(), reduceScript, aggregation);
return Objects.hash(super.hashCode(), reduceScript, aggregations);
}
}

View File

@ -132,7 +132,7 @@ public class InternalScriptedMetricTests extends InternalAggregationTestCase<Int
if (hasReduceScript) {
assertEquals(inputs.size(), reduced.aggregation());
} else {
assertEquals(inputs.size(), ((List<Object>) reduced.aggregation()).size());
assertEquals(inputs.size(), ((List<?>) reduced.aggregation()).size());
}
}

View File

@ -281,7 +281,7 @@ public abstract class InternalAggregationTestCase<T extends InternalAggregation>
return createTestInstance(name, metadata);
}
public void testReduceRandom() {
public void testReduceRandom() throws IOException {
String name = randomAlphaOfLength(5);
List<T> inputs = new ArrayList<>();
List<InternalAggregation> toReduce = new ArrayList<>();
@ -296,7 +296,7 @@ public abstract class InternalAggregationTestCase<T extends InternalAggregation>
ScriptService mockScriptService = mockScriptService();
MockBigArrays bigArrays = new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService());
if (randomBoolean() && toReduce.size() > 1) {
// sometimes do an incremental reduce
// sometimes do a partial reduce
Collections.shuffle(toReduce, random());
int r = randomIntBetween(1, toReduceSize);
List<InternalAggregation> internalAggregations = toReduce.subList(0, r);
@ -311,6 +311,14 @@ public abstract class InternalAggregationTestCase<T extends InternalAggregation>
int reducedBucketCount = countInnerBucket(reduced);
//check that non final reduction never adds buckets
assertThat(reducedBucketCount, lessThanOrEqualTo(initialBucketCount));
/*
* Sometimes serializing and deserializing the partially reduced
* result to simulate the compaction that we attempt after a
* partial reduce. And to simulate cross cluster search.
*/
if (randomBoolean()) {
reduced = copyInstance(reduced);
}
toReduce = new ArrayList<>(toReduce.subList(r, toReduceSize));
toReduce.add(reduced);
}

View File

@ -42,8 +42,16 @@ public class InternalStringStatsTests extends InternalAggregationTestCase<Intern
if (randomBoolean()) {
return new InternalStringStats(name, 0, 0, 0, 0, emptyMap(), randomBoolean(), DocValueFormat.RAW, emptyList(), metadata);
}
return new InternalStringStats(name, randomLongBetween(1, Long.MAX_VALUE),
randomNonNegativeLong(), between(0, Integer.MAX_VALUE), between(0, Integer.MAX_VALUE), randomCharOccurrences(),
/*
* Pick random count and length that are *much* less than
* Long.MAX_VALUE because reduction adds them together and sometimes
* serializes them and that serialization would fail if the sum has
* wrapped to a negative number.
*/
long count = randomLongBetween(1, Integer.MAX_VALUE);
long totalLength = randomLongBetween(0, count * 10);
return new InternalStringStats(name, count, totalLength,
between(0, Integer.MAX_VALUE), between(0, Integer.MAX_VALUE), randomCharOccurrences(),
randomBoolean(), DocValueFormat.RAW,
emptyList(), metadata);
};