From 1a58a487f0dd45a76f700abad0d3f9e7dd5eab37 Mon Sep 17 00:00:00 2001 From: Nikita Glashenko Date: Sat, 3 Aug 2019 00:09:48 +0500 Subject: [PATCH] Add more flexibility to MovingFunction window alignment (#44360) Introduce shift field to MovingFunction aggregation. By default, shift = 0. Behavior, in this case, is the same as before. Increasing shift by 1 moves starting window position by 1 to the right. To simply include current bucket to the window, use shift = 1 For center alignment (n/2 values before and after the current bucket), use shift = window / 2 For right alignment (n values after the current bucket), use shift = window. --- .../pipeline/movfn-aggregation.asciidoc | 17 ++++++- .../MovFnPipelineAggregationBuilder.java | 26 +++++++++-- .../pipeline/MovFnPipelineAggregator.java | 44 ++++++++++++++++--- ...eAggregationBuilderSerializationTests.java | 10 ++++- .../aggregations/pipeline/MovFnUnitTests.java | 42 +++++++++++++----- 5 files changed, 113 insertions(+), 26 deletions(-) diff --git a/docs/reference/aggregations/pipeline/movfn-aggregation.asciidoc b/docs/reference/aggregations/pipeline/movfn-aggregation.asciidoc index ea414237174..cdea58d45ae 100644 --- a/docs/reference/aggregations/pipeline/movfn-aggregation.asciidoc +++ b/docs/reference/aggregations/pipeline/movfn-aggregation.asciidoc @@ -24,14 +24,15 @@ A `moving_fn` aggregation looks like this in isolation: -------------------------------------------------- // NOTCONSOLE -[[moving-avg-params]] -.`moving_avg` Parameters +[[moving-fn-params]] +.`moving_fn` Parameters [options="header"] |=== |Parameter Name |Description |Required |Default Value |`buckets_path` |Path to the metric of interest (see <> for more details |Required | |`window` |The size of window to "slide" across the histogram. |Required | |`script` |The script that should be executed on each window of data |Required | +|`shift` |<> of window position. |Optional | 0 |=== `moving_fn` aggregations must be embedded inside of a `histogram` or `date_histogram` aggregation. They can be @@ -169,6 +170,18 @@ POST /_search // CONSOLE // TEST[setup:sales] +[[shift-parameter]] +==== shift parameter + +By default (with `shift = 0`), the window that is offered for calculation is the last `n` values excluding the current bucket. +Increasing `shift` by 1 moves starting window position by `1` to the right. + +- To include current bucket to the window, use `shift = 1`. +- For center alignment (`n / 2` values before and after the current bucket), use `shift = window / 2`. +- For right alignment (`n` values after the current bucket), use `shift = window`. + +If either of window edges moves outside the borders of data series, the window shrinks to include available values only. + ==== Pre-built Functions For convenience, a number of functions have been prebuilt and are available inside the `moving_fn` script context: diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/pipeline/MovFnPipelineAggregationBuilder.java b/server/src/main/java/org/elasticsearch/search/aggregations/pipeline/MovFnPipelineAggregationBuilder.java index 44f26c3c32b..7d56197cc46 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/pipeline/MovFnPipelineAggregationBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/pipeline/MovFnPipelineAggregationBuilder.java @@ -19,6 +19,7 @@ package org.elasticsearch.search.aggregations.pipeline; +import org.elasticsearch.Version; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; @@ -48,12 +49,14 @@ import static org.elasticsearch.search.aggregations.pipeline.PipelineAggregator. public class MovFnPipelineAggregationBuilder extends AbstractPipelineAggregationBuilder { public static final String NAME = "moving_fn"; private static final ParseField WINDOW = new ParseField("window"); + private static final ParseField SHIFT = new ParseField("shift"); private final Script script; private final String bucketsPathString; private String format = null; private GapPolicy gapPolicy = GapPolicy.SKIP; private int window; + private int shift; private static final Function> PARSER = name -> { @@ -68,6 +71,7 @@ public class MovFnPipelineAggregationBuilder extends AbstractPipelineAggregation (p, c) -> Script.parse(p), Script.SCRIPT_PARSE_FIELD, ObjectParser.ValueType.OBJECT_OR_STRING); parser.declareInt(ConstructingObjectParser.constructorArg(), WINDOW); + parser.declareInt(MovFnPipelineAggregationBuilder::setShift, SHIFT); parser.declareString(MovFnPipelineAggregationBuilder::format, FORMAT); parser.declareField(MovFnPipelineAggregationBuilder::gapPolicy, p -> { if (p.currentToken() == XContentParser.Token.VALUE_STRING) { @@ -97,6 +101,11 @@ public class MovFnPipelineAggregationBuilder extends AbstractPipelineAggregation format = in.readOptionalString(); gapPolicy = GapPolicy.readFrom(in); window = in.readInt(); + if (in.getVersion().onOrAfter(Version.V_8_0_0)) { // TODO change this after backport + shift = in.readInt(); + } else { + shift = 0; + } } @Override @@ -106,6 +115,9 @@ public class MovFnPipelineAggregationBuilder extends AbstractPipelineAggregation out.writeOptionalString(format); gapPolicy.writeTo(out); out.writeInt(window); + if (out.getVersion().onOrAfter(Version.V_8_0_0)) { // TODO change this after backport + out.writeInt(shift); + } } /** @@ -168,9 +180,13 @@ public class MovFnPipelineAggregationBuilder extends AbstractPipelineAggregation this.window = window; } + public void setShift(int shift) { + this.shift = shift; + } + @Override public void doValidate(AggregatorFactory parent, Collection aggFactories, - Collection pipelineAggregatoractories) { + Collection pipelineAggregatorFactories) { if (window <= 0) { throw new IllegalArgumentException("[" + WINDOW.getPreferredName() + "] must be a positive, non-zero integer."); } @@ -180,7 +196,7 @@ public class MovFnPipelineAggregationBuilder extends AbstractPipelineAggregation @Override protected PipelineAggregator createInternal(Map metaData) { - return new MovFnPipelineAggregator(name, bucketsPathString, script, window, formatter(), gapPolicy, metaData); + return new MovFnPipelineAggregator(name, bucketsPathString, script, window, shift, formatter(), gapPolicy, metaData); } @Override @@ -192,6 +208,7 @@ public class MovFnPipelineAggregationBuilder extends AbstractPipelineAggregation } builder.field(GAP_POLICY.getPreferredName(), gapPolicy.getName()); builder.field(WINDOW.getPreferredName(), window); + builder.field(SHIFT.getPreferredName(), shift); return builder; } @@ -225,7 +242,7 @@ public class MovFnPipelineAggregationBuilder extends AbstractPipelineAggregation @Override public int hashCode() { - return Objects.hash(super.hashCode(), bucketsPathString, script, format, gapPolicy, window); + return Objects.hash(super.hashCode(), bucketsPathString, script, format, gapPolicy, window, shift); } @Override @@ -238,7 +255,8 @@ public class MovFnPipelineAggregationBuilder extends AbstractPipelineAggregation && Objects.equals(script, other.script) && Objects.equals(format, other.format) && Objects.equals(gapPolicy, other.gapPolicy) - && Objects.equals(window, other.window); + && Objects.equals(window, other.window) + && Objects.equals(shift, other.shift); } @Override diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/pipeline/MovFnPipelineAggregator.java b/server/src/main/java/org/elasticsearch/search/aggregations/pipeline/MovFnPipelineAggregator.java index 4f14df2d66d..b0915350c26 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/pipeline/MovFnPipelineAggregator.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/pipeline/MovFnPipelineAggregator.java @@ -19,7 +19,7 @@ package org.elasticsearch.search.aggregations.pipeline; -import org.elasticsearch.common.collect.EvictingQueue; +import org.elasticsearch.Version; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.script.Script; @@ -63,8 +63,9 @@ public class MovFnPipelineAggregator extends PipelineAggregator { private final Script script; private final String bucketsPath; private final int window; + private final int shift; - MovFnPipelineAggregator(String name, String bucketsPath, Script script, int window, DocValueFormat formatter, + MovFnPipelineAggregator(String name, String bucketsPath, Script script, int window, int shift, DocValueFormat formatter, BucketHelpers.GapPolicy gapPolicy, Map metadata) { super(name, new String[]{bucketsPath}, metadata); this.bucketsPath = bucketsPath; @@ -72,6 +73,7 @@ public class MovFnPipelineAggregator extends PipelineAggregator { this.formatter = formatter; this.gapPolicy = gapPolicy; this.window = window; + this.shift = shift; } public MovFnPipelineAggregator(StreamInput in) throws IOException { @@ -81,6 +83,11 @@ public class MovFnPipelineAggregator extends PipelineAggregator { gapPolicy = BucketHelpers.GapPolicy.readFrom(in); bucketsPath = in.readString(); window = in.readInt(); + if (in.getVersion().onOrAfter(Version.V_8_0_0)) { // TODO change this after backport + shift = in.readInt(); + } else { + shift = 0; + } } @Override @@ -90,6 +97,9 @@ public class MovFnPipelineAggregator extends PipelineAggregator { gapPolicy.writeTo(out); out.writeString(bucketsPath); out.writeInt(window); + if (out.getVersion().onOrAfter(Version.V_8_0_0)) { // TODO change this after backport + out.writeInt(shift); + } } @Override @@ -106,7 +116,6 @@ public class MovFnPipelineAggregator extends PipelineAggregator { HistogramFactory factory = (HistogramFactory) histo; List newBuckets = new ArrayList<>(); - EvictingQueue values = new EvictingQueue<>(this.window); // Initialize the script MovingFunctionScript.Factory scriptFactory = reduceContext.scriptService().compile(script, MovingFunctionScript.CONTEXT); @@ -117,6 +126,12 @@ public class MovFnPipelineAggregator extends PipelineAggregator { MovingFunctionScript executableScript = scriptFactory.newInstance(); + List values = buckets.stream() + .map(b -> resolveBucketValue(histo, b, bucketsPaths()[0], gapPolicy)) + .filter(v -> v != null && v.isNaN() == false) + .collect(Collectors.toList()); + + int index = 0; for (InternalMultiBucketAggregation.InternalBucket bucket : buckets) { Double thisBucketValue = resolveBucketValue(histo, bucket, bucketsPaths()[0], gapPolicy); @@ -124,11 +139,18 @@ public class MovFnPipelineAggregator extends PipelineAggregator { // since we only change newBucket if we can add to it MultiBucketsAggregation.Bucket newBucket = bucket; - if (thisBucketValue != null && thisBucketValue.equals(Double.NaN) == false) { + if (thisBucketValue != null && thisBucketValue.isNaN() == false) { // The custom context mandates that the script returns a double (not Double) so we // don't need null checks, etc. - double movavg = executableScript.execute(vars, values.stream().mapToDouble(Double::doubleValue).toArray()); + int fromIndex = clamp(index - window + shift, values); + int toIndex = clamp(index + shift, values); + double movavg = executableScript.execute( + vars, + values.subList(fromIndex, toIndex).stream() + .mapToDouble(Double::doubleValue) + .toArray() + ); List aggs = StreamSupport .stream(bucket.getAggregations().spliterator(), false) @@ -136,11 +158,21 @@ public class MovFnPipelineAggregator extends PipelineAggregator { .collect(Collectors.toList()); aggs.add(new InternalSimpleValue(name(), movavg, formatter, new ArrayList<>(), metaData())); newBucket = factory.createBucket(factory.getKey(bucket), bucket.getDocCount(), new InternalAggregations(aggs)); - values.offer(thisBucketValue); + index++; } newBuckets.add(newBucket); } return factory.createAggregation(newBuckets); } + + private int clamp(int index, List list) { + if (index < 0) { + return 0; + } + if (index > list.size()) { + return list.size(); + } + return index; + } } diff --git a/server/src/test/java/org/elasticsearch/search/aggregations/pipeline/MovFnPipelineAggregationBuilderSerializationTests.java b/server/src/test/java/org/elasticsearch/search/aggregations/pipeline/MovFnPipelineAggregationBuilderSerializationTests.java index 49923640805..cb1e2d5249b 100644 --- a/server/src/test/java/org/elasticsearch/search/aggregations/pipeline/MovFnPipelineAggregationBuilderSerializationTests.java +++ b/server/src/test/java/org/elasticsearch/search/aggregations/pipeline/MovFnPipelineAggregationBuilderSerializationTests.java @@ -22,7 +22,6 @@ package org.elasticsearch.search.aggregations.pipeline; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.script.Script; -import org.elasticsearch.search.aggregations.pipeline.MovFnPipelineAggregationBuilder; import org.elasticsearch.test.AbstractSerializingTestCase; import java.io.IOException; @@ -31,7 +30,14 @@ public class MovFnPipelineAggregationBuilderSerializationTests extends AbstractS @Override protected MovFnPipelineAggregationBuilder createTestInstance() { - return new MovFnPipelineAggregationBuilder(randomAlphaOfLength(10), "foo", new Script("foo"), randomIntBetween(1, 10)); + MovFnPipelineAggregationBuilder builder = new MovFnPipelineAggregationBuilder( + randomAlphaOfLength(10), + "foo", + new Script("foo"), + randomIntBetween(1, 10) + ); + builder.setShift(randomIntBetween(1, 10)); + return builder; } @Override diff --git a/server/src/test/java/org/elasticsearch/search/aggregations/pipeline/MovFnUnitTests.java b/server/src/test/java/org/elasticsearch/search/aggregations/pipeline/MovFnUnitTests.java index 27490fa202b..862f5564555 100644 --- a/server/src/test/java/org/elasticsearch/search/aggregations/pipeline/MovFnUnitTests.java +++ b/server/src/test/java/org/elasticsearch/search/aggregations/pipeline/MovFnUnitTests.java @@ -53,6 +53,7 @@ import java.util.List; import java.util.Map; import java.util.Set; import java.util.function.Consumer; +import java.util.stream.Collectors; import static org.hamcrest.Matchers.equalTo; import static org.mockito.Mockito.mock; @@ -79,25 +80,42 @@ public class MovFnUnitTests extends AggregatorTestCase { private static final List datasetValues = Arrays.asList(1,2,3,4,5,6,7,8,9,10); public void testMatchAllDocs() throws IOException { - Query query = new MatchAllDocsQuery(); - Script script = new Script(Script.DEFAULT_SCRIPT_TYPE, "painless", "test", Collections.emptyMap()); + check(0, List.of(Double.NaN, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0)); + } + public void testShift() throws IOException { + check(1, List.of(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0)); + check(5, List.of(5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 10.0, 10.0, Double.NaN, Double.NaN)); + check(-5, List.of(Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN, 1.0, 2.0, 3.0, 4.0)); + } + + public void testWideWindow() throws IOException { + Script script = new Script(Script.DEFAULT_SCRIPT_TYPE, "painless", "test", Collections.emptyMap()); + MovFnPipelineAggregationBuilder builder = new MovFnPipelineAggregationBuilder("mov_fn", "avg", script, 100); + builder.setShift(50); + check(builder, script, List.of(10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0)); + } + + private void check(int shift, List expected) throws IOException { + Script script = new Script(Script.DEFAULT_SCRIPT_TYPE, "painless", "test", Collections.emptyMap()); + MovFnPipelineAggregationBuilder builder = new MovFnPipelineAggregationBuilder("mov_fn", "avg", script, 3); + builder.setShift(shift); + check(builder, script, expected); + } + + private void check(MovFnPipelineAggregationBuilder builder, Script script, List expected) throws IOException { + Query query = new MatchAllDocsQuery(); DateHistogramAggregationBuilder aggBuilder = new DateHistogramAggregationBuilder("histo"); aggBuilder.calendarInterval(DateHistogramInterval.DAY).field(DATE_FIELD); aggBuilder.subAggregation(new AvgAggregationBuilder("avg").field(VALUE_FIELD)); - aggBuilder.subAggregation(new MovFnPipelineAggregationBuilder("mov_fn", "avg", script, 3)); + aggBuilder.subAggregation(builder); executeTestCase(query, aggBuilder, histogram -> { - assertEquals(10, histogram.getBuckets().size()); List buckets = histogram.getBuckets(); - for (int i = 0; i < buckets.size(); i++) { - if (i == 0) { - assertThat(((InternalSimpleValue)(buckets.get(i).getAggregations().get("mov_fn"))).value(), equalTo(Double.NaN)); - } else { - assertThat(((InternalSimpleValue)(buckets.get(i).getAggregations().get("mov_fn"))).value(), equalTo(((double) i))); - } - - } + List actual = buckets.stream() + .map(bucket -> ((InternalSimpleValue) (bucket.getAggregations().get("mov_fn"))).value()) + .collect(Collectors.toList()); + assertThat(actual, equalTo(expected)); }, 1000, script); }