[ML] job results provider refactoring (#52012) (#52238)

During a bug hunt, I caught a handful of things (unrelated to the bug) that could be potential issues:

1. Needlessly wrapping in exception handling (minor cleanup)
2. Potential of notifying listeners of a failure multiple times + even trying to notify of a success after a failure notification
This commit is contained in:
Benjamin Trent 2020-02-11 17:54:44 -05:00 committed by GitHub
parent 28c56da754
commit 2a968f4f2b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 151 additions and 130 deletions

View File

@ -355,7 +355,7 @@ public class TransportDeleteJobAction extends TransportMasterNodeAction<DeleteJo
}
},
failure -> {
if (failure.getClass() == IndexNotFoundException.class) { // assume the index is already deleted
if (ExceptionsHelper.unwrapCause(failure) instanceof IndexNotFoundException) { // assume the index is already deleted
deleteByQueryExecutor.onResponse(false); // skip DBQ && Alias
} else {
failureHandler.accept(failure);
@ -403,9 +403,7 @@ public class TransportDeleteJobAction extends TransportMasterNodeAction<DeleteJo
// Step 4. Get the job as the initial result index name is required
ActionListener<Boolean> deleteCategorizerStateHandler = ActionListener.wrap(
response -> {
jobConfigProvider.getJob(jobId, getJobHandler);
},
response -> jobConfigProvider.getJob(jobId, getJobHandler),
failureHandler
);
@ -509,9 +507,7 @@ public class TransportDeleteJobAction extends TransportMasterNodeAction<DeleteJo
return;
}
executeAsyncWithOrigin(parentTaskClient.threadPool().getThreadContext(), ML_ORIGIN, removeRequest,
ActionListener.<AcknowledgedResponse>wrap(
finishedHandler::onResponse,
finishedHandler::onFailure),
finishedHandler,
parentTaskClient.admin().indices()::aliases);
},
finishedHandler::onFailure), parentTaskClient.admin().indices()::getAliases);

View File

@ -90,7 +90,7 @@ public class TransportGetDatafeedsStatsAction extends TransportMasterNodeReadAct
.collect(Collectors.toList());
jobResultsProvider.datafeedTimingStats(
jobIds,
timingStatsByJobId -> {
ActionListener.wrap(timingStatsByJobId -> {
List<GetDatafeedsStatsAction.Response.DatafeedStats> results = expandedIds.stream()
.map(datafeedId -> {
DatafeedConfig config = existingConfigs.get(datafeedId);
@ -109,7 +109,7 @@ public class TransportGetDatafeedsStatsAction extends TransportMasterNodeReadAct
new QueryPage<>(results, results.size(), DatafeedConfig.RESULTS_FIELD);
listener.onResponse(new GetDatafeedsStatsAction.Response(statsPage));
},
listener::onFailure);
listener::onFailure));
},
listener::onFailure)
);

View File

