Fix (de)serialization of async search failures (#55688)

The (de)serialization code of the async search response
cannot handle exceptions that extend ElasticsearchException (e.g. ScriptException).
This commit fixes this bug by serializing the error with the more generic
StreamInput#writeException.
This commit is contained in:
Jim Ferenczi 2020-04-23 23:00:34 +02:00 committed by jimczi
parent 8c7ef2417f
commit 31d1727698
8 changed files with 93 additions and 42 deletions

View File

@ -194,7 +194,7 @@ final class FetchSearchPhase extends SearchPhase {
try { try {
logger.debug( logger.debug(
() -> new ParameterizedMessage("[{}] Failed to execute fetch phase", fetchSearchRequest.contextId()), e); () -> new ParameterizedMessage("[{}] Failed to execute fetch phase", fetchSearchRequest.contextId()), e);
progressListener.notifyFetchFailure(shardIndex, e); progressListener.notifyFetchFailure(shardIndex, shardTarget, e);
counter.onFailure(shardIndex, shardTarget, e); counter.onFailure(shardIndex, shardTarget, e);
} finally { } finally {
// the search context might not be cleared on the node where the fetch was executed for example // the search context might not be cleared on the node where the fetch was executed for example

View File

@ -104,9 +104,10 @@ abstract class SearchProgressListener {
* Executed when a shard reports a fetch failure. * Executed when a shard reports a fetch failure.
* *
* @param shardIndex The index of the shard in the list provided by {@link SearchProgressListener#onListShards})}. * @param shardIndex The index of the shard in the list provided by {@link SearchProgressListener#onListShards})}.
* @param shardTarget The last shard target that thrown an exception.
* @param exc The cause of the failure. * @param exc The cause of the failure.
*/ */
protected void onFetchFailure(int shardIndex, Exception exc) {} protected void onFetchFailure(int shardIndex, SearchShardTarget shardTarget, Exception exc) {}
final void notifyListShards(List<SearchShard> shards, List<SearchShard> skippedShards, Clusters clusters, boolean fetchPhase) { final void notifyListShards(List<SearchShard> shards, List<SearchShard> skippedShards, Clusters clusters, boolean fetchPhase) {
this.shards = shards; this.shards = shards;
@ -160,9 +161,9 @@ abstract class SearchProgressListener {
} }
} }
final void notifyFetchFailure(int shardIndex, Exception exc) { final void notifyFetchFailure(int shardIndex, SearchShardTarget shardTarget, Exception exc) {
try { try {
onFetchFailure(shardIndex, exc); onFetchFailure(shardIndex, shardTarget, exc);
} catch (Exception e) { } catch (Exception e) {
logger.warn(() -> new ParameterizedMessage("[{}] Failed to execute progress listener on fetch failure", logger.warn(() -> new ParameterizedMessage("[{}] Failed to execute progress listener on fetch failure",
shards.get(shardIndex)), e); shards.get(shardIndex)), e);

View File

@ -167,7 +167,7 @@ public class SearchProgressActionListenerIT extends ESSingleNodeTestCase {
} }
@Override @Override
public void onFetchFailure(int shardIndex, Exception exc) { public void onFetchFailure(int shardIndex, SearchShardTarget shardTarget, Exception exc) {
assertThat(shardIndex, lessThan(shardsListener.get().size())); assertThat(shardIndex, lessThan(shardsListener.get().size()));
numFetchFailures.incrementAndGet(); numFetchFailures.incrementAndGet();
} }

View File

@ -345,12 +345,16 @@ final class AsyncSearchTask extends SearchTask implements AsyncTask {
// best effort to cancel expired tasks // best effort to cancel expired tasks
checkCancellation(); checkCancellation();
searchResponse.get().addShardFailure(shardIndex, searchResponse.get().addShardFailure(shardIndex,
// the nodeId is null if all replicas of this shard failed
new ShardSearchFailure(exc, shardTarget.getNodeId() != null ? shardTarget : null)); new ShardSearchFailure(exc, shardTarget.getNodeId() != null ? shardTarget : null));
} }
@Override @Override
protected void onFetchFailure(int shardIndex, Exception exc) { protected void onFetchFailure(int shardIndex, SearchShardTarget shardTarget, Exception exc) {
checkCancellation(); checkCancellation();
searchResponse.get().addShardFailure(shardIndex,
// the nodeId is null if all replicas of this shard failed
new ShardSearchFailure(exc, shardTarget.getNodeId() != null ? shardTarget : null));
} }
@Override @Override

View File

@ -6,7 +6,6 @@
package org.elasticsearch.xpack.search; package org.elasticsearch.xpack.search;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.index.IndexRequestBuilder; import org.elasticsearch.action.index.IndexRequestBuilder;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
@ -250,7 +249,7 @@ public class AsyncSearchActionIT extends AsyncSearchIntegTestCase {
assertNull(response.getSearchResponse()); assertNull(response.getSearchResponse());
assertNotNull(response.getFailure()); assertNotNull(response.getFailure());
assertFalse(response.isRunning()); assertFalse(response.isRunning());
ElasticsearchException exc = response.getFailure(); Exception exc = response.getFailure();
assertThat(exc.getMessage(), containsString("no such index")); assertThat(exc.getMessage(), containsString("no such index"));
} }

View File

@ -6,7 +6,6 @@
package org.elasticsearch.xpack.search; package org.elasticsearch.xpack.search;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.Version; import org.elasticsearch.Version;
import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.search.ShardSearchFailure; import org.elasticsearch.action.search.ShardSearchFailure;
@ -19,6 +18,7 @@ import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentElasticsearchExtension; import org.elasticsearch.common.xcontent.XContentElasticsearchExtension;
import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.script.ScriptException;
import org.elasticsearch.search.SearchModule; import org.elasticsearch.search.SearchModule;
import org.elasticsearch.search.internal.InternalSearchResponse; import org.elasticsearch.search.internal.InternalSearchResponse;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
@ -30,6 +30,7 @@ import org.elasticsearch.xpack.core.transform.transforms.TimeSyncConfig;
import org.junit.Before; import org.junit.Before;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.Date; import java.util.Date;
import java.util.List; import java.util.List;
@ -107,7 +108,8 @@ public class AsyncSearchResponseTests extends ESTestCase {
case 2: case 2:
return new AsyncSearchResponse(searchId, searchResponse, return new AsyncSearchResponse(searchId, searchResponse,
new ElasticsearchException(new IOException("boum")), randomBoolean(), randomBoolean(), new ScriptException("messageData", new Exception("causeData"), Arrays.asList("stack1", "stack2"),
"sourceData", "langData"), randomBoolean(), randomBoolean(),
randomNonNegativeLong(), randomNonNegativeLong()); randomNonNegativeLong(), randomNonNegativeLong());
default: default:

View File

@ -7,15 +7,16 @@ package org.elasticsearch.xpack.search;
import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.TotalHits;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.OriginalIndices;
import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.search.SearchShard; import org.elasticsearch.action.search.SearchShard;
import org.elasticsearch.action.search.ShardSearchFailure; import org.elasticsearch.action.search.ShardSearchFailure;
import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.search.SearchHits; import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.search.aggregations.InternalAggregations; import org.elasticsearch.search.aggregations.InternalAggregations;
import org.elasticsearch.search.internal.InternalSearchResponse; import org.elasticsearch.search.internal.InternalSearchResponse;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.TaskId; import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.client.NoOpClient; import org.elasticsearch.test.client.NoOpClient;
@ -26,35 +27,19 @@ import org.elasticsearch.xpack.core.search.action.AsyncSearchResponse;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
public class AsyncSearchTaskTests extends ESTestCase { public class AsyncSearchTaskTests extends ESTestCase {
private ThreadPool threadPool; private ThreadPool threadPool;
private static class TestTask extends CancellableTask {
private TestTask(long id, String type, String action, String description, TaskId parentTaskId, Map<String, String> headers) {
super(id, type, action, description, parentTaskId, headers);
}
@Override
public boolean shouldCancelChildrenOnCancellation() {
return false;
}
}
private static TestTask createSubmitTask() {
return new TestTask(0L, "", "", "test", new TaskId("node1", 0), Collections.emptyMap());
}
@Before @Before
public void beforeTest() { public void beforeTest() {
threadPool = new TestThreadPool(getTestName()); threadPool = new TestThreadPool(getTestName());
@ -65,6 +50,12 @@ public class AsyncSearchTaskTests extends ESTestCase {
threadPool.shutdownNow(); threadPool.shutdownNow();
} }
private AsyncSearchTask createAsyncSearchTask() {
return new AsyncSearchTask(0L, "", "", new TaskId("node1", 0), () -> false, TimeValue.timeValueHours(1),
Collections.emptyMap(), Collections.emptyMap(), new AsyncExecutionId("0", new TaskId("node1", 1)),
new NoOpClient(threadPool), threadPool, null);
}
public void testWaitForInit() throws InterruptedException { public void testWaitForInit() throws InterruptedException {
AsyncSearchTask task = new AsyncSearchTask(0L, "", "", new TaskId("node1", 0), () -> false, TimeValue.timeValueHours(1), AsyncSearchTask task = new AsyncSearchTask(0L, "", "", new TaskId("node1", 0), () -> false, TimeValue.timeValueHours(1),
Collections.emptyMap(), Collections.emptyMap(), new AsyncExecutionId("0", new TaskId("node1", 1)), Collections.emptyMap(), Collections.emptyMap(), new AsyncExecutionId("0", new TaskId("node1", 1)),
@ -106,9 +97,7 @@ public class AsyncSearchTaskTests extends ESTestCase {
} }
public void testWithFailure() throws InterruptedException { public void testWithFailure() throws InterruptedException {
AsyncSearchTask task = new AsyncSearchTask(0L, "", "", new TaskId("node1", 0), () -> false, TimeValue.timeValueHours(1), AsyncSearchTask task = createAsyncSearchTask();
Collections.emptyMap(), Collections.emptyMap(), new AsyncExecutionId("0", new TaskId("node1", 1)),
new NoOpClient(threadPool), threadPool, null);
int numThreads = randomIntBetween(1, 10); int numThreads = randomIntBetween(1, 10);
CountDownLatch latch = new CountDownLatch(numThreads); CountDownLatch latch = new CountDownLatch(numThreads);
for (int i = 0; i < numThreads; i++) { for (int i = 0; i < numThreads; i++) {
@ -134,9 +123,7 @@ public class AsyncSearchTaskTests extends ESTestCase {
} }
public void testWaitForCompletion() throws InterruptedException { public void testWaitForCompletion() throws InterruptedException {
AsyncSearchTask task = new AsyncSearchTask(0L, "", "", new TaskId("node1", 0), () -> false, TimeValue.timeValueHours(1), AsyncSearchTask task = createAsyncSearchTask();
Collections.emptyMap(), Collections.emptyMap(), new AsyncExecutionId("0", new TaskId("node1", 1)),
new NoOpClient(threadPool), threadPool, null);
int numShards = randomIntBetween(0, 10); int numShards = randomIntBetween(0, 10);
List<SearchShard> shards = new ArrayList<>(); List<SearchShard> shards = new ArrayList<>();
for (int i = 0; i < numShards; i++) { for (int i = 0; i < numShards; i++) {
@ -165,6 +152,42 @@ public class AsyncSearchTaskTests extends ESTestCase {
threadPool.shutdownNow(); threadPool.shutdownNow();
} }
public void testWithFetchFailures() throws InterruptedException {
AsyncSearchTask task = createAsyncSearchTask();
int numShards = randomIntBetween(0, 10);
List<SearchShard> shards = new ArrayList<>();
for (int i = 0; i < numShards; i++) {
shards.add(new SearchShard(null, new ShardId("0", "0", 1)));
}
List<SearchShard> skippedShards = new ArrayList<>();
int numSkippedShards = randomIntBetween(0, 10);
for (int i = 0; i < numSkippedShards; i++) {
skippedShards.add(new SearchShard(null, new ShardId("0", "0", 1)));
}
task.getSearchProgressActionListener().onListShards(shards, skippedShards, SearchResponse.Clusters.EMPTY, false);
for (int i = 0; i < numShards; i++) {
task.getSearchProgressActionListener().onPartialReduce(shards.subList(i, i+1),
new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), null, 0);
assertCompletionListeners(task, numShards+numSkippedShards, numSkippedShards, 0, true);
}
task.getSearchProgressActionListener().onFinalReduce(shards,
new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), null, 0);
int numFetchFailures = randomIntBetween(1, numShards);
for (int i = 0; i < numFetchFailures; i++) {
task.getSearchProgressActionListener().onFetchFailure(i,
new SearchShardTarget("0", new ShardId("0", "0", 1), null, OriginalIndices.NONE),
new IOException("boum"));
}
assertCompletionListeners(task, numShards+numSkippedShards, numSkippedShards, numFetchFailures, true);
((AsyncSearchTask.Listener)task.getProgressListener()).onResponse(
newSearchResponse(numShards+numSkippedShards, numShards, numSkippedShards));
assertCompletionListeners(task, numShards+numSkippedShards,
numSkippedShards, numFetchFailures, false);
threadPool.shutdownNow();
}
private static SearchResponse newSearchResponse(int totalShards, int successfulShards, int skippedShards) { private static SearchResponse newSearchResponse(int totalShards, int successfulShards, int skippedShards) {
InternalSearchResponse response = new InternalSearchResponse(SearchHits.empty(), InternalSearchResponse response = new InternalSearchResponse(SearchHits.empty(),
InternalAggregations.EMPTY, null, null, false, null, 1); InternalAggregations.EMPTY, null, null, false, null, 1);
@ -187,6 +210,13 @@ public class AsyncSearchTaskTests extends ESTestCase {
assertThat(resp.getSearchResponse().getSkippedShards(), equalTo(expectedSkippedShards)); assertThat(resp.getSearchResponse().getSkippedShards(), equalTo(expectedSkippedShards));
assertThat(resp.getSearchResponse().getFailedShards(), equalTo(expectedShardFailures)); assertThat(resp.getSearchResponse().getFailedShards(), equalTo(expectedShardFailures));
assertThat(resp.isPartial(), equalTo(isPartial)); assertThat(resp.isPartial(), equalTo(isPartial));
if (expectedShardFailures > 0) {
assertThat(resp.getSearchResponse().getShardFailures().length, equalTo(expectedShardFailures));
for (ShardSearchFailure failure : resp.getSearchResponse().getShardFailures()) {
assertThat(failure.getCause(), instanceOf(IOException.class));
assertThat(failure.getCause().getMessage(), equalTo("boum"));
}
}
latch.countDown(); latch.countDown();
} }

View File

@ -6,6 +6,8 @@
package org.elasticsearch.xpack.core.search.action; package org.elasticsearch.xpack.core.search.action;
import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionResponse; import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.common.Nullable; import org.elasticsearch.common.Nullable;
@ -29,7 +31,7 @@ public class AsyncSearchResponse extends ActionResponse implements StatusToXCont
@Nullable @Nullable
private final SearchResponse searchResponse; private final SearchResponse searchResponse;
@Nullable @Nullable
private final ElasticsearchException error; private final Exception error;
private final boolean isRunning; private final boolean isRunning;
private final boolean isPartial; private final boolean isPartial;
@ -60,7 +62,7 @@ public class AsyncSearchResponse extends ActionResponse implements StatusToXCont
*/ */
public AsyncSearchResponse(String id, public AsyncSearchResponse(String id,
SearchResponse searchResponse, SearchResponse searchResponse,
ElasticsearchException error, Exception error,
boolean isPartial, boolean isPartial,
boolean isRunning, boolean isRunning,
long startTimeMillis, long startTimeMillis,
@ -76,7 +78,11 @@ public class AsyncSearchResponse extends ActionResponse implements StatusToXCont
public AsyncSearchResponse(StreamInput in) throws IOException { public AsyncSearchResponse(StreamInput in) throws IOException {
this.id = in.readOptionalString(); this.id = in.readOptionalString();
if (in.getVersion().onOrAfter(Version.V_8_0_0)) {
this.error = in.readBoolean() ? in.readException() : null;
} else {
this.error = in.readOptionalWriteable(ElasticsearchException::new); this.error = in.readOptionalWriteable(ElasticsearchException::new);
}
this.searchResponse = in.readOptionalWriteable(SearchResponse::new); this.searchResponse = in.readOptionalWriteable(SearchResponse::new);
this.isPartial = in.readBoolean(); this.isPartial = in.readBoolean();
this.isRunning = in.readBoolean(); this.isRunning = in.readBoolean();
@ -87,7 +93,16 @@ public class AsyncSearchResponse extends ActionResponse implements StatusToXCont
@Override @Override
public void writeTo(StreamOutput out) throws IOException { public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalString(id); out.writeOptionalString(id);
out.writeOptionalWriteable(error); if (out.getVersion().onOrAfter(Version.V_8_0_0)) {
if (error != null) {
out.writeBoolean(true);
out.writeException(error);
} else {
out.writeBoolean(false);
}
} else {
out.writeOptionalWriteable(ExceptionsHelper.convertToElastic(error));
}
out.writeOptionalWriteable(searchResponse); out.writeOptionalWriteable(searchResponse);
out.writeBoolean(isPartial); out.writeBoolean(isPartial);
out.writeBoolean(isRunning); out.writeBoolean(isRunning);
@ -120,7 +135,7 @@ public class AsyncSearchResponse extends ActionResponse implements StatusToXCont
/** /**
* Returns the failure reason or null if the query is running or has completed normally. * Returns the failure reason or null if the query is running or has completed normally.
*/ */
public ElasticsearchException getFailure() { public Exception getFailure() {
return error; return error;
} }
@ -170,7 +185,7 @@ public class AsyncSearchResponse extends ActionResponse implements StatusToXCont
// shard failures are not considered fatal for partial results so // shard failures are not considered fatal for partial results so
// we return OK until we get the final response even if we don't have // we return OK until we get the final response even if we don't have
// a single successful shard. // a single successful shard.
return error != null ? error.status() : OK; return error != null ? ExceptionsHelper.status(error) : OK;
} else { } else {
return searchResponse.status(); return searchResponse.status();
} }
@ -193,7 +208,7 @@ public class AsyncSearchResponse extends ActionResponse implements StatusToXCont
} }
if (error != null) { if (error != null) {
builder.startObject("error"); builder.startObject("error");
error.toXContent(builder, params); ElasticsearchException.generateThrowableXContent(builder, params, error);
builder.endObject(); builder.endObject();
} }
builder.endObject(); builder.endObject();