From 2aa94a5133a909c3072ec228b37034124fb6eceb Mon Sep 17 00:00:00 2001 From: Jason Tedor Date: Mon, 18 Jan 2016 12:00:33 -0500 Subject: [PATCH 01/11] Normalize unavailable load average This commit normalizes the one-minute load average obtained from OperatingSystemMXBean#getSystemLoadAverage to -1 when it is not available. This is to reflect the Javadocs for this method saying "If the load average is not available, a negative value is returned." --- core/src/main/java/org/elasticsearch/monitor/os/OsProbe.java | 2 +- .../test/java/org/elasticsearch/monitor/os/OsProbeTests.java | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/core/src/main/java/org/elasticsearch/monitor/os/OsProbe.java b/core/src/main/java/org/elasticsearch/monitor/os/OsProbe.java index 077b4218fa9..3322c05fc97 100644 --- a/core/src/main/java/org/elasticsearch/monitor/os/OsProbe.java +++ b/core/src/main/java/org/elasticsearch/monitor/os/OsProbe.java @@ -124,7 +124,7 @@ public class OsProbe { } try { double oneMinuteLoadAverage = (double) getSystemLoadAverage.invoke(osMxBean); - return new double[] { oneMinuteLoadAverage, -1, -1 }; + return new double[] { oneMinuteLoadAverage >= 0 ? oneMinuteLoadAverage : -1, -1, -1 }; } catch (Throwable t) { return null; } diff --git a/core/src/test/java/org/elasticsearch/monitor/os/OsProbeTests.java b/core/src/test/java/org/elasticsearch/monitor/os/OsProbeTests.java index 4f4319e212c..83007c5a86c 100644 --- a/core/src/test/java/org/elasticsearch/monitor/os/OsProbeTests.java +++ b/core/src/test/java/org/elasticsearch/monitor/os/OsProbeTests.java @@ -29,7 +29,6 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.is; -import static org.hamcrest.Matchers.lessThan; import static org.hamcrest.Matchers.lessThanOrEqualTo; public class OsProbeTests extends ESTestCase { @@ -71,7 +70,7 @@ public class OsProbeTests extends ESTestCase { // one minute load average is available, but 10-minute and 15-minute load averages are not // load average can be negative if not available or not computed yet, otherwise it should be >= 0 if (loadAverage != null) { - assertThat(loadAverage[0], anyOf(lessThan((double) 0), greaterThanOrEqualTo((double) 0))); + assertThat(loadAverage[0], anyOf(equalTo((double) -1), greaterThanOrEqualTo((double) 0))); assertThat(loadAverage[1], equalTo((double) -1)); assertThat(loadAverage[2], equalTo((double) -1)); } From e5013d16f00bae7c76585a597e5d5b70bbfed890 Mon Sep 17 00:00:00 2001 From: Jason Tedor Date: Mon, 18 Jan 2016 13:33:48 -0500 Subject: [PATCH 02/11] Fix test for load average on FreeBSD This commit fixes the test for load averages on FreeBSD. On FreeBSD, it is either the case that linprocfs is mounted at /compat/linux/proc in which case the load averages are available, or this is not the case and no load average are available. Previously, the test on FreeBSD was falling back to the catch all case which asserts that the five-minute and fifteen-minute load averages are not available. --- .../java/org/elasticsearch/monitor/os/OsProbeTests.java | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/core/src/test/java/org/elasticsearch/monitor/os/OsProbeTests.java b/core/src/test/java/org/elasticsearch/monitor/os/OsProbeTests.java index 83007c5a86c..8dd80231623 100644 --- a/core/src/test/java/org/elasticsearch/monitor/os/OsProbeTests.java +++ b/core/src/test/java/org/elasticsearch/monitor/os/OsProbeTests.java @@ -66,6 +66,12 @@ public class OsProbeTests extends ESTestCase { assertThat(loadAverage[0], greaterThanOrEqualTo((double) 0)); assertThat(loadAverage[1], greaterThanOrEqualTo((double) 0)); assertThat(loadAverage[2], greaterThanOrEqualTo((double) 0)); + } else if (Constants.FREE_BSD) { + // five- and fifteen-minute load averages not available if linprocfs is not mounted at /compat/linux/proc + assertNotNull(loadAverage); + assertThat(loadAverage[0], greaterThanOrEqualTo((double) 0)); + assertThat(loadAverage[1], anyOf(equalTo((double) -1), greaterThanOrEqualTo((double) 0))); + assertThat(loadAverage[2], anyOf(equalTo((double) -1), greaterThanOrEqualTo((double) 0))); } else { // one minute load average is available, but 10-minute and 15-minute load averages are not // load average can be negative if not available or not computed yet, otherwise it should be >= 0 From d55952a90d1c71a5db0fb89e1ab860b6c73de140 Mon Sep 17 00:00:00 2001 From: Jason Tedor Date: Mon, 18 Jan 2016 13:45:00 -0500 Subject: [PATCH 03/11] Tigthen load average test assertions This commit tightens the load average test assertions by separating out OS X as a system where we can make tighter assertions about the load average values. --- .../java/org/elasticsearch/monitor/os/OsProbeTests.java | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/core/src/test/java/org/elasticsearch/monitor/os/OsProbeTests.java b/core/src/test/java/org/elasticsearch/monitor/os/OsProbeTests.java index 8dd80231623..5aa1899c61d 100644 --- a/core/src/test/java/org/elasticsearch/monitor/os/OsProbeTests.java +++ b/core/src/test/java/org/elasticsearch/monitor/os/OsProbeTests.java @@ -72,9 +72,14 @@ public class OsProbeTests extends ESTestCase { assertThat(loadAverage[0], greaterThanOrEqualTo((double) 0)); assertThat(loadAverage[1], anyOf(equalTo((double) -1), greaterThanOrEqualTo((double) 0))); assertThat(loadAverage[2], anyOf(equalTo((double) -1), greaterThanOrEqualTo((double) 0))); - } else { + } else if (Constants.MAC_OS_X) { // one minute load average is available, but 10-minute and 15-minute load averages are not - // load average can be negative if not available or not computed yet, otherwise it should be >= 0 + assertNotNull(loadAverage); + assertThat(loadAverage[0], greaterThanOrEqualTo((double) 0)); + assertThat(loadAverage[1], equalTo((double) -1)); + assertThat(loadAverage[2], equalTo((double) -1)); + } else { + // unknown system, but the best case is that we have the one-minute load average if (loadAverage != null) { assertThat(loadAverage[0], anyOf(equalTo((double) -1), greaterThanOrEqualTo((double) 0))); assertThat(loadAverage[1], equalTo((double) -1)); From f5b72b0714095cc6b4fe9d302e1ce858b570ba4e Mon Sep 17 00:00:00 2001 From: Jason Tedor Date: Mon, 18 Jan 2016 14:00:15 -0500 Subject: [PATCH 04/11] Load average on Windows is never available Since load average is never available on Windows, this commit modifies the handling of load average there to just always return null. --- .../src/main/java/org/elasticsearch/monitor/os/OsProbe.java | 3 +++ .../java/org/elasticsearch/monitor/os/OsProbeTests.java | 6 +----- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/core/src/main/java/org/elasticsearch/monitor/os/OsProbe.java b/core/src/main/java/org/elasticsearch/monitor/os/OsProbe.java index 3322c05fc97..5ee2232068f 100644 --- a/core/src/main/java/org/elasticsearch/monitor/os/OsProbe.java +++ b/core/src/main/java/org/elasticsearch/monitor/os/OsProbe.java @@ -119,6 +119,9 @@ public class OsProbe { } // fallback } + if (Constants.WINDOWS) { + return null; + } if (getSystemLoadAverage == null) { return null; } diff --git a/core/src/test/java/org/elasticsearch/monitor/os/OsProbeTests.java b/core/src/test/java/org/elasticsearch/monitor/os/OsProbeTests.java index 5aa1899c61d..2edaad5c4ba 100644 --- a/core/src/test/java/org/elasticsearch/monitor/os/OsProbeTests.java +++ b/core/src/test/java/org/elasticsearch/monitor/os/OsProbeTests.java @@ -55,11 +55,7 @@ public class OsProbeTests extends ESTestCase { } if (Constants.WINDOWS) { // load average is unavailable on Windows - if (loadAverage != null) { - assertThat(loadAverage[0], equalTo((double) -1)); - assertThat(loadAverage[1], equalTo((double) -1)); - assertThat(loadAverage[2], equalTo((double) -1)); - } + assertNull(loadAverage); } else if (Constants.LINUX) { // we should be able to get the load average assertNotNull(loadAverage); From e62221168ea5f2a259ecf321f0dff1f40ffe12c3 Mon Sep 17 00:00:00 2001 From: Jason Tedor Date: Mon, 18 Jan 2016 14:01:48 -0500 Subject: [PATCH 05/11] No load average if the values are meaningless This commit modifies the presentation of load average to not be present in the response if all of the values are meaningless. --- core/src/main/java/org/elasticsearch/monitor/os/OsStats.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/main/java/org/elasticsearch/monitor/os/OsStats.java b/core/src/main/java/org/elasticsearch/monitor/os/OsStats.java index c419c4f2608..569f8825aa9 100644 --- a/core/src/main/java/org/elasticsearch/monitor/os/OsStats.java +++ b/core/src/main/java/org/elasticsearch/monitor/os/OsStats.java @@ -28,6 +28,7 @@ import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentBuilderString; import java.io.IOException; +import java.util.Arrays; /** * @@ -89,7 +90,7 @@ public class OsStats implements Streamable, ToXContent { if (cpu != null) { builder.startObject(Fields.CPU); builder.field(Fields.PERCENT, cpu.getPercent()); - if (cpu.getLoadAverage() != null) { + if (cpu.getLoadAverage() != null && Arrays.stream(cpu.getLoadAverage()).anyMatch(load -> load != -1)) { builder.startObject(Fields.LOAD_AVERAGE); if (cpu.getLoadAverage()[0] != -1) { builder.field(Fields.LOAD_AVERAGE_1M, cpu.getLoadAverage()[0]); From 513f4e6c57ea15114010c69ffa1665fbce13b881 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christoph=20B=C3=BCscher?= Date: Mon, 18 Jan 2016 15:31:16 +0100 Subject: [PATCH 06/11] Add serialization and fromXContent to SmoothingModels PhraseSuggestionBuilder uses three smoothing models internally. In order to enable proper serialization / parsing from xContent to the phrase suggester later, this change starts by making the smoothing models writable, adding hashCode/equals and fromXContent. --- .../suggest/phrase/PhraseSuggestParser.java | 6 +- .../phrase/PhraseSuggestionBuilder.java | 266 ++++++++++++++++-- .../AbstractShapeBuilderTestCase.java | 1 - .../suggest/phrase/LaplaceModelTests.java | 38 +++ .../phrase/LinearInterpolationModelTests.java | 55 ++++ .../suggest/phrase/SmoothingModelTest.java | 161 +++++++++++ .../phrase/StupidBackoffModelTests.java | 38 +++ 7 files changed, 539 insertions(+), 26 deletions(-) create mode 100644 core/src/test/java/org/elasticsearch/search/suggest/phrase/LaplaceModelTests.java create mode 100644 core/src/test/java/org/elasticsearch/search/suggest/phrase/LinearInterpolationModelTests.java create mode 100644 core/src/test/java/org/elasticsearch/search/suggest/phrase/SmoothingModelTest.java create mode 100644 core/src/test/java/org/elasticsearch/search/suggest/phrase/StupidBackoffModelTests.java diff --git a/core/src/main/java/org/elasticsearch/search/suggest/phrase/PhraseSuggestParser.java b/core/src/main/java/org/elasticsearch/search/suggest/phrase/PhraseSuggestParser.java index 0b904a95720..c226d061047 100644 --- a/core/src/main/java/org/elasticsearch/search/suggest/phrase/PhraseSuggestParser.java +++ b/core/src/main/java/org/elasticsearch/search/suggest/phrase/PhraseSuggestParser.java @@ -36,6 +36,8 @@ import org.elasticsearch.script.Template; import org.elasticsearch.search.suggest.SuggestContextParser; import org.elasticsearch.search.suggest.SuggestUtils; import org.elasticsearch.search.suggest.SuggestionSearchContext; +import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.Laplace; +import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.StupidBackoff; import org.elasticsearch.search.suggest.phrase.PhraseSuggestionContext.DirectCandidateGenerator; import java.io.IOException; @@ -265,7 +267,7 @@ public final class PhraseSuggestParser implements SuggestContextParser { }); } else if ("laplace".equals(fieldName)) { ensureNoSmoothing(suggestion); - double theAlpha = 0.5; + double theAlpha = Laplace.DEFAULT_LAPLACE_ALPHA; while ((token = parser.nextToken()) != Token.END_OBJECT) { if (token == XContentParser.Token.FIELD_NAME) { @@ -286,7 +288,7 @@ public final class PhraseSuggestParser implements SuggestContextParser { } else if ("stupid_backoff".equals(fieldName) || "stupidBackoff".equals(fieldName)) { ensureNoSmoothing(suggestion); - double theDiscount = 0.4; + double theDiscount = StupidBackoff.DEFAULT_BACKOFF_DISCOUNT; while ((token = parser.nextToken()) != Token.END_OBJECT) { if (token == XContentParser.Token.FIELD_NAME) { fieldName = parser.currentName(); diff --git a/core/src/main/java/org/elasticsearch/search/suggest/phrase/PhraseSuggestionBuilder.java b/core/src/main/java/org/elasticsearch/search/suggest/phrase/PhraseSuggestionBuilder.java index 1055fbe83fc..97ca09d25a1 100644 --- a/core/src/main/java/org/elasticsearch/search/suggest/phrase/PhraseSuggestionBuilder.java +++ b/core/src/main/java/org/elasticsearch/search/suggest/phrase/PhraseSuggestionBuilder.java @@ -18,8 +18,16 @@ */ package org.elasticsearch.search.suggest.phrase; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.ParseFieldMatcher; +import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.common.xcontent.XContentParser.Token; +import org.elasticsearch.index.query.QueryParseContext; import org.elasticsearch.script.Template; import org.elasticsearch.search.suggest.SuggestBuilder.SuggestionBuilder; @@ -29,6 +37,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Map.Entry; +import java.util.Objects; import java.util.Set; /** @@ -41,7 +50,7 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder> generators = new HashMap<>(); private Integer gramSize; - private SmoothingModel model; + private SmoothingModel model; private Boolean forceUnigrams; private Integer tokenLimit; private String preTag; @@ -150,7 +159,7 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder model) { this.model = model; return this; } @@ -283,8 +292,15 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder for details. *