@ -83,11 +83,7 @@ public class JobDataDeleter {
// _doc is the most efficient sort order and will also disable scoring
deleteByQueryRequest.getSearchRequest().source().sort(ElasticsearchMappings.ES_DOC);
try {
executeAsyncWithOrigin(client, ML_ORIGIN, DeleteByQueryAction.INSTANCE, deleteByQueryRequest, listener);
} catch (Exception e) {
listener.onFailure(e);
}
}
/**

View File

@ -446,10 +446,9 @@ public class JobResultsProvider {
.addSort(SortBuilders.fieldSort(TimingStats.BUCKET_COUNT.getPreferredName()).order(SortOrder.DESC));
}
public void datafeedTimingStats(List<String> jobIds, Consumer<Map<String, DatafeedTimingStats>> handler,
Consumer<Exception> errorHandler) {
public void datafeedTimingStats(List<String> jobIds, ActionListener<Map<String, DatafeedTimingStats>> listener) {
if (jobIds.isEmpty()) {
handler.accept(Collections.emptyMap());
listener.onResponse(Collections.emptyMap());
return;
}
MultiSearchRequestBuilder msearchRequestBuilder = client.prepareMultiSearch();
@ -470,40 +469,45 @@ public class JobResultsProvider {
String jobId = jobIds.get(i);
MultiSearchResponse.Item itemResponse = msearchResponse.getResponses()[i];
if (itemResponse.isFailure()) {
errorHandler.accept(itemResponse.getFailure());
} else {
listener.onFailure(itemResponse.getFailure());
return;
}
SearchResponse searchResponse = itemResponse.getResponse();
ShardSearchFailure[] shardFailures = searchResponse.getShardFailures();
int unavailableShards = searchResponse.getTotalShards() - searchResponse.getSuccessfulShards();
if (shardFailures != null && shardFailures.length > 0) {
LOGGER.error("[{}] Search request returned shard failures: {}", jobId, Arrays.toString(shardFailures));
errorHandler.accept(
listener.onFailure(
new ElasticsearchException(ExceptionsHelper.shardFailuresToErrorMsg(jobId, shardFailures)));
} else if (unavailableShards > 0) {
errorHandler.accept(
return;
}
if (unavailableShards > 0) {
listener.onFailure(
new ElasticsearchException(
"[" + jobId + "] Search request encountered [" + unavailableShards + "] unavailable shards"));
} else {
return;
}
SearchHits hits = searchResponse.getHits();
long hitsCount = hits.getHits().length;
if (hitsCount == 0) {
if (hitsCount == 0 || hitsCount > 1) {
SearchRequest searchRequest = msearchRequest.requests().get(i);
LOGGER.debug("Found 0 hits for [{}]", new Object[]{searchRequest.indices()});
} else if (hitsCount > 1) {
SearchRequest searchRequest = msearchRequest.requests().get(i);
LOGGER.debug("Found multiple hits for [{}]", new Object[]{searchRequest.indices()});
} else {
assert hitsCount == 1;
LOGGER.debug("Found {} hits for [{}]",
hitsCount == 0 ? "0" : "multiple",
new Object[]{searchRequest.indices()});
continue;
}
SearchHit hit = hits.getHits()[0];
DatafeedTimingStats timingStats = parseSearchHit(hit, DatafeedTimingStats.PARSER, errorHandler);
try {
DatafeedTimingStats timingStats = parseSearchHit(hit, DatafeedTimingStats.PARSER);
timingStatsByJobId.put(jobId, timingStats);
} catch (Exception e) {
listener.onFailure(e);
return;
}
}
}
}
handler.accept(timingStatsByJobId);
listener.onResponse(timingStatsByJobId);
},
errorHandler
listener::onFailure
),
client::multiSearch);
}
@ -572,7 +576,8 @@ public class JobResultsProvider {
MultiSearchResponse.Item itemResponse = response.getResponses()[i];
if (itemResponse.isFailure()) {
errorHandler.accept(itemResponse.getFailure());
} else {
return;
}
SearchResponse searchResponse = itemResponse.getResponse();
ShardSearchFailure[] shardFailures = searchResponse.getShardFailures();
int unavailableShards = searchResponse.getTotalShards() - searchResponse.getSuccessfulShards();
@ -581,24 +586,28 @@ public class JobResultsProvider {
Arrays.toString(shardFailures));
errorHandler.accept(new ElasticsearchException(
ExceptionsHelper.shardFailuresToErrorMsg(jobId, shardFailures)));
} else if (unavailableShards > 0) {
return;
}
if (unavailableShards > 0) {
errorHandler.accept(new ElasticsearchException("[" + jobId
+ "] Search request encountered [" + unavailableShards + "] unavailable shards"));
} else {
return;
}
SearchHits hits = searchResponse.getHits();
long hitsCount = hits.getHits().length;
if (hitsCount == 0) {
SearchRequest searchRequest = msearch.request().requests().get(i);
LOGGER.debug("Found 0 hits for [{}]", new Object[]{searchRequest.indices()});
} else {
}
for (SearchHit hit : hits) {
parseAutodetectParamSearchHit(jobId, paramsBuilder, hit, errorHandler);
try {
parseAutodetectParamSearchHit(jobId, paramsBuilder, hit);
} catch (Exception e) {
errorHandler.accept(e);
return;
}
}
}
}
}
getScheduledEventsListener.onResponse(paramsBuilder);
},
errorHandler
@ -612,38 +621,47 @@ public class JobResultsProvider {
.setRouting(id);
}
private static void parseAutodetectParamSearchHit(String jobId, AutodetectParams.Builder paramsBuilder, SearchHit hit,
Consumer<Exception> errorHandler) {
/**
* @throws ElasticsearchException when search hit cannot be parsed
* @throws IllegalStateException when search hit has an unexpected ID
*/
private static void parseAutodetectParamSearchHit(String jobId,
AutodetectParams.Builder paramsBuilder,
SearchHit hit) {
String hitId = hit.getId();
if (DataCounts.documentId(jobId).equals(hitId)) {
paramsBuilder.setDataCounts(parseSearchHit(hit, DataCounts.PARSER, errorHandler));
paramsBuilder.setDataCounts(parseSearchHit(hit, DataCounts.PARSER));
} else if (TimingStats.documentId(jobId).equals(hitId)) {
paramsBuilder.setTimingStats(parseSearchHit(hit, TimingStats.PARSER, errorHandler));
paramsBuilder.setTimingStats(parseSearchHit(hit, TimingStats.PARSER));
} else if (hitId.startsWith(ModelSizeStats.documentIdPrefix(jobId))) {
ModelSizeStats.Builder modelSizeStats = parseSearchHit(hit, ModelSizeStats.LENIENT_PARSER, errorHandler);
ModelSizeStats.Builder modelSizeStats = parseSearchHit(hit, ModelSizeStats.LENIENT_PARSER);
paramsBuilder.setModelSizeStats(modelSizeStats == null ? null : modelSizeStats.build());
} else if (hitId.startsWith(ModelSnapshot.documentIdPrefix(jobId))) {
ModelSnapshot.Builder modelSnapshot = parseSearchHit(hit, ModelSnapshot.LENIENT_PARSER, errorHandler);
ModelSnapshot.Builder modelSnapshot = parseSearchHit(hit, ModelSnapshot.LENIENT_PARSER);
paramsBuilder.setModelSnapshot(modelSnapshot == null ? null : modelSnapshot.build());
} else if (Quantiles.documentId(jobId).equals(hit.getId())) {
paramsBuilder.setQuantiles(parseSearchHit(hit, Quantiles.LENIENT_PARSER, errorHandler));
paramsBuilder.setQuantiles(parseSearchHit(hit, Quantiles.LENIENT_PARSER));
} else if (hitId.startsWith(MlFilter.DOCUMENT_ID_PREFIX)) {
paramsBuilder.addFilter(parseSearchHit(hit, MlFilter.LENIENT_PARSER, errorHandler).build());
paramsBuilder.addFilter(parseSearchHit(hit, MlFilter.LENIENT_PARSER).build());
} else {
errorHandler.accept(new IllegalStateException("Unexpected Id [" + hitId + "]"));
throw new IllegalStateException("Unexpected Id [" + hitId + "]");
}
}
private static <T, U> T parseSearchHit(SearchHit hit, BiFunction<XContentParser, U, T> objectParser,
Consumer<Exception> errorHandler) {
/**
* @param hit The search hit to parse
* @param objectParser Parser for the object of type T
* @return The parsed value of T from the search hit
* @throws ElasticsearchException on failure
*/
private static <T, U> T parseSearchHit(SearchHit hit, BiFunction<XContentParser, U, T> objectParser) {
BytesReference source = hit.getSourceRef();
try (InputStream stream = source.streamInput();
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
.createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, stream)) {
return objectParser.apply(parser, null);
} catch (IOException e) {
errorHandler.accept(new ElasticsearchParseException("failed to parse " + hit.getId(), e));
return null;
throw new ElasticsearchParseException("failed to parse " + hit.getId(), e);
}
}
@ -1091,7 +1109,12 @@ public class JobResultsProvider {
LOGGER.trace("No {} for job with id {}", resultDescription, jobId);
handler.accept(new Result<>(null, notFoundSupplier.get()));
} else if (hits.length == 1) {
handler.accept(new Result<>(hits[0].getIndex(), parseSearchHit(hits[0], objectParser, errorHandler)));
try {
T result = parseSearchHit(hits[0], objectParser);
handler.accept(new Result<>(hits[0].getIndex(), result));
} catch (Exception e) {
errorHandler.accept(e);
}
} else {
errorHandler.accept(new IllegalStateException("Search for unique [" + resultDescription + "] returned ["
+ hits.length + "] hits even though size was 1"));
@ -1229,14 +1252,18 @@ public class JobResultsProvider {
response -> {
List<ScheduledEvent> events = new ArrayList<>();
SearchHit[] hits = response.getHits().getHits();
try {
for (SearchHit hit : hits) {
ScheduledEvent.Builder event = parseSearchHit(hit, ScheduledEvent.LENIENT_PARSER, handler::onFailure);
ScheduledEvent.Builder event = parseSearchHit(hit, ScheduledEvent.LENIENT_PARSER);
event.eventId(hit.getId());
events.add(event.build());
}
handler.onResponse(new QueryPage<>(events, response.getHits().getTotalHits().value,
ScheduledEvent.RESULTS_FIELD));
} catch (Exception e) {
handler.onFailure(e);
}
},
handler::onFailure),
client::search);
@ -1357,12 +1384,15 @@ public class JobResultsProvider {
response -> {
List<Calendar> calendars = new ArrayList<>();
SearchHit[] hits = response.getHits().getHits();
try {
for (SearchHit hit : hits) {
calendars.add(parseSearchHit(hit, Calendar.LENIENT_PARSER, listener::onFailure).build());
calendars.add(parseSearchHit(hit, Calendar.LENIENT_PARSER).build());
}
listener.onResponse(new QueryPage<>(calendars, response.getHits().getTotalHits().value,
Calendar.RESULTS_FIELD));
} catch (Exception e) {
listener.onFailure(e);
}
},
listener::onFailure)
, client::search);
@ -1370,13 +1400,8 @@ public class JobResultsProvider {
public void removeJobFromCalendars(String jobId, ActionListener<Boolean> listener) {
ActionListener<BulkResponse> updateCalandarsListener = ActionListener.wrap(
r -> {
if (r.hasFailures()) {
listener.onResponse(false);
}
listener.onResponse(true);
},
ActionListener<BulkResponse> updateCalendarsListener = ActionListener.wrap(
r -> listener.onResponse(r.hasFailures() == false),
listener::onFailure
);
@ -1384,23 +1409,24 @@ public class JobResultsProvider {
r -> {
BulkRequestBuilder bulkUpdate = client.prepareBulk();
bulkUpdate.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
r.results().stream()
.map(c -> {
Set<String> ids = new HashSet<>(c.getJobIds());
ids.remove(jobId);
return new Calendar(c.getId(), new ArrayList<>(ids), c.getDescription());
}).forEach(c -> {
UpdateRequest updateRequest = new UpdateRequest(MlMetaIndex.INDEX_NAME, c.documentId());
for (Calendar calendar : r.results()) {
List<String> ids = calendar.getJobIds()
.stream()
.filter(jId -> jobId.equals(jId) == false)
.collect(Collectors.toList());
Calendar newCalendar = new Calendar(calendar.getId(), ids, calendar.getDescription());
UpdateRequest updateRequest = new UpdateRequest(MlMetaIndex.INDEX_NAME, newCalendar.documentId());
try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
updateRequest.doc(c.toXContent(builder, ToXContent.EMPTY_PARAMS));
updateRequest.doc(newCalendar.toXContent(builder, ToXContent.EMPTY_PARAMS));
} catch (IOException e) {
throw new IllegalStateException("Failed to serialise calendar with id [" + c.getId() + "]", e);
listener.onFailure(
new IllegalStateException("Failed to serialise calendar with id [" + newCalendar.getId() + "]", e));
return;
}
bulkUpdate.add(updateRequest);
});
}
if (bulkUpdate.numberOfActions() > 0) {
executeAsyncWithOrigin(client, ML_ORIGIN, BulkAction.INSTANCE, bulkUpdate.request(), updateCalandarsListener);
executeAsyncWithOrigin(client, ML_ORIGIN, BulkAction.INSTANCE, bulkUpdate.request(), updateCalendarsListener);
} else {
listener.onResponse(true);
}

View File

@ -867,7 +867,7 @@ public class JobResultsProviderTests extends ESTestCase {
provider.timingStats(
"foo",
stats -> assertThat(stats, equalTo(new TimingStats("foo", 7, 1.0, 1000.0, 666.0, 777.0, context))),
e -> { throw new AssertionError(); });
e -> { throw new AssertionError("Failure getting timing stats", e); });
verify(client).prepareSearch(indexName);
verify(client).threadPool();
@ -888,7 +888,7 @@ public class JobResultsProviderTests extends ESTestCase {
provider.timingStats(
"foo",
stats -> assertThat(stats, equalTo(new TimingStats("foo"))),
e -> { throw new AssertionError(); });
e -> { throw new AssertionError("Failure getting timing stats", e); });
verify(client).prepareSearch(indexName);
verify(client).threadPool();
@ -902,9 +902,9 @@ public class JobResultsProviderTests extends ESTestCase {
JobResultsProvider provider = createProvider(client);
provider.datafeedTimingStats(
Arrays.asList(),
ActionListener.wrap(
statsByJobId -> assertThat(statsByJobId, anEmptyMap()),
e -> { throw new AssertionError(); });
e -> { throw new AssertionError("Failure getting datafeed timing stats", e); }));
verifyZeroInteractions(client);
}
@ -971,8 +971,11 @@ public class JobResultsProviderTests extends ESTestCase {
expectedStatsByJobId.put("bar", new DatafeedTimingStats("bar", 7, 77, 777.0, contextBar));
provider.datafeedTimingStats(
Arrays.asList("foo", "bar"),
statsByJobId -> assertThat(statsByJobId, equalTo(expectedStatsByJobId)),
e -> { throw new AssertionError(); });
ActionListener.wrap(
statsByJobId ->
assertThat(statsByJobId, equalTo(expectedStatsByJobId)),
e -> fail(e.getMessage())
));
verify(client).threadPool();
verify(client).prepareMultiSearch();
@ -1009,7 +1012,7 @@ public class JobResultsProviderTests extends ESTestCase {
provider.datafeedTimingStats(
"foo",
stats -> assertThat(stats, equalTo(new DatafeedTimingStats("foo", 6, 66, 666.0, contextFoo))),
e -> { throw new AssertionError(); });
e -> { throw new AssertionError("Failure getting datafeed timing stats", e); });
verify(client).prepareSearch(indexName);
verify(client).threadPool();
@ -1030,7 +1033,7 @@ public class JobResultsProviderTests extends ESTestCase {
provider.datafeedTimingStats(
"foo",
stats -> assertThat(stats, equalTo(new DatafeedTimingStats("foo"))),
e -> { throw new AssertionError(); });
e -> { throw new AssertionError("Failure getting datafeed timing stats", e); });
verify(client).prepareSearch(indexName);
verify(client).threadPool();