Fix (de)serialization of async search failures (#55688)

The (de)serialization code of the async search response
cannot handle exceptions that extend ElasticsearchException (e.g. ScriptException).
This commit fixes this bug by serializing the error with the more generic
StreamInput#writeException.
This commit is contained in:
Jim Ferenczi 2020-04-23 23:00:34 +02:00 committed by jimczi
parent 8c7ef2417f
commit 31d1727698
8 changed files with 93 additions and 42 deletions

View File

@ -194,7 +194,7 @@ final class FetchSearchPhase extends SearchPhase {
try {
logger.debug(
() -> new ParameterizedMessage("[{}] Failed to execute fetch phase", fetchSearchRequest.contextId()), e);
progressListener.notifyFetchFailure(shardIndex, e);
progressListener.notifyFetchFailure(shardIndex, shardTarget, e);
counter.onFailure(shardIndex, shardTarget, e);
} finally {
// the search context might not be cleared on the node where the fetch was executed for example

View File

@ -104,9 +104,10 @@ abstract class SearchProgressListener {
* Executed when a shard reports a fetch failure.
*
* @param shardIndex The index of the shard in the list provided by {@link SearchProgressListener#onListShards})}.
* @param shardTarget The last shard target that thrown an exception.
* @param exc The cause of the failure.
*/
protected void onFetchFailure(int shardIndex, Exception exc) {}
protected void onFetchFailure(int shardIndex, SearchShardTarget shardTarget, Exception exc) {}
final void notifyListShards(List<SearchShard> shards, List<SearchShard> skippedShards, Clusters clusters, boolean fetchPhase) {
this.shards = shards;
@ -160,9 +161,9 @@ abstract class SearchProgressListener {
}
}
final void notifyFetchFailure(int shardIndex, Exception exc) {
final void notifyFetchFailure(int shardIndex, SearchShardTarget shardTarget, Exception exc) {
try {
onFetchFailure(shardIndex, exc);
onFetchFailure(shardIndex, shardTarget, exc);
} catch (Exception e) {
logger.warn(() -> new ParameterizedMessage("[{}] Failed to execute progress listener on fetch failure",
shards.get(shardIndex)), e);

View File

@ -167,7 +167,7 @@ public class SearchProgressActionListenerIT extends ESSingleNodeTestCase {
}
@Override
public void onFetchFailure(int shardIndex, Exception exc) {
public void onFetchFailure(int shardIndex, SearchShardTarget shardTarget, Exception exc) {
assertThat(shardIndex, lessThan(shardsListener.get().size()));
numFetchFailures.incrementAndGet();
}

View File

@ -345,12 +345,16 @@ final class AsyncSearchTask extends SearchTask implements AsyncTask {
// best effort to cancel expired tasks
checkCancellation();
searchResponse.get().addShardFailure(shardIndex,
// the nodeId is null if all replicas of this shard failed
new ShardSearchFailure(exc, shardTarget.getNodeId() != null ? shardTarget : null));
}
@Override
protected void onFetchFailure(int shardIndex, Exception exc) {
protected void onFetchFailure(int shardIndex, SearchShardTarget shardTarget, Exception exc) {
checkCancellation();
searchResponse.get().addShardFailure(shardIndex,
// the nodeId is null if all replicas of this shard failed
new ShardSearchFailure(exc, shardTarget.getNodeId() != null ? shardTarget : null));
}
@Override

View File

@ -6,7 +6,6 @@
package org.elasticsearch.xpack.search;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.index.IndexRequestBuilder;
import org.elasticsearch.common.settings.Settings;
@ -250,7 +249,7 @@ public class AsyncSearchActionIT extends AsyncSearchIntegTestCase {
assertNull(response.getSearchResponse());
assertNotNull(response.getFailure());
assertFalse(response.isRunning());
ElasticsearchException exc = response.getFailure();
Exception exc = response.getFailure();
assertThat(exc.getMessage(), containsString("no such index"));
}

View File

@ -6,7 +6,6 @@
package org.elasticsearch.xpack.search;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.Version;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.search.ShardSearchFailure;
@ -19,6 +18,7 @@ import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentElasticsearchExtension;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.script.ScriptException;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.search.internal.InternalSearchResponse;
import org.elasticsearch.test.ESTestCase;
@ -30,6 +30,7 @@ import org.elasticsearch.xpack.core.transform.transforms.TimeSyncConfig;
import org.junit.Before;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.Date;
import java.util.List;
@ -107,7 +108,8 @@ public class AsyncSearchResponseTests extends ESTestCase {
case 2:
return new AsyncSearchResponse(searchId, searchResponse,
new ElasticsearchException(new IOException("boum")), randomBoolean(), randomBoolean(),
new ScriptException("messageData", new Exception("causeData"), Arrays.asList("stack1", "stack2"),
"sourceData", "langData"), randomBoolean(), randomBoolean(),
randomNonNegativeLong(), randomNonNegativeLong());
default:

View File

@ -7,15 +7,16 @@ package org.elasticsearch.xpack.search;
import org.apache.lucene.search.TotalHits;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.OriginalIndices;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.search.SearchShard;
import org.elasticsearch.action.search.ShardSearchFailure;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.search.aggregations.InternalAggregations;
import org.elasticsearch.search.internal.InternalSearchResponse;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.client.NoOpClient;
@ -26,35 +27,19 @@ import org.elasticsearch.xpack.core.search.action.AsyncSearchResponse;
import org.junit.After;
import org.junit.Before;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
public class AsyncSearchTaskTests extends ESTestCase {
private ThreadPool threadPool;
private static class TestTask extends CancellableTask {
private TestTask(long id, String type, String action, String description, TaskId parentTaskId, Map<String, String> headers) {
super(id, type, action, description, parentTaskId, headers);
}
@Override
public boolean shouldCancelChildrenOnCancellation() {
return false;
}
}
private static TestTask createSubmitTask() {
return new TestTask(0L, "", "", "test", new TaskId("node1", 0), Collections.emptyMap());
}
@Before
public void beforeTest() {
threadPool = new TestThreadPool(getTestName());
@ -65,6 +50,12 @@ public class AsyncSearchTaskTests extends ESTestCase {
threadPool.shutdownNow();
}
private AsyncSearchTask createAsyncSearchTask() {
return new AsyncSearchTask(0L, "", "", new TaskId("node1", 0), () -> false, TimeValue.timeValueHours(1),
Collections.emptyMap(), Collections.emptyMap(), new AsyncExecutionId("0", new TaskId("node1", 1)),
new NoOpClient(threadPool), threadPool, null);
}
public void testWaitForInit() throws InterruptedException {
AsyncSearchTask task = new AsyncSearchTask(0L, "", "", new TaskId("node1", 0), () -> false, TimeValue.timeValueHours(1),
Collections.emptyMap(), Collections.emptyMap(), new AsyncExecutionId("0", new TaskId("node1", 1)),
@ -106,9 +97,7 @@ public class AsyncSearchTaskTests extends ESTestCase {
}
public void testWithFailure() throws InterruptedException {
AsyncSearchTask task = new AsyncSearchTask(0L, "", "", new TaskId("node1", 0), () -> false, TimeValue.timeValueHours(1),
Collections.emptyMap(), Collections.emptyMap(), new AsyncExecutionId("0", new TaskId("node1", 1)),
new NoOpClient(threadPool), threadPool, null);
AsyncSearchTask task = createAsyncSearchTask();
int numThreads = randomIntBetween(1, 10);
CountDownLatch latch = new CountDownLatch(numThreads);
for (int i = 0; i < numThreads; i++) {
@ -134,9 +123,7 @@ public class AsyncSearchTaskTests extends ESTestCase {
}
public void testWaitForCompletion() throws InterruptedException {
AsyncSearchTask task = new AsyncSearchTask(0L, "", "", new TaskId("node1", 0), () -> false, TimeValue.timeValueHours(1),
Collections.emptyMap(), Collections.emptyMap(), new AsyncExecutionId("0", new TaskId("node1", 1)),
new NoOpClient(threadPool), threadPool, null);
AsyncSearchTask task = createAsyncSearchTask();
int numShards = randomIntBetween(0, 10);
List<SearchShard> shards = new ArrayList<>();
for (int i = 0; i < numShards; i++) {
@ -165,6 +152,42 @@ public class AsyncSearchTaskTests extends ESTestCase {
threadPool.shutdownNow();
}
public void testWithFetchFailures() throws InterruptedException {
AsyncSearchTask task = createAsyncSearchTask();
int numShards = randomIntBetween(0, 10);
List<SearchShard> shards = new ArrayList<>();
for (int i = 0; i < numShards; i++) {
shards.add(new SearchShard(null, new ShardId("0", "0", 1)));
}
List<SearchShard> skippedShards = new ArrayList<>();
int numSkippedShards = randomIntBetween(0, 10);
for (int i = 0; i < numSkippedShards; i++) {
skippedShards.add(new SearchShard(null, new ShardId("0", "0", 1)));
}
task.getSearchProgressActionListener().onListShards(shards, skippedShards, SearchResponse.Clusters.EMPTY, false);
for (int i = 0; i < numShards; i++) {
task.getSearchProgressActionListener().onPartialReduce(shards.subList(i, i+1),
new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), null, 0);
assertCompletionListeners(task, numShards+numSkippedShards, numSkippedShards, 0, true);
}
task.getSearchProgressActionListener().onFinalReduce(shards,
new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), null, 0);
int numFetchFailures = randomIntBetween(1, numShards);
for (int i = 0; i < numFetchFailures; i++) {
task.getSearchProgressActionListener().onFetchFailure(i,
new SearchShardTarget("0", new ShardId("0", "0", 1), null, OriginalIndices.NONE),
new IOException("boum"));
}
assertCompletionListeners(task, numShards+numSkippedShards, numSkippedShards, numFetchFailures, true);
((AsyncSearchTask.Listener)task.getProgressListener()).onResponse(
newSearchResponse(numShards+numSkippedShards, numShards, numSkippedShards));
assertCompletionListeners(task, numShards+numSkippedShards,
numSkippedShards, numFetchFailures, false);
threadPool.shutdownNow();
}
private static SearchResponse newSearchResponse(int totalShards, int successfulShards, int skippedShards) {
InternalSearchResponse response = new InternalSearchResponse(SearchHits.empty(),
InternalAggregations.EMPTY, null, null, false, null, 1);
@ -187,6 +210,13 @@ public class AsyncSearchTaskTests extends ESTestCase {
assertThat(resp.getSearchResponse().getSkippedShards(), equalTo(expectedSkippedShards));
assertThat(resp.getSearchResponse().getFailedShards(), equalTo(expectedShardFailures));
assertThat(resp.isPartial(), equalTo(isPartial));
if (expectedShardFailures > 0) {
assertThat(resp.getSearchResponse().getShardFailures().length, equalTo(expectedShardFailures));
for (ShardSearchFailure failure : resp.getSearchResponse().getShardFailures()) {
assertThat(failure.getCause(), instanceOf(IOException.class));
assertThat(failure.getCause().getMessage(), equalTo("boum"));
}
}
latch.countDown();
}

View File

@ -6,6 +6,8 @@
package org.elasticsearch.xpack.core.search.action;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.common.Nullable;
@ -29,7 +31,7 @@ public class AsyncSearchResponse extends ActionResponse implements StatusToXCont
@Nullable
private final SearchResponse searchResponse;
@Nullable
private final ElasticsearchException error;
private final Exception error;
private final boolean isRunning;
private final boolean isPartial;
@ -60,7 +62,7 @@ public class AsyncSearchResponse extends ActionResponse implements StatusToXCont
*/
public AsyncSearchResponse(String id,
SearchResponse searchResponse,
ElasticsearchException error,
Exception error,
boolean isPartial,
boolean isRunning,
long startTimeMillis,
@ -76,7 +78,11 @@ public class AsyncSearchResponse extends ActionResponse implements StatusToXCont
public AsyncSearchResponse(StreamInput in) throws IOException {
this.id = in.readOptionalString();
if (in.getVersion().onOrAfter(Version.V_8_0_0)) {
this.error = in.readBoolean() ? in.readException() : null;
} else {
this.error = in.readOptionalWriteable(ElasticsearchException::new);
}
this.searchResponse = in.readOptionalWriteable(SearchResponse::new);
this.isPartial = in.readBoolean();
this.isRunning = in.readBoolean();
@ -87,7 +93,16 @@ public class AsyncSearchResponse extends ActionResponse implements StatusToXCont
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalString(id);
out.writeOptionalWriteable(error);
if (out.getVersion().onOrAfter(Version.V_8_0_0)) {
if (error != null) {
out.writeBoolean(true);
out.writeException(error);
} else {
out.writeBoolean(false);
}
} else {
out.writeOptionalWriteable(ExceptionsHelper.convertToElastic(error));
}
out.writeOptionalWriteable(searchResponse);
out.writeBoolean(isPartial);
out.writeBoolean(isRunning);
@ -120,7 +135,7 @@ public class AsyncSearchResponse extends ActionResponse implements StatusToXCont
/**
* Returns the failure reason or null if the query is running or has completed normally.
*/
public ElasticsearchException getFailure() {
public Exception getFailure() {
return error;
}
@ -170,7 +185,7 @@ public class AsyncSearchResponse extends ActionResponse implements StatusToXCont
// shard failures are not considered fatal for partial results so
// we return OK until we get the final response even if we don't have
// a single successful shard.
return error != null ? error.status() : OK;
return error != null ? ExceptionsHelper.status(error) : OK;
} else {
return searchResponse.status();
}
@ -193,7 +208,7 @@ public class AsyncSearchResponse extends ActionResponse implements StatusToXCont
}
if (error != null) {
builder.startObject("error");
error.toXContent(builder, params);
ElasticsearchException.generateThrowableXContent(builder, params, error);
builder.endObject();
}
builder.endObject();