*/ - public static final class StupidBackoff extends SmoothingModel { - private final double discount; + public static final class StupidBackoff extends SmoothingModel { + /** + * Default discount parameter for {@link StupidBackoff} smoothing + */ + public static final double DEFAULT_BACKOFF_DISCOUNT = 0.4; + private double discount = DEFAULT_BACKOFF_DISCOUNT; + static final StupidBackoff PROTOTYPE = new StupidBackoff(DEFAULT_BACKOFF_DISCOUNT); + private static final String NAME = "stupid_backoff"; + private static final ParseField DISCOUNT_FIELD = new ParseField("discount"); /** * Creates a Stupid-Backoff smoothing model. @@ -293,15 +309,63 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder for details. *

*/ - public static final class Laplace extends SmoothingModel { - private final double alpha; + public static final class Laplace extends SmoothingModel { + private double alpha = DEFAULT_LAPLACE_ALPHA; + private static final String NAME = "laplace"; + private static final ParseField ALPHA_FIELD = new ParseField("alpha"); + /** + * Default alpha parameter for laplace smoothing + */ + public static final double DEFAULT_LAPLACE_ALPHA = 0.5; + static final Laplace PROTOTYPE = new Laplace(DEFAULT_LAPLACE_ALPHA); + /** * Creates a Laplace smoothing model. * */ public Laplace(double alpha) { - super("laplace"); this.alpha = alpha; } + /** + * @return the laplace model alpha parameter + */ + public double getAlpha() { + return this.alpha; + } + @Override protected XContentBuilder innerToXContent(XContentBuilder builder, Params params) throws IOException { - builder.field("alpha", alpha); + builder.field(ALPHA_FIELD.getPreferredName(), alpha); return builder; } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeDouble(alpha); + } + + @Override + public Laplace readFrom(StreamInput in) throws IOException { + return new Laplace(in.readDouble()); + } + + @Override + protected boolean doEquals(Laplace other) { + return Objects.equals(alpha, other.alpha); + } + + @Override + public final int hashCode() { + return Objects.hash(alpha); + } + + @Override + public Laplace fromXContent(QueryParseContext parseContext) throws IOException { + XContentParser parser = parseContext.parser(); + XContentParser.Token token; + String fieldName = null; + double alpha = DEFAULT_LAPLACE_ALPHA; + while ((token = parser.nextToken()) != Token.END_OBJECT) { + if (token == XContentParser.Token.FIELD_NAME) { + fieldName = parser.currentName(); + } + if (token.isValue() && parseContext.parseFieldMatcher().match(fieldName, ALPHA_FIELD)) { + alpha = parser.doubleValue(); + } + } + return new Laplace(alpha); + } } - public static abstract class SmoothingModel implements ToXContent { - private final String type; - - protected SmoothingModel(String type) { - this.type = type; - } + public static abstract class SmoothingModel> implements NamedWriteable, ToXContent { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(type); + builder.startObject(getWriteableName()); innerToXContent(builder,params); builder.endObject(); return builder; } + @Override + public final boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + @SuppressWarnings("unchecked") + SM other = (SM) obj; + return doEquals(other); + } + + public abstract SM fromXContent(QueryParseContext parseContext) throws IOException; + + /** + * subtype specific implementation of "equals". + */ + protected abstract boolean doEquals(SM other); + protected abstract XContentBuilder innerToXContent(XContentBuilder builder, Params params) throws IOException; } @@ -358,10 +493,15 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder for details. *

*/ - public static final class LinearInterpolation extends SmoothingModel { + public static final class LinearInterpolation extends SmoothingModel { + private static final String NAME = "linear"; + static final LinearInterpolation PROTOTYPE = new LinearInterpolation(0.8, 0.1, 0.1); private final double trigramLambda; private final double bigramLambda; private final double unigramLambda; + private static final ParseField TRIGRAM_FIELD = new ParseField("trigram_lambda"); + private static final ParseField BIGRAM_FIELD = new ParseField("bigram_lambda"); + private static final ParseField UNIGRAM_FIELD = new ParseField("unigram_lambda"); /** * Creates a linear interpolation smoothing model. @@ -376,19 +516,99 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder 0.001) { + throw new IllegalArgumentException("linear smoothing lambdas must sum to 1"); + } this.trigramLambda = trigramLambda; this.bigramLambda = bigramLambda; this.unigramLambda = unigramLambda; } + public double getTrigramLambda() { + return this.trigramLambda; + } + + public double getBigramLambda() { + return this.bigramLambda; + } + + public double getUnigramLambda() { + return this.unigramLambda; + } + @Override protected XContentBuilder innerToXContent(XContentBuilder builder, Params params) throws IOException { - builder.field("trigram_lambda", trigramLambda); - builder.field("bigram_lambda", bigramLambda); - builder.field("unigram_lambda", unigramLambda); + builder.field(TRIGRAM_FIELD.getPreferredName(), trigramLambda); + builder.field(BIGRAM_FIELD.getPreferredName(), bigramLambda); + builder.field(UNIGRAM_FIELD.getPreferredName(), unigramLambda); return builder; } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeDouble(trigramLambda); + out.writeDouble(bigramLambda); + out.writeDouble(unigramLambda); + } + + @Override + public LinearInterpolation readFrom(StreamInput in) throws IOException { + return new LinearInterpolation(in.readDouble(), in.readDouble(), in.readDouble()); + } + + @Override + protected boolean doEquals(LinearInterpolation other) { + return Objects.equals(trigramLambda, other.trigramLambda) && + Objects.equals(bigramLambda, other.bigramLambda) && + Objects.equals(unigramLambda, other.unigramLambda); + } + + @Override + public final int hashCode() { + return Objects.hash(trigramLambda, bigramLambda, unigramLambda); + } + + @Override + public LinearInterpolation fromXContent(QueryParseContext parseContext) throws IOException { + XContentParser parser = parseContext.parser(); + XContentParser.Token token; + String fieldName = null; + final double[] lambdas = new double[3]; + ParseFieldMatcher matcher = parseContext.parseFieldMatcher(); + while ((token = parser.nextToken()) != Token.END_OBJECT) { + if (token == XContentParser.Token.FIELD_NAME) { + fieldName = parser.currentName(); + } + if (token.isValue()) { + if (matcher.match(fieldName, TRIGRAM_FIELD)) { + lambdas[0] = parser.doubleValue(); + if (lambdas[0] < 0) { + throw new IllegalArgumentException("trigram_lambda must be positive"); + } + } else if (matcher.match(fieldName, BIGRAM_FIELD)) { + lambdas[1] = parser.doubleValue(); + if (lambdas[1] < 0) { + throw new IllegalArgumentException("bigram_lambda must be positive"); + } + } else if (matcher.match(fieldName, UNIGRAM_FIELD)) { + lambdas[2] = parser.doubleValue(); + if (lambdas[2] < 0) { + throw new IllegalArgumentException("unigram_lambda must be positive"); + } + } else { + throw new IllegalArgumentException( + "suggester[phrase][smoothing][linear] doesn't support field [" + fieldName + "]"); + } + } + } + return new LinearInterpolation(lambdas[0], lambdas[1], lambdas[2]); + } } /** @@ -428,7 +648,7 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder exte } XContentBuilder builder = testShape.toXContent(contentBuilder, ToXContent.EMPTY_PARAMS); XContentParser shapeParser = XContentHelper.createParser(builder.bytes()); - XContentHelper.createParser(builder.bytes()); shapeParser.nextToken(); ShapeBuilder parsedShape = ShapeBuilder.parse(shapeParser); assertNotSame(testShape, parsedShape); diff --git a/core/src/test/java/org/elasticsearch/search/suggest/phrase/LaplaceModelTests.java b/core/src/test/java/org/elasticsearch/search/suggest/phrase/LaplaceModelTests.java new file mode 100644 index 00000000000..e2256e98f6a --- /dev/null +++ b/core/src/test/java/org/elasticsearch/search/suggest/phrase/LaplaceModelTests.java @@ -0,0 +1,38 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.search.suggest.phrase; + +import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.Laplace; + +public class LaplaceModelTests extends SmoothingModelTest { + + @Override + protected Laplace createTestModel() { + return new Laplace(randomDoubleBetween(0.0, 10.0, false)); + } + + /** + * mutate the given model so the returned smoothing model is different + */ + @Override + protected Laplace createMutation(Laplace original) { + return new Laplace(original.getAlpha() + 0.1); + } +} diff --git a/core/src/test/java/org/elasticsearch/search/suggest/phrase/LinearInterpolationModelTests.java b/core/src/test/java/org/elasticsearch/search/suggest/phrase/LinearInterpolationModelTests.java new file mode 100644 index 00000000000..467bca7f0ab --- /dev/null +++ b/core/src/test/java/org/elasticsearch/search/suggest/phrase/LinearInterpolationModelTests.java @@ -0,0 +1,55 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.search.suggest.phrase; + +import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.LinearInterpolation; + +public class LinearInterpolationModelTests extends SmoothingModelTest { + + @Override + protected LinearInterpolation createTestModel() { + double trigramLambda = randomDoubleBetween(0.0, 10.0, false); + double bigramLambda = randomDoubleBetween(0.0, 10.0, false); + double unigramLambda = randomDoubleBetween(0.0, 10.0, false); + // normalize + double sum = trigramLambda + bigramLambda + unigramLambda; + return new LinearInterpolation(trigramLambda / sum, bigramLambda / sum, unigramLambda / sum); + } + + /** + * mutate the given model so the returned smoothing model is different + */ + @Override + protected LinearInterpolation createMutation(LinearInterpolation original) { + // swap two values permute original lambda values + switch (randomIntBetween(0, 2)) { + case 0: + // swap first two + return new LinearInterpolation(original.getBigramLambda(), original.getTrigramLambda(), original.getUnigramLambda()); + case 1: + // swap last two + return new LinearInterpolation(original.getTrigramLambda(), original.getUnigramLambda(), original.getBigramLambda()); + case 2: + default: + // swap first and last + return new LinearInterpolation(original.getUnigramLambda(), original.getBigramLambda(), original.getTrigramLambda()); + } + } +} diff --git a/core/src/test/java/org/elasticsearch/search/suggest/phrase/SmoothingModelTest.java b/core/src/test/java/org/elasticsearch/search/suggest/phrase/SmoothingModelTest.java new file mode 100644 index 00000000000..b2dbe17e67d --- /dev/null +++ b/core/src/test/java/org/elasticsearch/search/suggest/phrase/SmoothingModelTest.java @@ -0,0 +1,161 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.search.suggest.phrase; + +import org.elasticsearch.common.ParseFieldMatcher; +import org.elasticsearch.common.io.stream.BytesStreamOutput; +import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentFactory; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.index.query.QueryParseContext; +import org.elasticsearch.indices.query.IndicesQueriesRegistry; +import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.Laplace; +import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.LinearInterpolation; +import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.SmoothingModel; +import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.StupidBackoff; +import org.elasticsearch.test.ESTestCase; +import org.junit.AfterClass; +import org.junit.BeforeClass; + +import java.io.IOException; +import java.util.Collections; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.not; + +public abstract class SmoothingModelTest> extends ESTestCase { + + private static NamedWriteableRegistry namedWriteableRegistry; + private static IndicesQueriesRegistry indicesQueriesRegistry; + + /** + * setup for the whole base test class + */ + @BeforeClass + public static void init() { + if (namedWriteableRegistry == null) { + namedWriteableRegistry = new NamedWriteableRegistry(); + namedWriteableRegistry.registerPrototype(SmoothingModel.class, Laplace.PROTOTYPE); + namedWriteableRegistry.registerPrototype(SmoothingModel.class, LinearInterpolation.PROTOTYPE); + namedWriteableRegistry.registerPrototype(SmoothingModel.class, StupidBackoff.PROTOTYPE); + } + indicesQueriesRegistry = new IndicesQueriesRegistry(Settings.settingsBuilder().build(), Collections.emptySet(), namedWriteableRegistry); + } + + @AfterClass + public static void afterClass() throws Exception { + namedWriteableRegistry = null; + indicesQueriesRegistry = null; + } + + /** + * create random model that is put under test + */ + protected abstract SM createTestModel(); + + /** + * mutate the given model so the returned smoothing model is different + */ + protected abstract SM createMutation(SM original) throws IOException; + + /** + * Test that creates new smoothing model from a random test smoothing model and checks both for equality + */ + public void testFromXContent() throws IOException { + QueryParseContext context = new QueryParseContext(indicesQueriesRegistry); + context.parseFieldMatcher(new ParseFieldMatcher(Settings.EMPTY)); + + SM testModel = createTestModel(); + XContentBuilder contentBuilder = XContentFactory.contentBuilder(randomFrom(XContentType.values())); + if (randomBoolean()) { + contentBuilder.prettyPrint(); + } + contentBuilder.startObject(); + testModel.innerToXContent(contentBuilder, ToXContent.EMPTY_PARAMS); + contentBuilder.endObject(); + XContentParser parser = XContentHelper.createParser(contentBuilder.bytes()); + context.reset(parser); + SmoothingModel prototype = (SmoothingModel) namedWriteableRegistry.getPrototype(SmoothingModel.class, + testModel.getWriteableName()); + SmoothingModel parsedModel = prototype.fromXContent(context); + assertNotSame(testModel, parsedModel); + assertEquals(testModel, parsedModel); + assertEquals(testModel.hashCode(), parsedModel.hashCode()); + } + + /** + * Test serialization and deserialization of the tested model. + */ + @SuppressWarnings("unchecked") + public void testSerialization() throws IOException { + SM testModel = createTestModel(); + SM deserializedModel = (SM) copyModel(testModel); + assertEquals(testModel, deserializedModel); + assertEquals(testModel.hashCode(), deserializedModel.hashCode()); + assertNotSame(testModel, deserializedModel); + } + + /** + * Test equality and hashCode properties + */ + @SuppressWarnings("unchecked") + public void testEqualsAndHashcode() throws IOException { + SM firstModel = createTestModel(); + assertFalse("smoothing model is equal to null", firstModel.equals(null)); + assertFalse("smoothing model is equal to incompatible type", firstModel.equals("")); + assertTrue("smoothing model is not equal to self", firstModel.equals(firstModel)); + assertThat("same smoothing model's hashcode returns different values if called multiple times", firstModel.hashCode(), + equalTo(firstModel.hashCode())); + assertThat("different smoothing models should not be equal", createMutation(firstModel), not(equalTo(firstModel))); + + SM secondModel = (SM) copyModel(firstModel); + assertTrue("smoothing model is not equal to self", secondModel.equals(secondModel)); + assertTrue("smoothing model is not equal to its copy", firstModel.equals(secondModel)); + assertTrue("equals is not symmetric", secondModel.equals(firstModel)); + assertThat("smoothing model copy's hashcode is different from original hashcode", secondModel.hashCode(), equalTo(firstModel.hashCode())); + + SM thirdModel = (SM) copyModel(secondModel); + assertTrue("smoothing model is not equal to self", thirdModel.equals(thirdModel)); + assertTrue("smoothing model is not equal to its copy", secondModel.equals(thirdModel)); + assertThat("smoothing model copy's hashcode is different from original hashcode", secondModel.hashCode(), equalTo(thirdModel.hashCode())); + assertTrue("equals is not transitive", firstModel.equals(thirdModel)); + assertThat("smoothing model copy's hashcode is different from original hashcode", firstModel.hashCode(), equalTo(thirdModel.hashCode())); + assertTrue("equals is not symmetric", thirdModel.equals(secondModel)); + assertTrue("equals is not symmetric", thirdModel.equals(firstModel)); + } + + static SmoothingModel copyModel(SmoothingModel original) throws IOException { + try (BytesStreamOutput output = new BytesStreamOutput()) { + original.writeTo(output); + try (StreamInput in = new NamedWriteableAwareStreamInput(StreamInput.wrap(output.bytes()), namedWriteableRegistry)) { + SmoothingModel prototype = (SmoothingModel) namedWriteableRegistry.getPrototype(SmoothingModel.class, original.getWriteableName()); + return prototype.readFrom(in); + } + } + } + +} diff --git a/core/src/test/java/org/elasticsearch/search/suggest/phrase/StupidBackoffModelTests.java b/core/src/test/java/org/elasticsearch/search/suggest/phrase/StupidBackoffModelTests.java new file mode 100644 index 00000000000..5d774066e07 --- /dev/null +++ b/core/src/test/java/org/elasticsearch/search/suggest/phrase/StupidBackoffModelTests.java @@ -0,0 +1,38 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.search.suggest.phrase; + +import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.StupidBackoff; + +public class StupidBackoffModelTests extends SmoothingModelTest { + + @Override + protected StupidBackoff createTestModel() { + return new StupidBackoff(randomDoubleBetween(0.0, 10.0, false)); + } + + /** + * mutate the given model so the returned smoothing model is different + */ + @Override + protected StupidBackoff createMutation(StupidBackoff original) { + return new StupidBackoff(original.getDiscount() + 0.1); + } +} From e035dabd4dbc1d3ba573cd629de1a4f38bccd9fa Mon Sep 17 00:00:00 2001 From: Igor Motov Date: Mon, 11 Jan 2016 12:36:55 -0500 Subject: [PATCH 07/11] Extend tracking of parent tasks to master node, replication and broadcast actions Now MasterNodeOperations, ReplicationAllShards, ReplicationSingleShard, BroadcastReplication and BroadcastByNode actions keep track of their parent tasks. --- .../upgrade/post/TransportUpgradeAction.java | 5 +- .../query/TransportValidateQueryAction.java | 5 +- .../action/delete/TransportDeleteAction.java | 15 +- .../action/index/TransportIndexAction.java | 15 +- .../percolate/TransportPercolateAction.java | 7 +- .../support/ChildTaskActionRequest.java | 66 +++++ .../action/support/ChildTaskRequest.java | 4 +- .../support/HandledTransportAction.java | 9 +- .../action/support/TransportAction.java | 10 +- .../broadcast/BroadcastShardRequest.java | 4 +- .../broadcast/TransportBroadcastAction.java | 15 +- .../node/TransportBroadcastByNodeAction.java | 20 +- .../support/master/MasterNodeRequest.java | 3 +- .../master/TransportMasterNodeAction.java | 3 + .../replication/ReplicationRequest.java | 3 +- .../TransportBroadcastReplicationAction.java | 17 +- .../TransportReplicationAction.java | 24 +- .../support/tasks/BaseTasksRequest.java | 22 +- .../dfs/TransportDfsOnlyAction.java | 5 +- .../org/elasticsearch/tasks/ChildTask.java | 57 ---- .../java/org/elasticsearch/tasks/Task.java | 30 +- .../org/elasticsearch/tasks/TaskManager.java | 7 +- .../transport/TransportService.java | 6 +- .../tasks/RecordingTaskManagerListener.java | 81 ++++++ .../admin/cluster/node/tasks/TasksIT.java | 258 +++++++++++++++++- .../node/tasks/TransportTasksActionTests.java | 59 +++- .../TransportBroadcastByNodeActionTests.java | 10 +- .../BroadcastReplicationTests.java | 3 +- .../TransportReplicationActionTests.java | 16 +- .../messy/tests/IndicesRequestTests.java | 10 +- .../test/cluster/TestClusterService.java | 6 + .../test/tasks/MockTaskManager.java | 82 ++++++ .../test/tasks/MockTaskManagerListener.java | 31 +++ .../test/transport/MockTransportService.java | 11 + 34 files changed, 763 insertions(+), 156 deletions(-) create mode 100644 core/src/main/java/org/elasticsearch/action/support/ChildTaskActionRequest.java delete mode 100644 core/src/main/java/org/elasticsearch/tasks/ChildTask.java create mode 100644 core/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/RecordingTaskManagerListener.java create mode 100644 test/framework/src/main/java/org/elasticsearch/test/tasks/MockTaskManager.java create mode 100644 test/framework/src/main/java/org/elasticsearch/test/tasks/MockTaskManagerListener.java diff --git a/core/src/main/java/org/elasticsearch/action/admin/indices/upgrade/post/TransportUpgradeAction.java b/core/src/main/java/org/elasticsearch/action/admin/indices/upgrade/post/TransportUpgradeAction.java index 6e172f3e22f..f3cf2da9fdd 100644 --- a/core/src/main/java/org/elasticsearch/action/admin/indices/upgrade/post/TransportUpgradeAction.java +++ b/core/src/main/java/org/elasticsearch/action/admin/indices/upgrade/post/TransportUpgradeAction.java @@ -41,6 +41,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.shard.IndexShard; import org.elasticsearch.indices.IndicesService; +import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; @@ -179,7 +180,7 @@ public class TransportUpgradeAction extends TransportBroadcastByNodeAction listener) { + protected void doExecute(Task task, UpgradeRequest request, final ActionListener listener) { ActionListener settingsUpdateListener = new ActionListener() { @Override public void onResponse(UpgradeResponse upgradeResponse) { @@ -199,7 +200,7 @@ public class TransportUpgradeAction extends TransportBroadcastByNodeAction listener) { diff --git a/core/src/main/java/org/elasticsearch/action/admin/indices/validate/query/TransportValidateQueryAction.java b/core/src/main/java/org/elasticsearch/action/admin/indices/validate/query/TransportValidateQueryAction.java index 14086605f41..72cbe37c919 100644 --- a/core/src/main/java/org/elasticsearch/action/admin/indices/validate/query/TransportValidateQueryAction.java +++ b/core/src/main/java/org/elasticsearch/action/admin/indices/validate/query/TransportValidateQueryAction.java @@ -52,6 +52,7 @@ import org.elasticsearch.search.SearchService; import org.elasticsearch.search.internal.DefaultSearchContext; import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.search.internal.ShardSearchLocalRequest; +import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; @@ -89,9 +90,9 @@ public class TransportValidateQueryAction extends TransportBroadcastAction listener) { + protected void doExecute(Task task, ValidateQueryRequest request, ActionListener listener) { request.nowInMillis = System.currentTimeMillis(); - super.doExecute(request, listener); + super.doExecute(task, request, listener); } @Override diff --git a/core/src/main/java/org/elasticsearch/action/delete/TransportDeleteAction.java b/core/src/main/java/org/elasticsearch/action/delete/TransportDeleteAction.java index c235144c662..3a0e7aeec21 100644 --- a/core/src/main/java/org/elasticsearch/action/delete/TransportDeleteAction.java +++ b/core/src/main/java/org/elasticsearch/action/delete/TransportDeleteAction.java @@ -44,6 +44,7 @@ import org.elasticsearch.index.shard.IndexShard; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.indices.IndexAlreadyExistsException; import org.elasticsearch.indices.IndicesService; +import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; @@ -69,27 +70,27 @@ public class TransportDeleteAction extends TransportReplicationAction listener) { + protected void doExecute(Task task, final DeleteRequest request, final ActionListener listener) { ClusterState state = clusterService.state(); if (autoCreateIndex.shouldAutoCreate(request.index(), state)) { - createIndexAction.execute(new CreateIndexRequest().index(request.index()).cause("auto(delete api)").masterNodeTimeout(request.timeout()), new ActionListener() { + createIndexAction.execute(task, new CreateIndexRequest().index(request.index()).cause("auto(delete api)").masterNodeTimeout(request.timeout()), new ActionListener() { @Override public void onResponse(CreateIndexResponse result) { - innerExecute(request, listener); + innerExecute(task, request, listener); } @Override public void onFailure(Throwable e) { if (ExceptionsHelper.unwrapCause(e) instanceof IndexAlreadyExistsException) { // we have the index, do it - innerExecute(request, listener); + innerExecute(task, request, listener); } else { listener.onFailure(e); } } }); } else { - innerExecute(request, listener); + innerExecute(task, request, listener); } } @@ -114,8 +115,8 @@ public class TransportDeleteAction extends TransportReplicationAction listener) { - super.doExecute(request, listener); + private void innerExecute(Task task, final DeleteRequest request, final ActionListener listener) { + super.doExecute(task, request, listener); } @Override diff --git a/core/src/main/java/org/elasticsearch/action/index/TransportIndexAction.java b/core/src/main/java/org/elasticsearch/action/index/TransportIndexAction.java index ae901e8575d..33bf3547d0b 100644 --- a/core/src/main/java/org/elasticsearch/action/index/TransportIndexAction.java +++ b/core/src/main/java/org/elasticsearch/action/index/TransportIndexAction.java @@ -48,6 +48,7 @@ import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.index.translog.Translog; import org.elasticsearch.indices.IndexAlreadyExistsException; import org.elasticsearch.indices.IndicesService; +import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; @@ -84,7 +85,7 @@ public class TransportIndexAction extends TransportReplicationAction listener) { + protected void doExecute(Task task, final IndexRequest request, final ActionListener listener) { // if we don't have a master, we don't have metadata, that's fine, let it find a master using create index API ClusterState state = clusterService.state(); if (autoCreateIndex.shouldAutoCreate(request.index(), state)) { @@ -93,10 +94,10 @@ public class TransportIndexAction extends TransportReplicationAction() { + createIndexAction.execute(task, createIndexRequest, new ActionListener() { @Override public void onResponse(CreateIndexResponse result) { - innerExecute(request, listener); + innerExecute(task, request, listener); } @Override @@ -104,7 +105,7 @@ public class TransportIndexAction extends TransportReplicationAction listener) { - super.doExecute(request, listener); + private void innerExecute(Task task, final IndexRequest request, final ActionListener listener) { + super.doExecute(task, request, listener); } @Override diff --git a/core/src/main/java/org/elasticsearch/action/percolate/TransportPercolateAction.java b/core/src/main/java/org/elasticsearch/action/percolate/TransportPercolateAction.java index bba024068c3..0edce177be7 100644 --- a/core/src/main/java/org/elasticsearch/action/percolate/TransportPercolateAction.java +++ b/core/src/main/java/org/elasticsearch/action/percolate/TransportPercolateAction.java @@ -41,6 +41,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.engine.DocumentMissingException; import org.elasticsearch.percolator.PercolateException; import org.elasticsearch.percolator.PercolatorService; +import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; @@ -70,7 +71,7 @@ public class TransportPercolateAction extends TransportBroadcastAction listener) { + protected void doExecute(Task task, final PercolateRequest request, final ActionListener listener) { request.startTime = System.currentTimeMillis(); if (request.getRequest() != null) { //create a new get request to make sure it has the same headers and context as the original percolate request @@ -84,7 +85,7 @@ public class TransportPercolateAction extends TransportBroadcastAction> extends ActionRequest { + + private String parentTaskNode; + + private long parentTaskId; + + protected ChildTaskActionRequest() { + + } + + public void setParentTask(String parentTaskNode, long parentTaskId) { + this.parentTaskNode = parentTaskNode; + this.parentTaskId = parentTaskId; + } + + @Override + public void readFrom(StreamInput in) throws IOException { + super.readFrom(in); + parentTaskNode = in.readOptionalString(); + parentTaskId = in.readLong(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeOptionalString(parentTaskNode); + out.writeLong(parentTaskId); + } + + @Override + public Task createTask(long id, String type, String action) { + return new Task(id, type, action, this::getDescription, parentTaskNode, parentTaskId); + } + +} diff --git a/core/src/main/java/org/elasticsearch/action/support/ChildTaskRequest.java b/core/src/main/java/org/elasticsearch/action/support/ChildTaskRequest.java index 0483ec66e44..59ebb476703 100644 --- a/core/src/main/java/org/elasticsearch/action/support/ChildTaskRequest.java +++ b/core/src/main/java/org/elasticsearch/action/support/ChildTaskRequest.java @@ -19,10 +19,8 @@ package org.elasticsearch.action.support; -import org.elasticsearch.Version; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.tasks.ChildTask; import org.elasticsearch.tasks.Task; import org.elasticsearch.transport.TransportRequest; @@ -61,6 +59,6 @@ public class ChildTaskRequest extends TransportRequest { @Override public Task createTask(long id, String type, String action) { - return new ChildTask(id, type, action, this::getDescription, parentTaskNode, parentTaskId); + return new Task(id, type, action, this::getDescription, parentTaskNode, parentTaskId); } } diff --git a/core/src/main/java/org/elasticsearch/action/support/HandledTransportAction.java b/core/src/main/java/org/elasticsearch/action/support/HandledTransportAction.java index a439117fef6..337d98fce68 100644 --- a/core/src/main/java/org/elasticsearch/action/support/HandledTransportAction.java +++ b/core/src/main/java/org/elasticsearch/action/support/HandledTransportAction.java @@ -44,13 +44,14 @@ public abstract class HandledTransportAction { @Override - public final void messageReceived(final Request request, final TransportChannel channel, Task task) throws Exception { - messageReceived(request, channel); + public final void messageReceived(Request request, TransportChannel channel) throws Exception { + throw new UnsupportedOperationException("the task parameter is required for this operation"); } @Override - public final void messageReceived(Request request, TransportChannel channel) throws Exception { - execute(request, new ActionListener() { + public final void messageReceived(final Request request, final TransportChannel channel, Task task) throws Exception { + // We already got the task created on the netty layer - no need to create it again on the transport layer + execute(task, request, new ActionListener() { @Override public void onResponse(Response response) { try { diff --git a/core/src/main/java/org/elasticsearch/action/support/TransportAction.java b/core/src/main/java/org/elasticsearch/action/support/TransportAction.java index 584ff14e756..ecc81227c44 100644 --- a/core/src/main/java/org/elasticsearch/action/support/TransportAction.java +++ b/core/src/main/java/org/elasticsearch/action/support/TransportAction.java @@ -66,6 +66,11 @@ public abstract class TransportAction, Re return future; } + /** + * Use this method when the transport action call should result in creation of a new task associated with the call. + * + * This is a typical behavior. + */ public final Task execute(Request request, ActionListener listener) { Task task = taskManager.register("transport", actionName, request); if (task == null) { @@ -88,7 +93,10 @@ public abstract class TransportAction, Re return task; } - private final void execute(Task task, Request request, ActionListener listener) { + /** + * Use this method when the transport action should continue to run in the context of the current task + */ + public final void execute(Task task, Request request, ActionListener listener) { ActionRequestValidationException validationException = request.validate(); if (validationException != null) { diff --git a/core/src/main/java/org/elasticsearch/action/support/broadcast/BroadcastShardRequest.java b/core/src/main/java/org/elasticsearch/action/support/broadcast/BroadcastShardRequest.java index 921724e6572..76cb04f71f4 100644 --- a/core/src/main/java/org/elasticsearch/action/support/broadcast/BroadcastShardRequest.java +++ b/core/src/main/java/org/elasticsearch/action/support/broadcast/BroadcastShardRequest.java @@ -21,18 +21,18 @@ package org.elasticsearch.action.support.broadcast; import org.elasticsearch.action.IndicesRequest; import org.elasticsearch.action.OriginalIndices; +import org.elasticsearch.action.support.ChildTaskRequest; import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.index.shard.ShardId; -import org.elasticsearch.transport.TransportRequest; import java.io.IOException; /** * */ -public abstract class BroadcastShardRequest extends TransportRequest implements IndicesRequest { +public abstract class BroadcastShardRequest extends ChildTaskRequest implements IndicesRequest { private ShardId shardId; diff --git a/core/src/main/java/org/elasticsearch/action/support/broadcast/TransportBroadcastAction.java b/core/src/main/java/org/elasticsearch/action/support/broadcast/TransportBroadcastAction.java index be851cfa7e2..c36f4cd1a0f 100644 --- a/core/src/main/java/org/elasticsearch/action/support/broadcast/TransportBroadcastAction.java +++ b/core/src/main/java/org/elasticsearch/action/support/broadcast/TransportBroadcastAction.java @@ -35,6 +35,7 @@ import org.elasticsearch.cluster.routing.ShardIterator; import org.elasticsearch.cluster.routing.ShardRouting; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.BaseTransportResponseHandler; import org.elasticsearch.transport.TransportChannel; @@ -69,8 +70,13 @@ public abstract class TransportBroadcastAction listener) { - new AsyncBroadcastAction(request, listener).start(); + protected void doExecute(Task task, Request request, ActionListener listener) { + new AsyncBroadcastAction(task, request, listener).start(); + } + + @Override + protected final void doExecute(Request request, ActionListener listener) { + throw new UnsupportedOperationException("the task parameter is required for this operation"); } protected abstract Response newResponse(Request request, AtomicReferenceArray shardsResponses, ClusterState clusterState); @@ -93,6 +99,7 @@ public abstract class TransportBroadcastAction listener; private final ClusterState clusterState; @@ -102,7 +109,8 @@ public abstract class TransportBroadcastAction listener) { + protected AsyncBroadcastAction(Task task, Request request, ActionListener listener) { + this.task = task; this.request = request; this.listener = listener; @@ -158,6 +166,7 @@ public abstract class TransportBroadcastAction listener) { - new AsyncAction(request, listener).start(); + protected final void doExecute(Request request, ActionListener listener) { + throw new UnsupportedOperationException("the task parameter is required for this operation"); + } + + @Override + protected void doExecute(Task task, Request request, ActionListener listener) { + new AsyncAction(task, request, listener).start(); } protected class AsyncAction { + private final Task task; private final Request request; private final ActionListener listener; private final ClusterState clusterState; @@ -220,7 +228,8 @@ public abstract class TransportBroadcastByNodeAction unavailableShardExceptions = new ArrayList<>(); - protected AsyncAction(Request request, ActionListener listener) { + protected AsyncAction(Task task, Request request, ActionListener listener) { + this.task = task; this.request = request; this.listener = listener; @@ -290,6 +299,9 @@ public abstract class TransportBroadcastByNodeAction shards, final int nodeIndex) { try { NodeRequest nodeRequest = new NodeRequest(node.getId(), request, shards); + if (task != null) { + nodeRequest.setParentTask(clusterService.localNode().id(), task.getId()); + } transportService.sendRequest(node, transportNodeBroadcastAction, nodeRequest, new BaseTransportResponseHandler() { @Override public NodeResponse newInstance() { @@ -422,7 +434,7 @@ public abstract class TransportBroadcastByNodeAction shards; diff --git a/core/src/main/java/org/elasticsearch/action/support/master/MasterNodeRequest.java b/core/src/main/java/org/elasticsearch/action/support/master/MasterNodeRequest.java index a964a44a140..93d34e09ac6 100644 --- a/core/src/main/java/org/elasticsearch/action/support/master/MasterNodeRequest.java +++ b/core/src/main/java/org/elasticsearch/action/support/master/MasterNodeRequest.java @@ -20,6 +20,7 @@ package org.elasticsearch.action.support.master; import org.elasticsearch.action.ActionRequest; +import org.elasticsearch.action.support.ChildTaskActionRequest; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.unit.TimeValue; @@ -29,7 +30,7 @@ import java.io.IOException; /** * A based request for master based operation. */ -public abstract class MasterNodeRequest> extends ActionRequest { +public abstract class MasterNodeRequest> extends ChildTaskActionRequest { public static final TimeValue DEFAULT_MASTER_NODE_TIMEOUT = TimeValue.timeValueSeconds(30); diff --git a/core/src/main/java/org/elasticsearch/action/support/master/TransportMasterNodeAction.java b/core/src/main/java/org/elasticsearch/action/support/master/TransportMasterNodeAction.java index 087b3891a8b..3b8c751f934 100644 --- a/core/src/main/java/org/elasticsearch/action/support/master/TransportMasterNodeAction.java +++ b/core/src/main/java/org/elasticsearch/action/support/master/TransportMasterNodeAction.java @@ -113,6 +113,9 @@ public abstract class TransportMasterNodeAction listener) { this.task = task; this.request = request; + if (task != null) { + request.setParentTask(clusterService.localNode().getId(), task.getId()); + } // TODO do we really need to wrap it in a listener? the handlers should be cheap if ((listener instanceof ThreadedActionListener) == false) { listener = new ThreadedActionListener<>(logger, threadPool, ThreadPool.Names.LISTENER, listener); diff --git a/core/src/main/java/org/elasticsearch/action/support/replication/ReplicationRequest.java b/core/src/main/java/org/elasticsearch/action/support/replication/ReplicationRequest.java index 1f79d99981f..ed23017410e 100644 --- a/core/src/main/java/org/elasticsearch/action/support/replication/ReplicationRequest.java +++ b/core/src/main/java/org/elasticsearch/action/support/replication/ReplicationRequest.java @@ -23,6 +23,7 @@ import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.IndicesRequest; import org.elasticsearch.action.WriteConsistencyLevel; +import org.elasticsearch.action.support.ChildTaskActionRequest; import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.io.stream.StreamInput; @@ -38,7 +39,7 @@ import static org.elasticsearch.action.ValidateActions.addValidationError; /** * */ -public abstract class ReplicationRequest> extends ActionRequest implements IndicesRequest { +public abstract class ReplicationRequest> extends ChildTaskActionRequest implements IndicesRequest { public static final TimeValue DEFAULT_TIMEOUT = new TimeValue(1, TimeUnit.MINUTES); diff --git a/core/src/main/java/org/elasticsearch/action/support/replication/TransportBroadcastReplicationAction.java b/core/src/main/java/org/elasticsearch/action/support/replication/TransportBroadcastReplicationAction.java index 3daafce50b7..fd649f046e8 100644 --- a/core/src/main/java/org/elasticsearch/action/support/replication/TransportBroadcastReplicationAction.java +++ b/core/src/main/java/org/elasticsearch/action/support/replication/TransportBroadcastReplicationAction.java @@ -40,6 +40,7 @@ import org.elasticsearch.cluster.routing.IndexShardRoutingTable; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.CountDown; import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; @@ -67,8 +68,14 @@ public abstract class TransportBroadcastReplicationAction listener) { + protected final void doExecute(final Request request, final ActionListener listener) { + throw new UnsupportedOperationException("the task parameter is required for this operation"); + } + + @Override + protected void doExecute(Task task, Request request, ActionListener listener) { final ClusterState clusterState = clusterService.state(); List shards = shards(request, clusterState); final CopyOnWriteArrayList shardsResponses = new CopyOnWriteArrayList(); @@ -107,12 +114,14 @@ public abstract class TransportBroadcastReplicationAction shardActionListener) { - replicatedBroadcastShardAction.execute(newShardRequest(request, shardId), shardActionListener); + protected void shardExecute(Task task, Request request, ShardId shardId, ActionListener shardActionListener) { + ShardRequest shardRequest = newShardRequest(request, shardId); + shardRequest.setParentTask(clusterService.localNode().getId(), task.getId()); + replicatedBroadcastShardAction.execute(shardRequest, shardActionListener); } /** diff --git a/core/src/main/java/org/elasticsearch/action/support/replication/TransportReplicationAction.java b/core/src/main/java/org/elasticsearch/action/support/replication/TransportReplicationAction.java index 58b73b5e672..a5977a42146 100644 --- a/core/src/main/java/org/elasticsearch/action/support/replication/TransportReplicationAction.java +++ b/core/src/main/java/org/elasticsearch/action/support/replication/TransportReplicationAction.java @@ -61,6 +61,7 @@ import org.elasticsearch.index.translog.Translog; import org.elasticsearch.indices.IndicesService; import org.elasticsearch.node.NodeClosedException; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.BaseTransportResponseHandler; import org.elasticsearch.transport.ConnectTransportException; @@ -134,8 +135,13 @@ public abstract class TransportReplicationAction listener) { - new ReroutePhase(request, listener).run(); + protected final void doExecute(Request request, ActionListener listener) { + throw new UnsupportedOperationException("the task parameter is required for this operation"); + } + + @Override + protected void doExecute(Task task, Request request, ActionListener listener) { + new ReroutePhase(task, request, listener).run(); } protected abstract Response newResponseInstance(); @@ -244,8 +250,8 @@ public abstract class TransportReplicationAction { @Override - public void messageReceived(final Request request, final TransportChannel channel) throws Exception { - execute(request, new ActionListener() { + public void messageReceived(final Request request, final TransportChannel channel, Task task) throws Exception { + execute(task, request, new ActionListener() { @Override public void onResponse(Response result) { try { @@ -265,6 +271,11 @@ public abstract class TransportReplicationAction { @@ -407,8 +418,11 @@ public abstract class TransportReplicationAction listener) { + ReroutePhase(Task task, Request request, ActionListener listener) { this.request = request; + if (task != null) { + this.request.setParentTask(clusterService.localNode().getId(), task.getId()); + } this.listener = listener; this.observer = new ClusterStateObserver(clusterService, request.timeout(), logger, threadPool.getThreadContext()); } diff --git a/core/src/main/java/org/elasticsearch/action/support/tasks/BaseTasksRequest.java b/core/src/main/java/org/elasticsearch/action/support/tasks/BaseTasksRequest.java index 2257eaf71b1..5bb18801107 100644 --- a/core/src/main/java/org/elasticsearch/action/support/tasks/BaseTasksRequest.java +++ b/core/src/main/java/org/elasticsearch/action/support/tasks/BaseTasksRequest.java @@ -26,7 +26,6 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.regex.Regex; import org.elasticsearch.common.unit.TimeValue; -import org.elasticsearch.tasks.ChildTask; import org.elasticsearch.tasks.Task; import java.io.IOException; @@ -164,20 +163,13 @@ public class BaseTasksRequest> extends if (actions() != null && actions().length > 0 && Regex.simpleMatch(actions(), task.getAction()) == false) { return false; } - if (parentNode() != null || parentTaskId() != BaseTasksRequest.ALL_TASKS) { - if (task instanceof ChildTask) { - if (parentNode() != null) { - if (parentNode().equals(((ChildTask) task).getParentNode()) == false) { - return false; - } - } - if (parentTaskId() != BaseTasksRequest.ALL_TASKS) { - if (parentTaskId() != ((ChildTask) task).getParentId()) { - return false; - } - } - } else { - // This is not a child task and we need to match parent node or id + if (parentNode() != null) { + if (parentNode().equals(task.getParentNode()) == false) { + return false; + } + } + if (parentTaskId() != BaseTasksRequest.ALL_TASKS) { + if (parentTaskId() != task.getParentId()) { return false; } } diff --git a/core/src/main/java/org/elasticsearch/action/termvectors/dfs/TransportDfsOnlyAction.java b/core/src/main/java/org/elasticsearch/action/termvectors/dfs/TransportDfsOnlyAction.java index 6970f1e7762..647e3cc7546 100644 --- a/core/src/main/java/org/elasticsearch/action/termvectors/dfs/TransportDfsOnlyAction.java +++ b/core/src/main/java/org/elasticsearch/action/termvectors/dfs/TransportDfsOnlyAction.java @@ -39,6 +39,7 @@ import org.elasticsearch.search.SearchService; import org.elasticsearch.search.controller.SearchPhaseController; import org.elasticsearch.search.dfs.AggregatedDfs; import org.elasticsearch.search.dfs.DfsSearchResult; +import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; @@ -69,9 +70,9 @@ public class TransportDfsOnlyAction extends TransportBroadcastAction listener) { + protected void doExecute(Task task, DfsOnlyRequest request, ActionListener listener) { request.nowInMillis = System.currentTimeMillis(); - super.doExecute(request, listener); + super.doExecute(task, request, listener); } @Override diff --git a/core/src/main/java/org/elasticsearch/tasks/ChildTask.java b/core/src/main/java/org/elasticsearch/tasks/ChildTask.java deleted file mode 100644 index 14d49baf398..00000000000 --- a/core/src/main/java/org/elasticsearch/tasks/ChildTask.java +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Licensed to Elasticsearch under one or more contributor - * license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright - * ownership. Elasticsearch licenses this file to you under - * the Apache License, Version 2.0 (the "License"); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.elasticsearch.tasks; - -import org.elasticsearch.action.admin.cluster.node.tasks.list.TaskInfo; -import org.elasticsearch.cluster.node.DiscoveryNode; -import org.elasticsearch.common.inject.Provider; - -/** - * Child task - */ -public class ChildTask extends Task { - - private final String parentNode; - - private final long parentId; - - public ChildTask(long id, String type, String action, Provider description, String parentNode, long parentId) { - super(id, type, action, description); - this.parentNode = parentNode; - this.parentId = parentId; - } - - /** - * Returns parent node of the task or null if task doesn't have any parent tasks - */ - public String getParentNode() { - return parentNode; - } - - /** - * Returns id of the parent task or -1L if task doesn't have any parent tasks - */ - public long getParentId() { - return parentId; - } - - public TaskInfo taskInfo(DiscoveryNode node, boolean detailed) { - return new TaskInfo(node, getId(), getType(), getAction(), detailed ? getDescription() : null, parentNode, parentId); - } -} diff --git a/core/src/main/java/org/elasticsearch/tasks/Task.java b/core/src/main/java/org/elasticsearch/tasks/Task.java index 9e925b09d1a..9e02bc7c5e9 100644 --- a/core/src/main/java/org/elasticsearch/tasks/Task.java +++ b/core/src/main/java/org/elasticsearch/tasks/Task.java @@ -29,6 +29,8 @@ import org.elasticsearch.common.inject.Provider; */ public class Task { + public static final long NO_PARENT_ID = 0; + private final long id; private final String type; @@ -37,15 +39,27 @@ public class Task { private final Provider description; + private final String parentNode; + + private final long parentId; + + public Task(long id, String type, String action, Provider description) { + this(id, type, action, description, null, NO_PARENT_ID); + } + + public Task(long id, String type, String action, Provider description, String parentNode, long parentId) { this.id = id; this.type = type; this.action = action; this.description = description; + this.parentNode = parentNode; + this.parentId = parentId; } + public TaskInfo taskInfo(DiscoveryNode node, boolean detailed) { - return new TaskInfo(node, id, type, action, detailed ? getDescription() : null); + return new TaskInfo(node, getId(), getType(), getAction(), detailed ? getDescription() : null, parentNode, parentId); } /** @@ -76,4 +90,18 @@ public class Task { return description.get(); } + /** + * Returns the parent node of the task or null if the task doesn't have any parent tasks + */ + public String getParentNode() { + return parentNode; + } + + /** + * Returns id of the parent task or NO_PARENT_ID if the task doesn't have any parent tasks + */ + public long getParentId() { + return parentId; + } + } diff --git a/core/src/main/java/org/elasticsearch/tasks/TaskManager.java b/core/src/main/java/org/elasticsearch/tasks/TaskManager.java index 68e2dcbe9a5..ef05f911908 100644 --- a/core/src/main/java/org/elasticsearch/tasks/TaskManager.java +++ b/core/src/main/java/org/elasticsearch/tasks/TaskManager.java @@ -25,9 +25,11 @@ import org.elasticsearch.common.util.concurrent.ConcurrentCollections; import org.elasticsearch.common.util.concurrent.ConcurrentMapLong; import org.elasticsearch.transport.TransportRequest; +import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.Map; +import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.atomic.AtomicLong; /** @@ -61,9 +63,9 @@ public class TaskManager extends AbstractComponent { /** * Unregister the task */ - public void unregister(Task task) { + public Task unregister(Task task) { logger.trace("unregister task for id: {}", task.getId()); - tasks.remove(task.getId()); + return tasks.remove(task.getId()); } /** @@ -72,5 +74,4 @@ public class TaskManager extends AbstractComponent { public Map getTasks() { return Collections.unmodifiableMap(new HashMap<>(tasks)); } - } diff --git a/core/src/main/java/org/elasticsearch/transport/TransportService.java b/core/src/main/java/org/elasticsearch/transport/TransportService.java index 8cff05a4d6a..a6a1cab4f05 100644 --- a/core/src/main/java/org/elasticsearch/transport/TransportService.java +++ b/core/src/main/java/org/elasticsearch/transport/TransportService.java @@ -117,7 +117,7 @@ public class TransportService extends AbstractLifecycleComponent> events = new ArrayList<>(); + + public RecordingTaskManagerListener(DiscoveryNode localNode, String... actionMasks) { + this.actionMasks = actionMasks; + this.localNode = localNode; + } + + @Override + public synchronized void onTaskRegistered(Task task) { + if (Regex.simpleMatch(actionMasks, task.getAction())) { + events.add(new Tuple<>(true, task.taskInfo(localNode, true))); + } + } + + @Override + public synchronized void onTaskUnregistered(Task task) { + if (Regex.simpleMatch(actionMasks, task.getAction())) { + events.add(new Tuple<>(false, task.taskInfo(localNode, true))); + } + } + + public synchronized List> getEvents() { + return Collections.unmodifiableList(new ArrayList<>(events)); + } + + public synchronized List getRegistrationEvents() { + List events = this.events.stream().filter(Tuple::v1).map(Tuple::v2).collect(Collectors.toList()); + return Collections.unmodifiableList(events); + } + + public synchronized List getUnregistrationEvents() { + List events = this.events.stream().filter(event -> event.v1() == false).map(Tuple::v2).collect(Collectors.toList()); + return Collections.unmodifiableList(events); + } + + public synchronized void reset() { + events.clear(); + } + +} diff --git a/core/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/TasksIT.java b/core/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/TasksIT.java index 4228c9fa699..fbb93202fcf 100644 --- a/core/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/TasksIT.java +++ b/core/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/TasksIT.java @@ -18,21 +18,277 @@ */ package org.elasticsearch.action.admin.cluster.node.tasks; +import org.elasticsearch.action.admin.cluster.health.ClusterHealthAction; import org.elasticsearch.action.admin.cluster.node.tasks.list.ListTasksAction; import org.elasticsearch.action.admin.cluster.node.tasks.list.ListTasksResponse; +import org.elasticsearch.action.admin.cluster.node.tasks.list.TaskInfo; +import org.elasticsearch.action.admin.indices.refresh.RefreshAction; +import org.elasticsearch.action.admin.indices.upgrade.post.UpgradeAction; +import org.elasticsearch.action.admin.indices.validate.query.ValidateQueryAction; +import org.elasticsearch.action.percolate.PercolateAction; +import org.elasticsearch.cluster.ClusterService; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.collect.Tuple; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.plugins.Plugin; import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.test.tasks.MockTaskManager; +import org.elasticsearch.test.transport.MockTransportService; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Function; import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.lessThanOrEqualTo; /** * Integration tests for task management API + *

+ * We need at least 2 nodes so we have a master node a non-master node */ -@ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.SUITE) +@ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.SUITE, minNumDataNodes = 2) public class TasksIT extends ESIntegTestCase { + private Map, RecordingTaskManagerListener> listeners = new HashMap<>(); + + @Override + protected Collection> nodePlugins() { + return pluginList(MockTransportService.TestPlugin.class); + } + + @Override + protected Settings nodeSettings(int nodeOrdinal) { + return Settings.builder() + .put(super.nodeSettings(nodeOrdinal)) + .put(MockTaskManager.USE_MOCK_TASK_MANAGER, true) + .build(); + } + public void testTaskCounts() { // Run only on data nodes ListTasksResponse response = client().admin().cluster().prepareListTasks("data:true").setActions(ListTasksAction.NAME + "[n]").get(); assertThat(response.getTasks().size(), greaterThanOrEqualTo(cluster().numDataNodes())); } + + public void testMasterNodeOperationTasks() { + registerTaskManageListeners(ClusterHealthAction.NAME); + + // First run the health on the master node - should produce only one task on the master node + internalCluster().masterClient().admin().cluster().prepareHealth().get(); + assertEquals(1, numberOfEvents(ClusterHealthAction.NAME, Tuple::v1)); // counting only registration events + assertEquals(1, numberOfEvents(ClusterHealthAction.NAME, event -> event.v1() == false)); // counting only unregistration events + + resetTaskManageListeners(ClusterHealthAction.NAME); + + // Now run the health on a non-master node - should produce one task on master and one task on another node + internalCluster().nonMasterClient().admin().cluster().prepareHealth().get(); + assertEquals(2, numberOfEvents(ClusterHealthAction.NAME, Tuple::v1)); // counting only registration events + assertEquals(2, numberOfEvents(ClusterHealthAction.NAME, event -> event.v1() == false)); // counting only unregistration events + List tasks = findEvents(ClusterHealthAction.NAME, Tuple::v1); + + // Verify that one of these tasks is a parent of another task + if (tasks.get(0).getParentNode() == null) { + assertParentTask(Collections.singletonList(tasks.get(1)), tasks.get(0)); + } else { + assertParentTask(Collections.singletonList(tasks.get(0)), tasks.get(1)); + } + } + + public void testTransportReplicationAllShardsTasks() { + registerTaskManageListeners(PercolateAction.NAME); // main task + registerTaskManageListeners(PercolateAction.NAME + "[s]"); // shard level tasks + createIndex("test"); + ensureGreen("test"); // Make sure all shards are allocated + client().preparePercolate().setIndices("test").setDocumentType("foo").setSource("{}").get(); + + // the percolate operation should produce one main task + NumShards numberOfShards = getNumShards("test"); + assertEquals(1, numberOfEvents(PercolateAction.NAME, Tuple::v1)); + // and then one operation per shard + assertEquals(numberOfShards.totalNumShards, numberOfEvents(PercolateAction.NAME + "[s]", Tuple::v1)); + + // the shard level tasks should have the main task as a parent + assertParentTask(findEvents(PercolateAction.NAME + "[s]", Tuple::v1), findEvents(PercolateAction.NAME, Tuple::v1).get(0)); + } + + public void testTransportBroadcastByNodeTasks() { + registerTaskManageListeners(UpgradeAction.NAME); // main task + registerTaskManageListeners(UpgradeAction.NAME + "[n]"); // node level tasks + createIndex("test"); + ensureGreen("test"); // Make sure all shards are allocated + client().admin().indices().prepareUpgrade("test").get(); + + // the percolate operation should produce one main task + assertEquals(1, numberOfEvents(UpgradeAction.NAME, Tuple::v1)); + // and then one operation per each node where shards are located + assertEquals(internalCluster().nodesInclude("test").size(), numberOfEvents(UpgradeAction.NAME + "[n]", Tuple::v1)); + + // all node level tasks should have the main task as a parent + assertParentTask(findEvents(UpgradeAction.NAME + "[n]", Tuple::v1), findEvents(UpgradeAction.NAME, Tuple::v1).get(0)); + } + + public void testTransportReplicationSingleShardTasks() { + registerTaskManageListeners(ValidateQueryAction.NAME); // main task + registerTaskManageListeners(ValidateQueryAction.NAME + "[s]"); // shard level tasks + createIndex("test"); + ensureGreen("test"); // Make sure all shards are allocated + client().admin().indices().prepareValidateQuery("test").get(); + + // the validate operation should produce one main task + assertEquals(1, numberOfEvents(ValidateQueryAction.NAME, Tuple::v1)); + // and then one operation + assertEquals(1, numberOfEvents(ValidateQueryAction.NAME + "[s]", Tuple::v1)); + // the shard level operation should have the main task as its parent + assertParentTask(findEvents(ValidateQueryAction.NAME + "[s]", Tuple::v1), findEvents(ValidateQueryAction.NAME, Tuple::v1).get(0)); + } + + + public void testTransportBroadcastReplicationTasks() { + registerTaskManageListeners(RefreshAction.NAME); // main task + registerTaskManageListeners(RefreshAction.NAME + "[s]"); // shard level tasks + registerTaskManageListeners(RefreshAction.NAME + "[s][*]"); // primary and replica shard tasks + createIndex("test"); + ensureGreen("test"); // Make sure all shards are allocated + client().admin().indices().prepareRefresh("test").get(); + + // the refresh operation should produce one main task + NumShards numberOfShards = getNumShards("test"); + + logger.debug("number of shards, total: [{}], primaries: [{}] ", numberOfShards.totalNumShards, numberOfShards.numPrimaries); + logger.debug("main events {}", numberOfEvents(RefreshAction.NAME, Tuple::v1)); + logger.debug("main event node {}", findEvents(RefreshAction.NAME, Tuple::v1).get(0).getNode().name()); + logger.debug("[s] events {}", numberOfEvents(RefreshAction.NAME + "[s]", Tuple::v1)); + logger.debug("[s][*] events {}", numberOfEvents(RefreshAction.NAME + "[s][*]", Tuple::v1)); + logger.debug("nodes with the index {}", internalCluster().nodesInclude("test")); + + assertEquals(1, numberOfEvents(RefreshAction.NAME, Tuple::v1)); + // Because it's broadcast replication action we will have as many [s] level requests + // as we have primary shards on the coordinating node plus we will have one task per primary outside of the + // coordinating node due to replication. + // If all primaries are on the coordinating node, the number of tasks should be equal to the number of primaries + // If all primaries are not on the coordinating node, the number of tasks should be equal to the number of primaries times 2 + assertThat(numberOfEvents(RefreshAction.NAME + "[s]", Tuple::v1), greaterThanOrEqualTo(numberOfShards.numPrimaries)); + assertThat(numberOfEvents(RefreshAction.NAME + "[s]", Tuple::v1), lessThanOrEqualTo(numberOfShards.numPrimaries * 2)); + + // Verify that all [s] events have the proper parent + // This is complicated because if the shard task runs on the same node it has main task as a parent + // but if it runs on non-coordinating node it would have another intermediate [s] task on the coordinating node as a parent + TaskInfo mainTask = findEvents(RefreshAction.NAME, Tuple::v1).get(0); + List sTasks = findEvents(RefreshAction.NAME + "[s]", Tuple::v1); + for (TaskInfo taskInfo : sTasks) { + if (mainTask.getNode().equals(taskInfo.getNode())) { + // This shard level task runs on the same node as a parent task - it should have the main task as a direct parent + assertParentTask(Collections.singletonList(taskInfo), mainTask); + } else { + String description = taskInfo.getDescription(); + // This shard level task runs on another node - it should have a corresponding shard level task on the node where main task is running + List sTasksOnRequestingNode = findEvents(RefreshAction.NAME + "[s]", + event -> event.v1() && mainTask.getNode().equals(event.v2().getNode()) && description.equals(event.v2().getDescription())); + // There should be only one parent task + assertEquals(1, sTasksOnRequestingNode.size()); + assertParentTask(Collections.singletonList(taskInfo), sTasksOnRequestingNode.get(0)); + } + } + + // we will have as many [s][p] and [s][r] tasks as we have primary and replica shards + assertEquals(numberOfShards.totalNumShards, numberOfEvents(RefreshAction.NAME + "[s][*]", Tuple::v1)); + + // we the [s][p] and [s][r] tasks should have a corresponding [s] task on the same node as a parent + List spEvents = findEvents(RefreshAction.NAME + "[s][*]", Tuple::v1); + for (TaskInfo taskInfo : spEvents) { + List sTask; + if (taskInfo.getAction().endsWith("[s][p]")) { + // A [s][p] level task should have a corresponding [s] level task on the same node + sTask = findEvents(RefreshAction.NAME + "[s]", + event -> event.v1() && taskInfo.getNode().equals(event.v2().getNode()) && taskInfo.getDescription().equals(event.v2().getDescription())); + } else { + // A [s][r] level task should have a corresponding [s] level task on the a different node (where primary is located) + sTask = findEvents(RefreshAction.NAME + "[s]", + event -> event.v1() && taskInfo.getParentNode().equals(event.v2().getNode().getId()) && taskInfo.getDescription().equals(event.v2().getDescription())); + } + // There should be only one parent task + assertEquals(1, sTask.size()); + assertParentTask(Collections.singletonList(taskInfo), sTask.get(0)); + } + } + + @Override + public void tearDown() throws Exception { + for (Map.Entry, RecordingTaskManagerListener> entry : listeners.entrySet()) { + ((MockTaskManager)internalCluster().getInstance(ClusterService.class, entry.getKey().v1()).getTaskManager()).removeListener(entry.getValue()); + } + listeners.clear(); + super.tearDown(); + } + + /** + * Registers recording task event listeners with the given action mask on all nodes + */ + private void registerTaskManageListeners(String actionMasks) { + for (ClusterService clusterService : internalCluster().getInstances(ClusterService.class)) { + DiscoveryNode node = clusterService.localNode(); + RecordingTaskManagerListener listener = new RecordingTaskManagerListener(node, Strings.splitStringToArray(actionMasks, ',')); + ((MockTaskManager)clusterService.getTaskManager()).addListener(listener); + RecordingTaskManagerListener oldListener = listeners.put(new Tuple<>(node.name(), actionMasks), listener); + assertNull(oldListener); + } + } + + /** + * Resets all recording task event listeners with the given action mask on all nodes + */ + private void resetTaskManageListeners(String actionMasks) { + for (Map.Entry, RecordingTaskManagerListener> entry : listeners.entrySet()) { + if (actionMasks == null || entry.getKey().v2().equals(actionMasks)) { + entry.getValue().reset(); + } + } + } + + /** + * Returns the number of events that satisfy the criteria across all nodes + * + * @param actionMasks action masks to match + * @return number of events that satisfy the criteria + */ + private int numberOfEvents(String actionMasks, Function, Boolean> criteria) { + return findEvents(actionMasks, criteria).size(); + } + + /** + * Returns all events that satisfy the criteria across all nodes + * + * @param actionMasks action masks to match + * @return number of events that satisfy the criteria + */ + private List findEvents(String actionMasks, Function, Boolean> criteria) { + List events = new ArrayList<>(); + for (Map.Entry, RecordingTaskManagerListener> entry : listeners.entrySet()) { + if (actionMasks == null || entry.getKey().v2().equals(actionMasks)) { + for (Tuple taskEvent : entry.getValue().getEvents()) { + if (criteria.apply(taskEvent)) { + events.add(taskEvent.v2()); + } + } + } + } + return events; + } + + /** + * Asserts that all tasks in the tasks list have the same parentTask + */ + private void assertParentTask(List tasks, TaskInfo parentTask) { + for (TaskInfo task : tasks) { + assertNotNull(task.getParentNode()); + assertEquals(parentTask.getNode().getId(), task.getParentNode()); + assertEquals(parentTask.getId(), task.getParentId()); + } + } } diff --git a/core/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/TransportTasksActionTests.java b/core/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/TransportTasksActionTests.java index 3fbac003419..4e1c08261a3 100644 --- a/core/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/TransportTasksActionTests.java +++ b/core/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/TransportTasksActionTests.java @@ -48,10 +48,11 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.lease.Releasable; import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.tasks.ChildTask; import org.elasticsearch.tasks.Task; +import org.elasticsearch.tasks.TaskManager; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.cluster.TestClusterService; +import org.elasticsearch.test.tasks.MockTaskManager; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; import org.elasticsearch.transport.local.LocalTransport; @@ -95,12 +96,11 @@ public class TransportTasksActionTests extends ESTestCase { threadPool = null; } - @Before - public final void setupTestNodes() throws Exception { + public void setupTestNodes(Settings settings) { nodesCount = randomIntBetween(2, 10); testNodes = new TestNode[nodesCount]; for (int i = 0; i < testNodes.length; i++) { - testNodes[i] = new TestNode("node" + i, threadPool, Settings.EMPTY); + testNodes[i] = new TestNode("node" + i, threadPool, settings); } } @@ -113,11 +113,20 @@ public class TransportTasksActionTests extends ESTestCase { private static class TestNode implements Releasable { public TestNode(String name, ThreadPool threadPool, Settings settings) { - clusterService = new TestClusterService(threadPool); - transportService = new TransportService(Settings.EMPTY, - new LocalTransport(Settings.EMPTY, threadPool, Version.CURRENT, new NamedWriteableRegistry()), - threadPool); + transportService = new TransportService(settings, + new LocalTransport(settings, threadPool, Version.CURRENT, new NamedWriteableRegistry()), + threadPool){ + @Override + protected TaskManager createTaskManager() { + if (settings.getAsBoolean(MockTaskManager.USE_MOCK_TASK_MANAGER, false)) { + return new MockTaskManager(settings); + } else { + return super.createTaskManager(); + } + } + }; transportService.start(); + clusterService = new TestClusterService(threadPool, transportService); discoveryNode = new DiscoveryNode(name, transportService.boundAddress().publishAddress(), Version.CURRENT); transportListTasksAction = new TransportListTasksAction(settings, clusterName, threadPool, clusterService, transportService, new ActionFilters(Collections.emptySet()), new IndexNameExpressionResolver(settings)); @@ -150,6 +159,15 @@ public class TransportTasksActionTests extends ESTestCase { } } + public static RecordingTaskManagerListener[] setupListeners(TestNode[] nodes, String... actionMasks) { + RecordingTaskManagerListener[] listeners = new RecordingTaskManagerListener[nodes.length]; + for (int i = 0; i < nodes.length; i++) { + listeners[i] = new RecordingTaskManagerListener(nodes[i].discoveryNode, actionMasks); + ((MockTaskManager)(nodes[i].clusterService.getTaskManager())).addListener(listeners[i]); + } + return listeners; + } + public static class NodeRequest extends BaseNodeRequest { protected String requestName; private boolean enableTaskManager; @@ -180,7 +198,7 @@ public class TransportTasksActionTests extends ESTestCase { @Override public String getDescription() { - return "NodeRequest[" + requestName + ", " + enableTaskManager + "]"; + return "NodeRequest[" + requestName + ", " + enableTaskManager + "]"; } @Override @@ -464,6 +482,7 @@ public class TransportTasksActionTests extends ESTestCase { } public void testRunningTasksCount() throws Exception { + setupTestNodes(Settings.EMPTY); connectNodes(testNodes); CountDownLatch checkLatch = new CountDownLatch(1); CountDownLatch responseLatch = new CountDownLatch(1); @@ -553,6 +572,7 @@ public class TransportTasksActionTests extends ESTestCase { } public void testFindChildTasks() throws Exception { + setupTestNodes(Settings.EMPTY); connectNodes(testNodes); CountDownLatch checkLatch = new CountDownLatch(1); ActionFuture future = startBlockingTestNodesAction(checkLatch); @@ -586,10 +606,11 @@ public class TransportTasksActionTests extends ESTestCase { } public void testTaskManagementOptOut() throws Exception { + setupTestNodes(Settings.EMPTY); connectNodes(testNodes); CountDownLatch checkLatch = new CountDownLatch(1); // Starting actions that disable task manager - ActionFuture future = startBlockingTestNodesAction(checkLatch, new NodesRequest("Test Request", false)); + ActionFuture future = startBlockingTestNodesAction(checkLatch, new NodesRequest("Test Request", false)); TestNode testNode = testNodes[randomIntBetween(0, testNodes.length - 1)]; @@ -606,6 +627,7 @@ public class TransportTasksActionTests extends ESTestCase { } public void testTasksDescriptions() throws Exception { + setupTestNodes(Settings.EMPTY); connectNodes(testNodes); CountDownLatch checkLatch = new CountDownLatch(1); ActionFuture future = startBlockingTestNodesAction(checkLatch); @@ -637,8 +659,11 @@ public class TransportTasksActionTests extends ESTestCase { } public void testFailedTasksCount() throws ExecutionException, InterruptedException, IOException { + Settings settings = Settings.builder().put(MockTaskManager.USE_MOCK_TASK_MANAGER, true).build(); + setupTestNodes(settings); connectNodes(testNodes); TestNodesAction[] actions = new TestNodesAction[nodesCount]; + RecordingTaskManagerListener[] listeners = setupListeners(testNodes, "testAction*"); for (int i = 0; i < testNodes.length; i++) { final int node = i; actions[i] = new TestNodesAction(Settings.EMPTY, "testAction", clusterName, threadPool, testNodes[i].clusterService, testNodes[i].transportService) { @@ -656,9 +681,21 @@ public class TransportTasksActionTests extends ESTestCase { NodesRequest request = new NodesRequest("Test Request"); NodesResponse responses = actions[0].execute(request).get(); assertEquals(nodesCount, responses.failureCount()); + + // Make sure that actions are still registered in the task manager on all nodes + // Twice on the coordinating node and once on all other nodes. + assertEquals(4, listeners[0].getEvents().size()); + assertEquals(2, listeners[0].getRegistrationEvents().size()); + assertEquals(2, listeners[0].getUnregistrationEvents().size()); + for (int i = 1; i < listeners.length; i++) { + assertEquals(2, listeners[i].getEvents().size()); + assertEquals(1, listeners[i].getRegistrationEvents().size()); + assertEquals(1, listeners[i].getUnregistrationEvents().size()); + } } public void testTaskLevelActionFailures() throws ExecutionException, InterruptedException, IOException { + setupTestNodes(Settings.EMPTY); connectNodes(testNodes); CountDownLatch checkLatch = new CountDownLatch(1); ActionFuture future = startBlockingTestNodesAction(checkLatch); @@ -672,7 +709,7 @@ public class TransportTasksActionTests extends ESTestCase { @Override protected TestTaskResponse taskOperation(TestTasksRequest request, Task task) { logger.info("Task action on node " + node); - if (failTaskOnNode == node && ((ChildTask) task).getParentNode() != null) { + if (failTaskOnNode == node && task.getParentNode() != null) { logger.info("Failing on node " + node); throw new RuntimeException("Task level failure"); } diff --git a/core/src/test/java/org/elasticsearch/action/support/broadcast/node/TransportBroadcastByNodeActionTests.java b/core/src/test/java/org/elasticsearch/action/support/broadcast/node/TransportBroadcastByNodeActionTests.java index 76307ccd806..a408ccc5bf9 100644 --- a/core/src/test/java/org/elasticsearch/action/support/broadcast/node/TransportBroadcastByNodeActionTests.java +++ b/core/src/test/java/org/elasticsearch/action/support/broadcast/node/TransportBroadcastByNodeActionTests.java @@ -242,7 +242,7 @@ public class TransportBroadcastByNodeActionTests extends ESTestCase { .addGlobalBlock(new ClusterBlock(1, "test-block", false, true, RestStatus.SERVICE_UNAVAILABLE, ClusterBlockLevel.ALL)); clusterService.setState(ClusterState.builder(clusterService.state()).blocks(block)); try { - action.new AsyncAction(request, listener).start(); + action.new AsyncAction(null, request, listener).start(); fail("expected ClusterBlockException"); } catch (ClusterBlockException expected) { assertEquals("blocked by: [SERVICE_UNAVAILABLE/1/test-block];", expected.getMessage()); @@ -257,7 +257,7 @@ public class TransportBroadcastByNodeActionTests extends ESTestCase { .addIndexBlock(TEST_INDEX, new ClusterBlock(1, "test-block", false, true, RestStatus.SERVICE_UNAVAILABLE, ClusterBlockLevel.ALL)); clusterService.setState(ClusterState.builder(clusterService.state()).blocks(block)); try { - action.new AsyncAction(request, listener).start(); + action.new AsyncAction(null, request, listener).start(); fail("expected ClusterBlockException"); } catch (ClusterBlockException expected) { assertEquals("blocked by: [SERVICE_UNAVAILABLE/1/test-block];", expected.getMessage()); @@ -268,7 +268,7 @@ public class TransportBroadcastByNodeActionTests extends ESTestCase { Request request = new Request(new String[]{TEST_INDEX}); PlainActionFuture listener = new PlainActionFuture<>(); - action.new AsyncAction(request, listener).start(); + action.new AsyncAction(null, request, listener).start(); Map> capturedRequests = transport.getCapturedRequestsByTargetNodeAndClear(); ShardsIterator shardIt = clusterService.state().routingTable().allShards(new String[]{TEST_INDEX}); @@ -302,7 +302,7 @@ public class TransportBroadcastByNodeActionTests extends ESTestCase { clusterService.setState(ClusterState.builder(clusterService.state()).nodes(builder)); - action.new AsyncAction(request, listener).start(); + action.new AsyncAction(null, request, listener).start(); Map> capturedRequests = transport.getCapturedRequestsByTargetNodeAndClear(); @@ -389,7 +389,7 @@ public class TransportBroadcastByNodeActionTests extends ESTestCase { clusterService.setState(ClusterState.builder(clusterService.state()).nodes(builder)); } - action.new AsyncAction(request, listener).start(); + action.new AsyncAction(null, request, listener).start(); Map> capturedRequests = transport.getCapturedRequestsByTargetNodeAndClear(); ShardsIterator shardIt = clusterService.state().getRoutingTable().allShards(new String[]{TEST_INDEX}); diff --git a/core/src/test/java/org/elasticsearch/action/support/replication/BroadcastReplicationTests.java b/core/src/test/java/org/elasticsearch/action/support/replication/BroadcastReplicationTests.java index fccdd494af7..7b9fd91a567 100644 --- a/core/src/test/java/org/elasticsearch/action/support/replication/BroadcastReplicationTests.java +++ b/core/src/test/java/org/elasticsearch/action/support/replication/BroadcastReplicationTests.java @@ -41,6 +41,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.ConcurrentCollections; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.tasks.Task; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.cluster.TestClusterService; import org.elasticsearch.threadpool.ThreadPool; @@ -207,7 +208,7 @@ public class BroadcastReplicationTests extends ESTestCase { } @Override - protected void shardExecute(DummyBroadcastRequest request, ShardId shardId, ActionListener shardActionListener) { + protected void shardExecute(Task task, DummyBroadcastRequest request, ShardId shardId, ActionListener shardActionListener) { capturedShardRequests.add(new Tuple<>(shardId, shardActionListener)); } } diff --git a/core/src/test/java/org/elasticsearch/action/support/replication/TransportReplicationActionTests.java b/core/src/test/java/org/elasticsearch/action/support/replication/TransportReplicationActionTests.java index c868b0d036f..402a454649b 100644 --- a/core/src/test/java/org/elasticsearch/action/support/replication/TransportReplicationActionTests.java +++ b/core/src/test/java/org/elasticsearch/action/support/replication/TransportReplicationActionTests.java @@ -140,7 +140,7 @@ public class TransportReplicationActionTests extends ESTestCase { ClusterBlocks.Builder block = ClusterBlocks.builder() .addGlobalBlock(new ClusterBlock(1, "non retryable", false, true, RestStatus.SERVICE_UNAVAILABLE, ClusterBlockLevel.ALL)); clusterService.setState(ClusterState.builder(clusterService.state()).blocks(block)); - TransportReplicationAction.ReroutePhase reroutePhase = action.new ReroutePhase(request, listener); + TransportReplicationAction.ReroutePhase reroutePhase = action.new ReroutePhase(null, request, listener); reroutePhase.run(); assertListenerThrows("primary phase should fail operation", listener, ClusterBlockException.class); @@ -148,13 +148,13 @@ public class TransportReplicationActionTests extends ESTestCase { .addGlobalBlock(new ClusterBlock(1, "retryable", true, true, RestStatus.SERVICE_UNAVAILABLE, ClusterBlockLevel.ALL)); clusterService.setState(ClusterState.builder(clusterService.state()).blocks(block)); listener = new PlainActionFuture<>(); - reroutePhase = action.new ReroutePhase(new Request().timeout("5ms"), listener); + reroutePhase = action.new ReroutePhase(null, new Request().timeout("5ms"), listener); reroutePhase.run(); assertListenerThrows("failed to timeout on retryable block", listener, ClusterBlockException.class); listener = new PlainActionFuture<>(); - reroutePhase = action.new ReroutePhase(new Request(), listener); + reroutePhase = action.new ReroutePhase(null, new Request(), listener); reroutePhase.run(); assertFalse("primary phase should wait on retryable block", listener.isDone()); @@ -180,13 +180,13 @@ public class TransportReplicationActionTests extends ESTestCase { Request request = new Request(shardId).timeout("1ms"); PlainActionFuture listener = new PlainActionFuture<>(); - TransportReplicationAction.ReroutePhase reroutePhase = action.new ReroutePhase(request, listener); + TransportReplicationAction.ReroutePhase reroutePhase = action.new ReroutePhase(null, request, listener); reroutePhase.run(); assertListenerThrows("unassigned primary didn't cause a timeout", listener, UnavailableShardsException.class); request = new Request(shardId); listener = new PlainActionFuture<>(); - reroutePhase = action.new ReroutePhase(request, listener); + reroutePhase = action.new ReroutePhase(null, request, listener); reroutePhase.run(); assertFalse("unassigned primary didn't cause a retry", listener.isDone()); @@ -211,12 +211,12 @@ public class TransportReplicationActionTests extends ESTestCase { logger.debug("--> using initial state:\n{}", clusterService.state().prettyPrint()); Request request = new Request(new ShardId("unknown_index", "_na_", 0)).timeout("1ms"); PlainActionFuture listener = new PlainActionFuture<>(); - TransportReplicationAction.ReroutePhase reroutePhase = action.new ReroutePhase(request, listener); + TransportReplicationAction.ReroutePhase reroutePhase = action.new ReroutePhase(null, request, listener); reroutePhase.run(); assertListenerThrows("must throw index not found exception", listener, IndexNotFoundException.class); request = new Request(new ShardId(index, "_na_", 10)).timeout("1ms"); listener = new PlainActionFuture<>(); - reroutePhase = action.new ReroutePhase(request, listener); + reroutePhase = action.new ReroutePhase(null, request, listener); reroutePhase.run(); assertListenerThrows("must throw shard not found exception", listener, ShardNotFoundException.class); } @@ -234,7 +234,7 @@ public class TransportReplicationActionTests extends ESTestCase { Request request = new Request(shardId); PlainActionFuture listener = new PlainActionFuture<>(); - TransportReplicationAction.ReroutePhase reroutePhase = action.new ReroutePhase(request, listener); + TransportReplicationAction.ReroutePhase reroutePhase = action.new ReroutePhase(null, request, listener); reroutePhase.run(); assertThat(request.shardId(), equalTo(shardId)); logger.info("--> primary is assigned to [{}], checking request forwarded", primaryNodeId); diff --git a/modules/lang-groovy/src/test/java/org/elasticsearch/messy/tests/IndicesRequestTests.java b/modules/lang-groovy/src/test/java/org/elasticsearch/messy/tests/IndicesRequestTests.java index 9a3a4632c6f..b5cd130e191 100644 --- a/modules/lang-groovy/src/test/java/org/elasticsearch/messy/tests/IndicesRequestTests.java +++ b/modules/lang-groovy/src/test/java/org/elasticsearch/messy/tests/IndicesRequestTests.java @@ -92,6 +92,7 @@ import org.elasticsearch.plugins.Plugin; import org.elasticsearch.script.Script; import org.elasticsearch.script.groovy.GroovyPlugin; import org.elasticsearch.search.action.SearchServiceTransportAction; +import org.elasticsearch.tasks.Task; import org.elasticsearch.tasks.TaskManager; import org.elasticsearch.test.ESIntegTestCase; import org.elasticsearch.test.ESIntegTestCase.ClusterScope; @@ -821,7 +822,7 @@ public class IndicesRequestTests extends ESIntegTestCase { } @Override - public void messageReceived(T request, TransportChannel channel) throws Exception { + public void messageReceived(T request, TransportChannel channel, Task task) throws Exception { synchronized (InterceptingTransportService.this) { if (actions.contains(action)) { List requestList = requests.get(action); @@ -834,7 +835,12 @@ public class IndicesRequestTests extends ESIntegTestCase { } } } - requestHandler.messageReceived(request, channel); + requestHandler.messageReceived(request, channel, task); + } + + @Override + public void messageReceived(T request, TransportChannel channel) throws Exception { + messageReceived(request, channel, null); } } } diff --git a/test/framework/src/main/java/org/elasticsearch/test/cluster/TestClusterService.java b/test/framework/src/main/java/org/elasticsearch/test/cluster/TestClusterService.java index 92b5f9a584b..172c746e88e 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/cluster/TestClusterService.java +++ b/test/framework/src/main/java/org/elasticsearch/test/cluster/TestClusterService.java @@ -49,6 +49,7 @@ import org.elasticsearch.common.util.concurrent.ConcurrentCollections; import org.elasticsearch.common.util.concurrent.FutureUtils; import org.elasticsearch.tasks.TaskManager; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportService; import java.util.Arrays; import java.util.Iterator; @@ -77,6 +78,11 @@ public class TestClusterService implements ClusterService { taskManager = new TaskManager(Settings.EMPTY); } + public TestClusterService(ThreadPool threadPool, TransportService transportService) { + this(ClusterState.builder(new ClusterName("test")).build(), threadPool); + taskManager = transportService.getTaskManager(); + } + public TestClusterService(ClusterState state) { this(state, null); } diff --git a/test/framework/src/main/java/org/elasticsearch/test/tasks/MockTaskManager.java b/test/framework/src/main/java/org/elasticsearch/test/tasks/MockTaskManager.java new file mode 100644 index 00000000000..9b6bc72162c --- /dev/null +++ b/test/framework/src/main/java/org/elasticsearch/test/tasks/MockTaskManager.java @@ -0,0 +1,82 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.test.tasks; + +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.tasks.TaskManager; +import org.elasticsearch.transport.TransportRequest; + +import java.util.Collection; +import java.util.concurrent.CopyOnWriteArrayList; + +/** + * A mock task manager that allows adding listeners for events + */ +public class MockTaskManager extends TaskManager { + + public static final String USE_MOCK_TASK_MANAGER = "tests.mock.taskmanager.enabled"; + + private final Collection listeners = new CopyOnWriteArrayList<>(); + + public MockTaskManager(Settings settings) { + super(settings); + } + + @Override + public Task register(String type, String action, TransportRequest request) { + Task task = super.register(type, action, request); + if (task != null) { + for (MockTaskManagerListener listener : listeners) { + try { + listener.onTaskRegistered(task); + } catch (Throwable t) { + logger.warn("failed to notify task manager listener about unregistering the task with id {}", t, task.getId()); + } + } + } + return task; + } + + @Override + public Task unregister(Task task) { + Task removedTask = super.unregister(task); + if (removedTask != null) { + for (MockTaskManagerListener listener : listeners) { + try { + listener.onTaskUnregistered(task); + } catch (Throwable t) { + logger.warn("failed to notify task manager listener about unregistering the task with id {}", t, task.getId()); + } + } + } else { + logger.warn("trying to remove the same with id {} twice", task.getId()); + } + return removedTask; + } + + public void addListener(MockTaskManagerListener listener) { + listeners.add(listener); + } + + public void removeListener(MockTaskManagerListener listener) { + listeners.remove(listener); + } +} diff --git a/test/framework/src/main/java/org/elasticsearch/test/tasks/MockTaskManagerListener.java b/test/framework/src/main/java/org/elasticsearch/test/tasks/MockTaskManagerListener.java new file mode 100644 index 00000000000..d10dd357999 --- /dev/null +++ b/test/framework/src/main/java/org/elasticsearch/test/tasks/MockTaskManagerListener.java @@ -0,0 +1,31 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.test.tasks; + +import org.elasticsearch.tasks.Task; + +/** + * Listener for task registration/unregistration + */ +public interface MockTaskManagerListener { + void onTaskRegistered(Task task); + + void onTaskUnregistered(Task task); +} diff --git a/test/framework/src/main/java/org/elasticsearch/test/transport/MockTransportService.java b/test/framework/src/main/java/org/elasticsearch/test/transport/MockTransportService.java index 985c8a86838..84c981d60d5 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/transport/MockTransportService.java +++ b/test/framework/src/main/java/org/elasticsearch/test/transport/MockTransportService.java @@ -34,6 +34,8 @@ import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.common.util.concurrent.ConcurrentCollections; import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.tasks.TaskManager; +import org.elasticsearch.test.tasks.MockTaskManager; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.ConnectTransportException; import org.elasticsearch.transport.RequestHandlerRegistry; @@ -100,6 +102,15 @@ public class MockTransportService extends TransportService { return transportAddresses.toArray(new TransportAddress[transportAddresses.size()]); } + @Override + protected TaskManager createTaskManager() { + if (settings.getAsBoolean(MockTaskManager.USE_MOCK_TASK_MANAGER, false)) { + return new MockTaskManager(settings); + } else { + return super.createTaskManager(); + } + } + /** * Clears all the registered rules. */ From e3816d58aeb8ee750e0ada6ac97f379b7f3578b3 Mon Sep 17 00:00:00 2001 From: Ali Beyad Date: Thu, 28 Jan 2016 17:10:04 -0500 Subject: [PATCH 08/11] TribeIT.testOnConflictDrop test awaits fix until #16299 is resolved --- core/src/test/java/org/elasticsearch/tribe/TribeIT.java | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/test/java/org/elasticsearch/tribe/TribeIT.java b/core/src/test/java/org/elasticsearch/tribe/TribeIT.java index 260c6252efd..506321684ba 100644 --- a/core/src/test/java/org/elasticsearch/tribe/TribeIT.java +++ b/core/src/test/java/org/elasticsearch/tribe/TribeIT.java @@ -225,6 +225,7 @@ public class TribeIT extends ESIntegTestCase { } } + @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/16299") public void testOnConflictDrop() throws Exception { logger.info("create 2 indices, test1 on t1, and test2 on t2"); assertAcked(cluster().client().admin().indices().prepareCreate("conflict")); From 9ec1b11148fe2a347c31f64d6630a42fffbef763 Mon Sep 17 00:00:00 2001 From: Martijn van Groningen Date: Thu, 28 Jan 2016 10:31:43 +0100 Subject: [PATCH 09/11] term vectors: The term vector APIs no longer modify the mappings if an unmapped field is found --- .../index/termvectors/TermVectorsService.java | 9 +--- .../action/termvectors/GetTermVectorsIT.java | 47 ------------------- docs/reference/migration/migrate_3_0.asciidoc | 6 +++ 3 files changed, 7 insertions(+), 55 deletions(-) diff --git a/core/src/main/java/org/elasticsearch/index/termvectors/TermVectorsService.java b/core/src/main/java/org/elasticsearch/index/termvectors/TermVectorsService.java index fbc18fd578a..c78d125ab46 100644 --- a/core/src/main/java/org/elasticsearch/index/termvectors/TermVectorsService.java +++ b/core/src/main/java/org/elasticsearch/index/termvectors/TermVectorsService.java @@ -71,12 +71,10 @@ import static org.elasticsearch.index.mapper.SourceToParse.source; public class TermVectorsService { - private final MappingUpdatedAction mappingUpdatedAction; private final TransportDfsOnlyAction dfsAction; @Inject - public TermVectorsService(MappingUpdatedAction mappingUpdatedAction, TransportDfsOnlyAction dfsAction) { - this.mappingUpdatedAction = mappingUpdatedAction; + public TermVectorsService(TransportDfsOnlyAction dfsAction) { this.dfsAction = dfsAction; } @@ -293,16 +291,11 @@ public class TermVectorsService { private ParsedDocument parseDocument(IndexShard indexShard, String index, String type, BytesReference doc) throws Throwable { MapperService mapperService = indexShard.mapperService(); - - // TODO: make parsing not dynamically create fields not in the original mapping DocumentMapperForType docMapper = mapperService.documentMapperWithAutoCreate(type); ParsedDocument parsedDocument = docMapper.getDocumentMapper().parse(source(doc).index(index).type(type).flyweight(true)); if (docMapper.getMapping() != null) { parsedDocument.addDynamicMappingsUpdate(docMapper.getMapping()); } - if (parsedDocument.dynamicMappingsUpdate() != null) { - mappingUpdatedAction.updateMappingOnMasterSynchronously(index, type, parsedDocument.dynamicMappingsUpdate()); - } return parsedDocument; } diff --git a/core/src/test/java/org/elasticsearch/action/termvectors/GetTermVectorsIT.java b/core/src/test/java/org/elasticsearch/action/termvectors/GetTermVectorsIT.java index 16fd9f4b718..45b045a7f4b 100644 --- a/core/src/test/java/org/elasticsearch/action/termvectors/GetTermVectorsIT.java +++ b/core/src/test/java/org/elasticsearch/action/termvectors/GetTermVectorsIT.java @@ -870,53 +870,6 @@ public class GetTermVectorsIT extends AbstractTermVectorsTestCase { checkBrownFoxTermVector(resp.getFields(), "field1", false); } - public void testArtificialNonExistingField() throws Exception { - // setup indices - Settings.Builder settings = settingsBuilder() - .put(indexSettings()) - .put("index.analysis.analyzer", "standard"); - assertAcked(prepareCreate("test") - .setSettings(settings) - .addMapping("type1", "field1", "type=string")); - ensureGreen(); - - // index just one doc - List indexBuilders = new ArrayList<>(); - indexBuilders.add(client().prepareIndex() - .setIndex("test") - .setType("type1") - .setId("1") - .setRouting("1") - .setSource("field1", "some text")); - indexRandom(true, indexBuilders); - - // request tvs from artificial document - XContentBuilder doc = jsonBuilder() - .startObject() - .field("field1", "the quick brown fox jumps over the lazy dog") - .field("non_existing", "the quick brown fox jumps over the lazy dog") - .endObject(); - - for (int i = 0; i < 2; i++) { - TermVectorsResponse resp = client().prepareTermVectors() - .setIndex("test") - .setType("type1") - .setDoc(doc) - .setRouting("" + i) - .setOffsets(true) - .setPositions(true) - .setFieldStatistics(true) - .setTermStatistics(true) - .get(); - assertThat(resp.isExists(), equalTo(true)); - checkBrownFoxTermVector(resp.getFields(), "field1", false); - // we should have created a mapping for this field - assertMappingOnMaster("test", "type1", "non_existing"); - // and return the generated term vectors - checkBrownFoxTermVector(resp.getFields(), "non_existing", false); - } - } - public void testPerFieldAnalyzer() throws IOException { int numFields = 25; diff --git a/docs/reference/migration/migrate_3_0.asciidoc b/docs/reference/migration/migrate_3_0.asciidoc index 067932ca21f..78f8ff40307 100644 --- a/docs/reference/migration/migrate_3_0.asciidoc +++ b/docs/reference/migration/migrate_3_0.asciidoc @@ -18,6 +18,7 @@ your application to Elasticsearch 3.0. * <> * <> * <> +* <> [[breaking_30_search_changes]] === Warmers @@ -707,3 +708,8 @@ Previously script mode settings (e.g., "script.inline: true", values `off`, `false`, `0`, and `no` for disabling a scripting mode. The variants `on`, `1`, and `yes ` for enabling and `off`, `0`, and `no` for disabling are no longer supported. + +[[breaking_30_term_vectors]] +=== Term vectors + +The term vectors APIs no longer persist unmapped fields in the mappings. From f5e89f724281dd8d20e4fcbebeb29e2885d095bc Mon Sep 17 00:00:00 2001 From: Martijn van Groningen Date: Thu, 28 Jan 2016 10:15:04 +0100 Subject: [PATCH 10/11] mappings: remove fly weight --- .../index/mapper/ParseContext.java | 13 +--------- .../index/mapper/SourceToParse.java | 11 --------- .../index/mapper/internal/IdFieldMapper.java | 2 +- .../mapper/internal/ParentFieldMapper.java | 4 +--- .../mapper/internal/SourceFieldMapper.java | 5 ++-- .../index/mapper/internal/TTLFieldMapper.java | 2 +- .../index/mapper/internal/UidFieldMapper.java | 2 +- .../percolator/PercolatorFieldMapper.java | 4 +--- .../index/termvectors/TermVectorsService.java | 2 +- .../percolator/PercolateDocumentParser.java | 24 +++++++++---------- .../index/mapper/size/SizeFieldMapper.java | 2 +- 11 files changed, 22 insertions(+), 49 deletions(-) diff --git a/core/src/main/java/org/elasticsearch/index/mapper/ParseContext.java b/core/src/main/java/org/elasticsearch/index/mapper/ParseContext.java index 3c12f51a7f7..938dd778b0e 100644 --- a/core/src/main/java/org/elasticsearch/index/mapper/ParseContext.java +++ b/core/src/main/java/org/elasticsearch/index/mapper/ParseContext.java @@ -181,11 +181,6 @@ public abstract class ParseContext { this.in = in; } - @Override - public boolean flyweight() { - return in.flyweight(); - } - @Override public DocumentMapperParser docMapperParser() { return in.docMapperParser(); @@ -411,11 +406,6 @@ public abstract class ParseContext { this.dynamicMappingsUpdate = null; } - @Override - public boolean flyweight() { - return sourceToParse.flyweight(); - } - @Override public DocumentMapperParser docMapperParser() { return this.docMapperParser; @@ -580,8 +570,6 @@ public abstract class ParseContext { } } - public abstract boolean flyweight(); - public abstract DocumentMapperParser docMapperParser(); /** @@ -658,6 +646,7 @@ public abstract class ParseContext { public abstract SourceToParse sourceToParse(); + @Nullable public abstract BytesReference source(); // only should be used by SourceFieldMapper to update with a compressed source diff --git a/core/src/main/java/org/elasticsearch/index/mapper/SourceToParse.java b/core/src/main/java/org/elasticsearch/index/mapper/SourceToParse.java index f65072d489e..6094caa319a 100644 --- a/core/src/main/java/org/elasticsearch/index/mapper/SourceToParse.java +++ b/core/src/main/java/org/elasticsearch/index/mapper/SourceToParse.java @@ -46,8 +46,6 @@ public class SourceToParse { private final XContentParser parser; - private boolean flyweight = false; - private String index; private String type; @@ -106,15 +104,6 @@ public class SourceToParse { return this; } - public SourceToParse flyweight(boolean flyweight) { - this.flyweight = flyweight; - return this; - } - - public boolean flyweight() { - return this.flyweight; - } - public String id() { return this.id; } diff --git a/core/src/main/java/org/elasticsearch/index/mapper/internal/IdFieldMapper.java b/core/src/main/java/org/elasticsearch/index/mapper/internal/IdFieldMapper.java index a586a7b5b94..1f26dd60841 100644 --- a/core/src/main/java/org/elasticsearch/index/mapper/internal/IdFieldMapper.java +++ b/core/src/main/java/org/elasticsearch/index/mapper/internal/IdFieldMapper.java @@ -220,7 +220,7 @@ public class IdFieldMapper extends MetadataFieldMapper { @Override public void postParse(ParseContext context) throws IOException { - if (context.id() == null && !context.sourceToParse().flyweight()) { + if (context.id() == null) { throw new MapperParsingException("No id found while parsing the content source"); } // it either get built in the preParse phase, or get parsed... diff --git a/core/src/main/java/org/elasticsearch/index/mapper/internal/ParentFieldMapper.java b/core/src/main/java/org/elasticsearch/index/mapper/internal/ParentFieldMapper.java index 86009ffcccc..e7cd1b107ae 100644 --- a/core/src/main/java/org/elasticsearch/index/mapper/internal/ParentFieldMapper.java +++ b/core/src/main/java/org/elasticsearch/index/mapper/internal/ParentFieldMapper.java @@ -228,9 +228,7 @@ public class ParentFieldMapper extends MetadataFieldMapper { @Override public void postParse(ParseContext context) throws IOException { - if (context.sourceToParse().flyweight() == false) { - parse(context); - } + parse(context); } @Override diff --git a/core/src/main/java/org/elasticsearch/index/mapper/internal/SourceFieldMapper.java b/core/src/main/java/org/elasticsearch/index/mapper/internal/SourceFieldMapper.java index 1925b2b2faa..519a38c4ff3 100644 --- a/core/src/main/java/org/elasticsearch/index/mapper/internal/SourceFieldMapper.java +++ b/core/src/main/java/org/elasticsearch/index/mapper/internal/SourceFieldMapper.java @@ -251,10 +251,11 @@ public class SourceFieldMapper extends MetadataFieldMapper { if (!fieldType().stored()) { return; } - if (context.flyweight()) { + BytesReference source = context.source(); + // Percolate and tv APIs may not set the source and that is ok, because these APIs will not index any data + if (source == null) { return; } - BytesReference source = context.source(); boolean filtered = (includes != null && includes.length > 0) || (excludes != null && excludes.length > 0); if (filtered) { diff --git a/core/src/main/java/org/elasticsearch/index/mapper/internal/TTLFieldMapper.java b/core/src/main/java/org/elasticsearch/index/mapper/internal/TTLFieldMapper.java index dbf63a7f801..7c51b05cb4b 100644 --- a/core/src/main/java/org/elasticsearch/index/mapper/internal/TTLFieldMapper.java +++ b/core/src/main/java/org/elasticsearch/index/mapper/internal/TTLFieldMapper.java @@ -212,7 +212,7 @@ public class TTLFieldMapper extends MetadataFieldMapper { @Override protected void parseCreateField(ParseContext context, List fields) throws IOException, AlreadyExpiredException { - if (enabledState.enabled && !context.sourceToParse().flyweight()) { + if (enabledState.enabled) { long ttl = context.sourceToParse().ttl(); if (ttl <= 0 && defaultTTL > 0) { // no ttl provided so we use the default value ttl = defaultTTL; diff --git a/core/src/main/java/org/elasticsearch/index/mapper/internal/UidFieldMapper.java b/core/src/main/java/org/elasticsearch/index/mapper/internal/UidFieldMapper.java index 828651409b1..f8fea4071e5 100644 --- a/core/src/main/java/org/elasticsearch/index/mapper/internal/UidFieldMapper.java +++ b/core/src/main/java/org/elasticsearch/index/mapper/internal/UidFieldMapper.java @@ -149,7 +149,7 @@ public class UidFieldMapper extends MetadataFieldMapper { @Override public void postParse(ParseContext context) throws IOException { - if (context.id() == null && !context.sourceToParse().flyweight()) { + if (context.id() == null) { throw new MapperParsingException("No id found while parsing the content source"); } // if we did not have the id as part of the sourceToParse, then we need to parse it here diff --git a/core/src/main/java/org/elasticsearch/index/percolator/PercolatorFieldMapper.java b/core/src/main/java/org/elasticsearch/index/percolator/PercolatorFieldMapper.java index c4b2b06e0e0..9a103195746 100644 --- a/core/src/main/java/org/elasticsearch/index/percolator/PercolatorFieldMapper.java +++ b/core/src/main/java/org/elasticsearch/index/percolator/PercolatorFieldMapper.java @@ -126,9 +126,7 @@ public class PercolatorFieldMapper extends FieldMapper { public Mapper parse(ParseContext context) throws IOException { QueryShardContext queryShardContext = new QueryShardContext(this.queryShardContext); Query query = PercolatorQueriesRegistry.parseQuery(queryShardContext, mapUnmappedFieldAsString, context.parser()); - if (context.flyweight() == false) { - ExtractQueryTermsService.extractQueryTerms(query, context.doc(), queryTermsField.name(), unknownQueryField.name(), queryTermsField.fieldType()); - } + ExtractQueryTermsService.extractQueryTerms(query, context.doc(), queryTermsField.name(), unknownQueryField.name(), queryTermsField.fieldType()); return null; } diff --git a/core/src/main/java/org/elasticsearch/index/termvectors/TermVectorsService.java b/core/src/main/java/org/elasticsearch/index/termvectors/TermVectorsService.java index c78d125ab46..97416e17211 100644 --- a/core/src/main/java/org/elasticsearch/index/termvectors/TermVectorsService.java +++ b/core/src/main/java/org/elasticsearch/index/termvectors/TermVectorsService.java @@ -292,7 +292,7 @@ public class TermVectorsService { private ParsedDocument parseDocument(IndexShard indexShard, String index, String type, BytesReference doc) throws Throwable { MapperService mapperService = indexShard.mapperService(); DocumentMapperForType docMapper = mapperService.documentMapperWithAutoCreate(type); - ParsedDocument parsedDocument = docMapper.getDocumentMapper().parse(source(doc).index(index).type(type).flyweight(true)); + ParsedDocument parsedDocument = docMapper.getDocumentMapper().parse(source(doc).index(index).type(type).id("_id_for_tv_api")); if (docMapper.getMapping() != null) { parsedDocument.addDynamicMappingsUpdate(docMapper.getMapping()); } diff --git a/core/src/main/java/org/elasticsearch/percolator/PercolateDocumentParser.java b/core/src/main/java/org/elasticsearch/percolator/PercolateDocumentParser.java index 973aa18b8fc..946d30edcc4 100644 --- a/core/src/main/java/org/elasticsearch/percolator/PercolateDocumentParser.java +++ b/core/src/main/java/org/elasticsearch/percolator/PercolateDocumentParser.java @@ -24,6 +24,7 @@ import org.apache.lucene.search.Query; import org.elasticsearch.ElasticsearchParseException; import org.elasticsearch.action.percolate.PercolateShardRequest; import org.elasticsearch.cluster.action.index.MappingUpdatedAction; +import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.io.stream.BytesStreamOutput; @@ -34,6 +35,7 @@ import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.index.mapper.DocumentMapperForType; import org.elasticsearch.index.mapper.MapperService; import org.elasticsearch.index.mapper.ParsedDocument; +import org.elasticsearch.index.mapper.SourceToParse; import org.elasticsearch.index.query.QueryShardContext; import org.elasticsearch.search.SearchParseElement; import org.elasticsearch.search.aggregations.AggregationPhase; @@ -93,7 +95,7 @@ public class PercolateDocumentParser { DocumentMapperForType docMapper = mapperService.documentMapperWithAutoCreate(request.documentType()); String index = context.shardTarget().index(); - doc = docMapper.getDocumentMapper().parse(source(parser).index(index).type(request.documentType()).flyweight(true)); + doc = docMapper.getDocumentMapper().parse(source(parser).index(index).type(request.documentType()).id("_id_for_percolate_api")); if (docMapper.getMapping() != null) { doc.addDynamicMappingsUpdate(docMapper.getMapping()); } @@ -202,19 +204,15 @@ public class PercolateDocumentParser { } private ParsedDocument parseFetchedDoc(PercolateContext context, BytesReference fetchedDoc, MapperService mapperService, String index, String type) { - try (XContentParser parser = XContentFactory.xContent(fetchedDoc).createParser(fetchedDoc)) { - DocumentMapperForType docMapper = mapperService.documentMapperWithAutoCreate(type); - ParsedDocument doc = docMapper.getDocumentMapper().parse(source(parser).index(index).type(type).flyweight(true)); - if (doc == null) { - throw new ElasticsearchParseException("No doc to percolate in the request"); - } - if (context.highlight() != null) { - doc.setSource(fetchedDoc); - } - return doc; - } catch (Throwable e) { - throw new ElasticsearchParseException("failed to parse request", e); + DocumentMapperForType docMapper = mapperService.documentMapperWithAutoCreate(type); + ParsedDocument doc = docMapper.getDocumentMapper().parse(source(fetchedDoc).index(index).type(type).id("_id_for_percolate_api")); + if (doc == null) { + throw new ElasticsearchParseException("No doc to percolate in the request"); } + if (context.highlight() != null) { + doc.setSource(fetchedDoc); + } + return doc; } } diff --git a/plugins/mapper-size/src/main/java/org/elasticsearch/index/mapper/size/SizeFieldMapper.java b/plugins/mapper-size/src/main/java/org/elasticsearch/index/mapper/size/SizeFieldMapper.java index 6cd54eeaac0..984e83a438e 100644 --- a/plugins/mapper-size/src/main/java/org/elasticsearch/index/mapper/size/SizeFieldMapper.java +++ b/plugins/mapper-size/src/main/java/org/elasticsearch/index/mapper/size/SizeFieldMapper.java @@ -150,7 +150,7 @@ public class SizeFieldMapper extends MetadataFieldMapper { if (!enabledState.enabled) { return; } - if (context.flyweight()) { + if (context.source() == null) { return; } fields.add(new IntegerFieldMapper.CustomIntegerNumericField(context.source().length(), fieldType())); From aefdee17fd0d7caacc769f2df54beede12adb3f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christoph=20B=C3=BCscher?= Date: Wed, 27 Jan 2016 18:56:19 +0100 Subject: [PATCH 11/11] Adding builder method to SmoothingModel implementations Adds a method that emits a WordScorerFactory to all of the three SmoothingModel implementatins that will be needed when we switch to parsing the PhraseSuggestion on the coordinating node and need to delay creating the WordScorer on the shards. --- .../search/suggest/phrase/LaplaceScorer.java | 10 +- .../phrase/LinearInterpoatingScorer.java | 14 ++- .../phrase/PhraseSuggestionBuilder.java | 92 +++++++++++++------ .../suggest/phrase/StupidBackoffScorer.java | 4 + .../suggest/phrase/LaplaceModelTests.java | 17 +++- .../phrase/LinearInterpolationModelTests.java | 22 ++++- ...lTest.java => SmoothingModelTestCase.java} | 71 ++++++++++---- .../phrase/StupidBackoffModelTests.java | 17 +++- 8 files changed, 185 insertions(+), 62 deletions(-) rename core/src/test/java/org/elasticsearch/search/suggest/phrase/{SmoothingModelTest.java => SmoothingModelTestCase.java} (68%) diff --git a/core/src/main/java/org/elasticsearch/search/suggest/phrase/LaplaceScorer.java b/core/src/main/java/org/elasticsearch/search/suggest/phrase/LaplaceScorer.java index 04d98c3827d..678f3082bac 100644 --- a/core/src/main/java/org/elasticsearch/search/suggest/phrase/LaplaceScorer.java +++ b/core/src/main/java/org/elasticsearch/search/suggest/phrase/LaplaceScorer.java @@ -27,14 +27,14 @@ import org.elasticsearch.search.suggest.phrase.DirectCandidateGenerator.Candidat import java.io.IOException; //TODO public for tests public final class LaplaceScorer extends WordScorer { - + public static final WordScorerFactory FACTORY = new WordScorer.WordScorerFactory() { @Override public WordScorer newScorer(IndexReader reader, Terms terms, String field, double realWordLikelyhood, BytesRef separator) throws IOException { return new LaplaceScorer(reader, terms, field, realWordLikelyhood, separator, 0.5); } }; - + private double alpha; public LaplaceScorer(IndexReader reader, Terms terms, String field, @@ -42,7 +42,11 @@ public final class LaplaceScorer extends WordScorer { super(reader, terms, field, realWordLikelyhood, separator); this.alpha = alpha; } - + + double alpha() { + return this.alpha; + } + @Override protected double scoreBigram(Candidate word, Candidate w_1) throws IOException { SuggestUtils.join(separator, spare, w_1.term, word.term); diff --git a/core/src/main/java/org/elasticsearch/search/suggest/phrase/LinearInterpoatingScorer.java b/core/src/main/java/org/elasticsearch/search/suggest/phrase/LinearInterpoatingScorer.java index d2b1ba48b13..368d461fc53 100644 --- a/core/src/main/java/org/elasticsearch/search/suggest/phrase/LinearInterpoatingScorer.java +++ b/core/src/main/java/org/elasticsearch/search/suggest/phrase/LinearInterpoatingScorer.java @@ -41,7 +41,19 @@ public final class LinearInterpoatingScorer extends WordScorer { this.bigramLambda = bigramLambda / sum; this.trigramLambda = trigramLambda / sum; } - + + double trigramLambda() { + return this.trigramLambda; + } + + double bigramLambda() { + return this.bigramLambda; + } + + double unigramLambda() { + return this.unigramLambda; + } + @Override protected double scoreBigram(Candidate word, Candidate w_1) throws IOException { SuggestUtils.join(separator, spare, w_1.term, word.term); diff --git a/core/src/main/java/org/elasticsearch/search/suggest/phrase/PhraseSuggestionBuilder.java b/core/src/main/java/org/elasticsearch/search/suggest/phrase/PhraseSuggestionBuilder.java index 97ca09d25a1..0e1fec6c7b2 100644 --- a/core/src/main/java/org/elasticsearch/search/suggest/phrase/PhraseSuggestionBuilder.java +++ b/core/src/main/java/org/elasticsearch/search/suggest/phrase/PhraseSuggestionBuilder.java @@ -18,8 +18,12 @@ */ package org.elasticsearch.search.suggest.phrase; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.Terms; +import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.ParseFieldMatcher; +import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.io.stream.NamedWriteable; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -30,6 +34,7 @@ import org.elasticsearch.common.xcontent.XContentParser.Token; import org.elasticsearch.index.query.QueryParseContext; import org.elasticsearch.script.Template; import org.elasticsearch.search.suggest.SuggestBuilder.SuggestionBuilder; +import org.elasticsearch.search.suggest.phrase.WordScorer.WordScorerFactory; import java.io.IOException; import java.util.ArrayList; @@ -50,7 +55,7 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder> generators = new HashMap<>(); private Integer gramSize; - private SmoothingModel model; + private SmoothingModel model; private Boolean forceUnigrams; private Integer tokenLimit; private String preTag; @@ -159,7 +164,7 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder model) { + public PhraseSuggestionBuilder smoothingModel(SmoothingModel model) { this.model = model; return this; } @@ -292,7 +297,7 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder for details. *

*/ - public static final class StupidBackoff extends SmoothingModel { + public static final class StupidBackoff extends SmoothingModel { /** * Default discount parameter for {@link StupidBackoff} smoothing */ @@ -341,8 +346,9 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder new StupidBackoffScorer(reader, terms, field, realWordLikelyhood, separator, discount); + } } /** @@ -377,7 +389,7 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder for details. *

*/ - public static final class Laplace extends SmoothingModel { + public static final class Laplace extends SmoothingModel { private double alpha = DEFAULT_LAPLACE_ALPHA; private static final String NAME = "laplace"; private static final ParseField ALPHA_FIELD = new ParseField("alpha"); @@ -419,13 +431,14 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder new LaplaceScorer(reader, terms, field, realWordLikelyhood, separator, alpha); + } } - public static abstract class SmoothingModel> implements NamedWriteable, ToXContent { + public static abstract class SmoothingModel implements NamedWriteable, ToXContent { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { @@ -471,16 +490,18 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder for details. *

*/ - public static final class LinearInterpolation extends SmoothingModel { + public static final class LinearInterpolation extends SmoothingModel { private static final String NAME = "linear"; static final LinearInterpolation PROTOTYPE = new LinearInterpolation(0.8, 0.1, 0.1); private final double trigramLambda; @@ -563,10 +584,11 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder + new LinearInterpoatingScorer(reader, terms, field, realWordLikelyhood, separator, trigramLambda, bigramLambda, + unigramLambda); } } diff --git a/core/src/main/java/org/elasticsearch/search/suggest/phrase/StupidBackoffScorer.java b/core/src/main/java/org/elasticsearch/search/suggest/phrase/StupidBackoffScorer.java index fcf6064d228..5bd3d942b1a 100644 --- a/core/src/main/java/org/elasticsearch/search/suggest/phrase/StupidBackoffScorer.java +++ b/core/src/main/java/org/elasticsearch/search/suggest/phrase/StupidBackoffScorer.java @@ -42,6 +42,10 @@ public class StupidBackoffScorer extends WordScorer { this.discount = discount; } + double discount() { + return this.discount; + } + @Override protected double scoreBigram(Candidate word, Candidate w_1) throws IOException { SuggestUtils.join(separator, spare, w_1.term, word.term); diff --git a/core/src/test/java/org/elasticsearch/search/suggest/phrase/LaplaceModelTests.java b/core/src/test/java/org/elasticsearch/search/suggest/phrase/LaplaceModelTests.java index e2256e98f6a..87ad654e0cd 100644 --- a/core/src/test/java/org/elasticsearch/search/suggest/phrase/LaplaceModelTests.java +++ b/core/src/test/java/org/elasticsearch/search/suggest/phrase/LaplaceModelTests.java @@ -20,11 +20,14 @@ package org.elasticsearch.search.suggest.phrase; import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.Laplace; +import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.SmoothingModel; -public class LaplaceModelTests extends SmoothingModelTest { +import static org.hamcrest.Matchers.instanceOf; + +public class LaplaceModelTests extends SmoothingModelTestCase { @Override - protected Laplace createTestModel() { + protected SmoothingModel createTestModel() { return new Laplace(randomDoubleBetween(0.0, 10.0, false)); } @@ -32,7 +35,15 @@ public class LaplaceModelTests extends SmoothingModelTest { * mutate the given model so the returned smoothing model is different */ @Override - protected Laplace createMutation(Laplace original) { + protected Laplace createMutation(SmoothingModel input) { + Laplace original = (Laplace) input; return new Laplace(original.getAlpha() + 0.1); } + + @Override + void assertWordScorer(WordScorer wordScorer, SmoothingModel input) { + Laplace model = (Laplace) input; + assertThat(wordScorer, instanceOf(LaplaceScorer.class)); + assertEquals(model.getAlpha(), ((LaplaceScorer) wordScorer).alpha(), Double.MIN_VALUE); + } } diff --git a/core/src/test/java/org/elasticsearch/search/suggest/phrase/LinearInterpolationModelTests.java b/core/src/test/java/org/elasticsearch/search/suggest/phrase/LinearInterpolationModelTests.java index 467bca7f0ab..1112b7a5ed7 100644 --- a/core/src/test/java/org/elasticsearch/search/suggest/phrase/LinearInterpolationModelTests.java +++ b/core/src/test/java/org/elasticsearch/search/suggest/phrase/LinearInterpolationModelTests.java @@ -20,15 +20,18 @@ package org.elasticsearch.search.suggest.phrase; import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.LinearInterpolation; +import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.SmoothingModel; -public class LinearInterpolationModelTests extends SmoothingModelTest { +import static org.hamcrest.Matchers.instanceOf; + +public class LinearInterpolationModelTests extends SmoothingModelTestCase { @Override - protected LinearInterpolation createTestModel() { + protected SmoothingModel createTestModel() { double trigramLambda = randomDoubleBetween(0.0, 10.0, false); double bigramLambda = randomDoubleBetween(0.0, 10.0, false); double unigramLambda = randomDoubleBetween(0.0, 10.0, false); - // normalize + // normalize so parameters sum to 1 double sum = trigramLambda + bigramLambda + unigramLambda; return new LinearInterpolation(trigramLambda / sum, bigramLambda / sum, unigramLambda / sum); } @@ -37,7 +40,8 @@ public class LinearInterpolationModelTests extends SmoothingModelTest> extends ESTestCase { +public abstract class SmoothingModelTestCase extends ESTestCase { private static NamedWriteableRegistry namedWriteableRegistry; - private static IndicesQueriesRegistry indicesQueriesRegistry; /** * setup for the whole base test class @@ -63,33 +76,31 @@ public abstract class SmoothingModelTest> extends E namedWriteableRegistry.registerPrototype(SmoothingModel.class, LinearInterpolation.PROTOTYPE); namedWriteableRegistry.registerPrototype(SmoothingModel.class, StupidBackoff.PROTOTYPE); } - indicesQueriesRegistry = new IndicesQueriesRegistry(Settings.settingsBuilder().build(), Collections.emptySet(), namedWriteableRegistry); } @AfterClass public static void afterClass() throws Exception { namedWriteableRegistry = null; - indicesQueriesRegistry = null; } /** * create random model that is put under test */ - protected abstract SM createTestModel(); + protected abstract SmoothingModel createTestModel(); /** * mutate the given model so the returned smoothing model is different */ - protected abstract SM createMutation(SM original) throws IOException; + protected abstract SmoothingModel createMutation(SmoothingModel original) throws IOException; /** * Test that creates new smoothing model from a random test smoothing model and checks both for equality */ public void testFromXContent() throws IOException { - QueryParseContext context = new QueryParseContext(indicesQueriesRegistry); + QueryParseContext context = new QueryParseContext(new IndicesQueriesRegistry(Settings.settingsBuilder().build(), Collections.emptyMap())); context.parseFieldMatcher(new ParseFieldMatcher(Settings.EMPTY)); - SM testModel = createTestModel(); + SmoothingModel testModel = createTestModel(); XContentBuilder contentBuilder = XContentFactory.contentBuilder(randomFrom(XContentType.values())); if (randomBoolean()) { contentBuilder.prettyPrint(); @@ -99,21 +110,45 @@ public abstract class SmoothingModelTest> extends E contentBuilder.endObject(); XContentParser parser = XContentHelper.createParser(contentBuilder.bytes()); context.reset(parser); - SmoothingModel prototype = (SmoothingModel) namedWriteableRegistry.getPrototype(SmoothingModel.class, + parser.nextToken(); // go to start token, real parsing would do that in the outer element parser + SmoothingModel prototype = (SmoothingModel) namedWriteableRegistry.getPrototype(SmoothingModel.class, testModel.getWriteableName()); - SmoothingModel parsedModel = prototype.fromXContent(context); + SmoothingModel parsedModel = prototype.fromXContent(context); assertNotSame(testModel, parsedModel); assertEquals(testModel, parsedModel); assertEquals(testModel.hashCode(), parsedModel.hashCode()); } + /** + * Test the WordScorer emitted by the smoothing model + */ + public void testBuildWordScorer() throws IOException { + SmoothingModel testModel = createTestModel(); + + Map mapping = new HashMap<>(); + mapping.put("field", new WhitespaceAnalyzer()); + PerFieldAnalyzerWrapper wrapper = new PerFieldAnalyzerWrapper(new WhitespaceAnalyzer(), mapping); + IndexWriter writer = new IndexWriter(new RAMDirectory(), new IndexWriterConfig(wrapper)); + Document doc = new Document(); + doc.add(new Field("field", "someText", TextField.TYPE_NOT_STORED)); + writer.addDocument(doc); + DirectoryReader ir = DirectoryReader.open(writer, false); + + WordScorer wordScorer = testModel.buildWordScorerFactory().newScorer(ir, MultiFields.getTerms(ir , "field"), "field", 0.9d, BytesRefs.toBytesRef(" ")); + assertWordScorer(wordScorer, testModel); + } + + /** + * implementation dependant assertions on the wordScorer produced by the smoothing model under test + */ + abstract void assertWordScorer(WordScorer wordScorer, SmoothingModel testModel); + /** * Test serialization and deserialization of the tested model. */ - @SuppressWarnings("unchecked") public void testSerialization() throws IOException { - SM testModel = createTestModel(); - SM deserializedModel = (SM) copyModel(testModel); + SmoothingModel testModel = createTestModel(); + SmoothingModel deserializedModel = copyModel(testModel); assertEquals(testModel, deserializedModel); assertEquals(testModel.hashCode(), deserializedModel.hashCode()); assertNotSame(testModel, deserializedModel); @@ -124,7 +159,7 @@ public abstract class SmoothingModelTest> extends E */ @SuppressWarnings("unchecked") public void testEqualsAndHashcode() throws IOException { - SM firstModel = createTestModel(); + SmoothingModel firstModel = createTestModel(); assertFalse("smoothing model is equal to null", firstModel.equals(null)); assertFalse("smoothing model is equal to incompatible type", firstModel.equals("")); assertTrue("smoothing model is not equal to self", firstModel.equals(firstModel)); @@ -132,13 +167,13 @@ public abstract class SmoothingModelTest> extends E equalTo(firstModel.hashCode())); assertThat("different smoothing models should not be equal", createMutation(firstModel), not(equalTo(firstModel))); - SM secondModel = (SM) copyModel(firstModel); + SmoothingModel secondModel = copyModel(firstModel); assertTrue("smoothing model is not equal to self", secondModel.equals(secondModel)); assertTrue("smoothing model is not equal to its copy", firstModel.equals(secondModel)); assertTrue("equals is not symmetric", secondModel.equals(firstModel)); assertThat("smoothing model copy's hashcode is different from original hashcode", secondModel.hashCode(), equalTo(firstModel.hashCode())); - SM thirdModel = (SM) copyModel(secondModel); + SmoothingModel thirdModel = copyModel(secondModel); assertTrue("smoothing model is not equal to self", thirdModel.equals(thirdModel)); assertTrue("smoothing model is not equal to its copy", secondModel.equals(thirdModel)); assertThat("smoothing model copy's hashcode is different from original hashcode", secondModel.hashCode(), equalTo(thirdModel.hashCode())); @@ -148,11 +183,11 @@ public abstract class SmoothingModelTest> extends E assertTrue("equals is not symmetric", thirdModel.equals(firstModel)); } - static SmoothingModel copyModel(SmoothingModel original) throws IOException { + static SmoothingModel copyModel(SmoothingModel original) throws IOException { try (BytesStreamOutput output = new BytesStreamOutput()) { original.writeTo(output); try (StreamInput in = new NamedWriteableAwareStreamInput(StreamInput.wrap(output.bytes()), namedWriteableRegistry)) { - SmoothingModel prototype = (SmoothingModel) namedWriteableRegistry.getPrototype(SmoothingModel.class, original.getWriteableName()); + SmoothingModel prototype = (SmoothingModel) namedWriteableRegistry.getPrototype(SmoothingModel.class, original.getWriteableName()); return prototype.readFrom(in); } } diff --git a/core/src/test/java/org/elasticsearch/search/suggest/phrase/StupidBackoffModelTests.java b/core/src/test/java/org/elasticsearch/search/suggest/phrase/StupidBackoffModelTests.java index 5d774066e07..c3bd66d2a81 100644 --- a/core/src/test/java/org/elasticsearch/search/suggest/phrase/StupidBackoffModelTests.java +++ b/core/src/test/java/org/elasticsearch/search/suggest/phrase/StupidBackoffModelTests.java @@ -19,12 +19,15 @@ package org.elasticsearch.search.suggest.phrase; +import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.SmoothingModel; import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.StupidBackoff; -public class StupidBackoffModelTests extends SmoothingModelTest { +import static org.hamcrest.Matchers.instanceOf; + +public class StupidBackoffModelTests extends SmoothingModelTestCase { @Override - protected StupidBackoff createTestModel() { + protected SmoothingModel createTestModel() { return new StupidBackoff(randomDoubleBetween(0.0, 10.0, false)); } @@ -32,7 +35,15 @@ public class StupidBackoffModelTests extends SmoothingModelTest { * mutate the given model so the returned smoothing model is different */ @Override - protected StupidBackoff createMutation(StupidBackoff original) { + protected StupidBackoff createMutation(SmoothingModel input) { + StupidBackoff original = (StupidBackoff) input; return new StupidBackoff(original.getDiscount() + 0.1); } + + @Override + void assertWordScorer(WordScorer wordScorer, SmoothingModel input) { + assertThat(wordScorer, instanceOf(StupidBackoffScorer.class)); + StupidBackoff testModel = (StupidBackoff) input; + assertEquals(testModel.getDiscount(), ((StupidBackoffScorer) wordScorer).discount(), Double.MIN_VALUE); + } }