[ML] Fix search that fetches results for renormalization (elastic/x-pack-elasticsearch#1556)

The commit that converted the results index into single type
broke the search for fetching results for renormalization.
This commit fixes that.

Original commit: elastic/x-pack-elasticsearch@1ca7517adc
This commit is contained in:
Dimitris Athanasiou 2017-05-25 17:54:13 +01:00 committed by GitHub
parent 9b655ce6f1
commit 1e86f55746
11 changed files with 304 additions and 169 deletions

View File

@ -15,7 +15,6 @@ import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.sort.SortBuilders;
import org.elasticsearch.xpack.ml.job.results.Result;
import java.util.ArrayDeque;
import java.util.Arrays;
@ -35,49 +34,17 @@ public abstract class BatchedDocumentsIterator<T> {
private final Client client;
private final String index;
private final ResultsFilterBuilder filterBuilder;
private volatile long count;
private volatile long totalHits;
private volatile String scrollId;
private volatile boolean isScrollInitialised;
public BatchedDocumentsIterator(Client client, String index) {
this(client, index, new ResultsFilterBuilder());
}
protected BatchedDocumentsIterator(Client client, String index, QueryBuilder queryBuilder) {
this(client, index, new ResultsFilterBuilder(queryBuilder));
}
private BatchedDocumentsIterator(Client client, String index, ResultsFilterBuilder resultsFilterBuilder) {
protected BatchedDocumentsIterator(Client client, String index) {
this.client = Objects.requireNonNull(client);
this.index = Objects.requireNonNull(index);
totalHits = 0;
count = 0;
filterBuilder = Objects.requireNonNull(resultsFilterBuilder);
isScrollInitialised = false;
}
/**
* Query documents whose timestamp is within the given time range
*
* @param startEpochMs the start time as epoch milliseconds (inclusive)
* @param endEpochMs the end time as epoch milliseconds (exclusive)
* @return the iterator itself
*/
public BatchedDocumentsIterator<T> timeRange(long startEpochMs, long endEpochMs) {
filterBuilder.timeRange(Result.TIMESTAMP.getPreferredName(), startEpochMs, endEpochMs);
return this;
}
/**
* Sets whether interim results should be included
*
* @param includeInterim Whether interim results should be included
*/
public BatchedDocumentsIterator<T> includeInterim(boolean includeInterim) {
filterBuilder.interim(includeInterim);
return this;
this.totalHits = 0;
this.count = 0;
this.isScrollInitialised = false;
}
/**
@ -118,17 +85,16 @@ public abstract class BatchedDocumentsIterator<T> {
}
private SearchResponse initScroll() {
LOGGER.trace("ES API CALL: search all of type {} from index {}", getType(), index);
LOGGER.trace("ES API CALL: search index {}", index);
isScrollInitialised = true;
SearchRequest searchRequest = new SearchRequest(index);
searchRequest.indicesOptions(JobProvider.addIgnoreUnavailable(SearchRequest.DEFAULT_INDICES_OPTIONS));
searchRequest.types(getType());
searchRequest.scroll(CONTEXT_ALIVE_DURATION);
searchRequest.source(new SearchSourceBuilder()
.size(BATCH_SIZE)
.query(filterBuilder.build())
.query(getQuery())
.sort(SortBuilders.fieldSort(ElasticsearchMappings.ES_DOC)));
SearchResponse searchResponse = client.search(searchRequest).actionGet();
@ -155,7 +121,11 @@ public abstract class BatchedDocumentsIterator<T> {
return results;
}
protected abstract String getType();
/**
* Get the query to use for the search
* @return the search query
*/
protected abstract QueryBuilder getQuery();
/**
* Maps the search hit to the document type

View File

@ -6,18 +6,42 @@
package org.elasticsearch.xpack.ml.job.persistence;
import org.elasticsearch.client.Client;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.TermsQueryBuilder;
import org.elasticsearch.xpack.ml.job.results.Result;
public abstract class BatchedResultsIterator<T> extends BatchedDocumentsIterator<Result<T>> {
private final ResultsFilterBuilder filterBuilder;
public BatchedResultsIterator(Client client, String jobId, String resultType) {
super(client, AnomalyDetectorsIndex.jobResultsAliasedName(jobId),
new TermsQueryBuilder(Result.RESULT_TYPE.getPreferredName(), resultType));
super(client, AnomalyDetectorsIndex.jobResultsAliasedName(jobId));
this.filterBuilder = new ResultsFilterBuilder(new TermsQueryBuilder(Result.RESULT_TYPE.getPreferredName(), resultType));
}
@Override
protected String getType() {
return Result.TYPE.getPreferredName();
/**
* Query documents whose timestamp is within the given time range
*
* @param startEpochMs the start time as epoch milliseconds (inclusive)
* @param endEpochMs the end time as epoch milliseconds (exclusive)
* @return the iterator itself
*/
public BatchedResultsIterator<T> timeRange(long startEpochMs, long endEpochMs) {
filterBuilder.timeRange(Result.TIMESTAMP.getPreferredName(), startEpochMs, endEpochMs);
return this;
}
/**
* Sets whether interim results should be included
*
* @param includeInterim Whether interim results should be included
*/
public BatchedResultsIterator<T> includeInterim(boolean includeInterim) {
filterBuilder.interim(includeInterim);
return this;
}
protected final QueryBuilder getQuery() {
return filterBuilder.build();
}
}

View File

@ -440,27 +440,26 @@ public class JobProvider {
}
/**
* Returns a {@link BatchedDocumentsIterator} that allows querying
* Returns a {@link BatchedResultsIterator} that allows querying
* and iterating over a large number of buckets of the given job.
* The bucket and source indexes are returned by the iterator.
*
* @param jobId the id of the job for which buckets are requested
* @return a bucket {@link BatchedDocumentsIterator}
* @return a bucket {@link BatchedResultsIterator}
*/
public BatchedDocumentsIterator<Result<Bucket>> newBatchedBucketsIterator(String jobId) {
public BatchedResultsIterator<Bucket> newBatchedBucketsIterator(String jobId) {
return new BatchedBucketsIterator(client, jobId);
}
/**
* Returns a {@link BatchedDocumentsIterator} that allows querying
* Returns a {@link BatchedResultsIterator} that allows querying
* and iterating over a large number of records in the given job
* The records and source indexes are returned by the iterator.
*
* @param jobId the id of the job for which buckets are requested
* @return a record {@link BatchedDocumentsIterator}
* @return a record {@link BatchedResultsIterator}
*/
public BatchedDocumentsIterator<Result<AnomalyRecord>>
newBatchedRecordsIterator(String jobId) {
public BatchedResultsIterator<AnomalyRecord> newBatchedRecordsIterator(String jobId) {
return new BatchedRecordsIterator(client, jobId);
}
@ -672,13 +671,13 @@ public class JobProvider {
}
/**
* Returns a {@link BatchedDocumentsIterator} that allows querying
* Returns a {@link BatchedResultsIterator} that allows querying
* and iterating over a large number of influencers of the given job
*
* @param jobId the id of the job for which influencers are requested
* @return an influencer {@link BatchedDocumentsIterator}
* @return an influencer {@link BatchedResultsIterator}
*/
public BatchedDocumentsIterator<Result<Influencer>> newBatchedInfluencersIterator(String jobId) {
public BatchedResultsIterator<Influencer> newBatchedInfluencersIterator(String jobId) {
return new BatchedInfluencersIterator(client, jobId);
}

View File

@ -5,16 +5,12 @@
*/
package org.elasticsearch.xpack.ml.datafeed.extractor.scroll;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHitField;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.ml.test.SearchHitBuilder;
import org.joda.time.DateTime;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.hamcrest.Matchers.equalTo;
@ -106,36 +102,4 @@ public class ExtractedFieldTests extends ESTestCase {
assertThat(timeField.value(hit), equalTo(new Object[] { 123456789L }));
}
static class SearchHitBuilder {
private final SearchHit hit;
private final Map<String, SearchHitField> fields;
SearchHitBuilder(int docId) {
hit = new SearchHit(docId);
fields = new HashMap<>();
}
SearchHitBuilder addField(String name, Object value) {
return addField(name, Arrays.asList(value));
}
SearchHitBuilder addField(String name, List<Object> values) {
fields.put(name, new SearchHitField(name, values));
return this;
}
SearchHitBuilder setSource(String sourceJson) {
hit.sourceRef(new BytesArray(sourceJson));
return this;
}
SearchHit build() {
if (!fields.isEmpty()) {
hit.fields(fields);
}
return hit;
}
}
}

View File

@ -7,6 +7,7 @@ package org.elasticsearch.xpack.ml.datafeed.extractor.scroll;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.ml.test.SearchHitBuilder;
import org.joda.time.DateTime;
import java.util.Arrays;
@ -48,7 +49,7 @@ public class ExtractedFieldsTests extends ESTestCase {
}
public void testTimeFieldValue() {
SearchHit hit = new ExtractedFieldTests.SearchHitBuilder(1).addField("time", new DateTime(1000L)).build();
SearchHit hit = new SearchHitBuilder(1).addField("time", new DateTime(1000L)).build();
ExtractedFields extractedFields = new ExtractedFields(timeField, Arrays.asList(timeField));
@ -56,7 +57,7 @@ public class ExtractedFieldsTests extends ESTestCase {
}
public void testTimeFieldValueGivenEmptyArray() {
SearchHit hit = new ExtractedFieldTests.SearchHitBuilder(1).addField("time", Collections.emptyList()).build();
SearchHit hit = new SearchHitBuilder(1).addField("time", Collections.emptyList()).build();
ExtractedFields extractedFields = new ExtractedFields(timeField, Arrays.asList(timeField));
@ -64,7 +65,7 @@ public class ExtractedFieldsTests extends ESTestCase {
}
public void testTimeFieldValueGivenValueHasTwoElements() {
SearchHit hit = new ExtractedFieldTests.SearchHitBuilder(1).addField("time", Arrays.asList(1L, 2L)).build();
SearchHit hit = new SearchHitBuilder(1).addField("time", Arrays.asList(1L, 2L)).build();
ExtractedFields extractedFields = new ExtractedFields(timeField, Arrays.asList(timeField));
@ -72,7 +73,7 @@ public class ExtractedFieldsTests extends ESTestCase {
}
public void testTimeFieldValueGivenValueIsString() {
SearchHit hit = new ExtractedFieldTests.SearchHitBuilder(1).addField("time", "a string").build();
SearchHit hit = new SearchHitBuilder(1).addField("time", "a string").build();
ExtractedFields extractedFields = new ExtractedFields(timeField, Arrays.asList(timeField));

View File

@ -7,6 +7,7 @@ package org.elasticsearch.xpack.ml.datafeed.extractor.scroll;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.ml.test.SearchHitBuilder;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
@ -24,7 +25,7 @@ public class SearchHitToJsonProcessorTests extends ESTestCase {
ExtractedField arrayField = ExtractedField.newField("array", ExtractedField.ExtractionMethod.DOC_VALUE);
ExtractedFields extractedFields = new ExtractedFields(timeField, Arrays.asList(timeField, missingField, singleField, arrayField));
SearchHit hit = new ExtractedFieldTests.SearchHitBuilder(8)
SearchHit hit = new SearchHitBuilder(8)
.addField("time", 1000L)
.addField("single", "a")
.addField("array", Arrays.asList("b", "c"))
@ -42,13 +43,13 @@ public class SearchHitToJsonProcessorTests extends ESTestCase {
ExtractedField arrayField = ExtractedField.newField("array", ExtractedField.ExtractionMethod.DOC_VALUE);
ExtractedFields extractedFields = new ExtractedFields(timeField, Arrays.asList(timeField, missingField, singleField, arrayField));
SearchHit hit1 = new ExtractedFieldTests.SearchHitBuilder(8)
SearchHit hit1 = new SearchHitBuilder(8)
.addField("time", 1000L)
.addField("single", "a1")
.addField("array", Arrays.asList("b1", "c1"))
.build();
SearchHit hit2 = new ExtractedFieldTests.SearchHitBuilder(8)
SearchHit hit2 = new SearchHitBuilder(8)
.addField("time", 2000L)
.addField("single", "a2")
.addField("array", Arrays.asList("b2", "c2"))

View File

@ -0,0 +1,98 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.integration;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.xpack.ml.job.config.AnalysisConfig;
import org.elasticsearch.xpack.ml.job.config.DataDescription;
import org.elasticsearch.xpack.ml.job.config.Detector;
import org.elasticsearch.xpack.ml.job.config.Job;
import org.elasticsearch.xpack.ml.job.results.AnomalyRecord;
import org.junit.After;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan;
/**
* This is a minimal test to ensure renormalization takes place
*/
public class BasicRenormalizationIT extends MlNativeAutodetectIntegTestCase {
@After
public void tearDownData() throws Exception {
cleanUp();
}
public void test() throws Exception {
TimeValue bucketSpan = TimeValue.timeValueHours(1);
long startTime = 1491004800000L;
Job.Builder job = buildAndRegisterJob("basic-renormalization-it-job", bucketSpan);
openJob(job.getId());
postData(job.getId(), generateData(startTime, bucketSpan, 50,
bucketIndex -> {
if (bucketIndex == 35) {
// First anomaly is 10 events
return 10;
} else if (bucketIndex == 45) {
// Second anomaly is 100, should get the highest score and should bring the first score down
return 100;
} else {
return 1;
}
}).stream().collect(Collectors.joining()));
closeJob(job.getId());
List<AnomalyRecord> records = getRecords(job.getId());
assertThat(records.size(), equalTo(2));
AnomalyRecord laterRecord = records.get(0);
assertThat(laterRecord.getActual().get(0), equalTo(100.0));
AnomalyRecord earlierRecord = records.get(1);
assertThat(earlierRecord.getActual().get(0), equalTo(10.0));
assertThat(laterRecord.getRecordScore(), greaterThan(earlierRecord.getRecordScore()));
// This is the key assertion: if renormalization never happened then the record_score would
// be the same as the initial_record_score on the anomaly record that happened earlier
assertThat(earlierRecord.getInitialRecordScore(), greaterThan(earlierRecord.getRecordScore()));
}
private Job.Builder buildAndRegisterJob(String jobId, TimeValue bucketSpan) throws Exception {
Detector.Builder detector = new Detector.Builder("count", null);
AnalysisConfig.Builder analysisConfig = new AnalysisConfig.Builder(Arrays.asList(detector.build()));
analysisConfig.setBucketSpan(bucketSpan);
Job.Builder job = new Job.Builder(jobId);
job.setAnalysisConfig(analysisConfig);
DataDescription.Builder dataDescription = new DataDescription.Builder();
job.setDataDescription(dataDescription);
registerJob(job);
putJob(job);
return job;
}
private static List<String> generateData(long timestamp, TimeValue bucketSpan, int bucketCount,
Function<Integer, Integer> timeToCountFunction) throws IOException {
List<String> data = new ArrayList<>();
long now = timestamp;
for (int bucketIndex = 0; bucketIndex < bucketCount; bucketIndex++) {
for (int count = 0; count < timeToCountFunction.apply(bucketIndex); count++) {
Map<String, Object> record = new HashMap<>();
record.put("time", now);
data.add(createJsonRecord(record));
}
now += bucketSpan.getMillis();
}
return data;
}
}

View File

@ -5,32 +5,36 @@
*/
package org.elasticsearch.xpack.ml.job.persistence;
import org.apache.lucene.util.LuceneTestCase;
import org.elasticsearch.action.ActionFuture;
import org.elasticsearch.action.search.ClearScrollRequestBuilder;
import org.elasticsearch.action.search.SearchRequestBuilder;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.search.SearchScrollRequestBuilder;
import org.elasticsearch.action.search.SearchScrollRequest;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.sort.SortBuilder;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.ml.test.SearchHitBuilder;
import org.junit.Before;
import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Deque;
import java.util.List;
import java.util.NoSuchElementException;
import static org.mockito.Matchers.any;
import static org.hamcrest.Matchers.equalTo;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
@LuceneTestCase.AwaitsFix(bugUrl = "https://github.com/elastic/prelert-legacy/issues/127")
public class BatchedDocumentsIteratorTests extends ESTestCase {
private static final String INDEX_NAME = ".ml-anomalies-foo";
private static final String SCROLL_ID = "someScrollId";
@ -39,6 +43,9 @@ public class BatchedDocumentsIteratorTests extends ESTestCase {
private TestIterator testIterator;
private ArgumentCaptor<SearchRequest> searchRequestCaptor = ArgumentCaptor.forClass(SearchRequest.class);
private ArgumentCaptor<SearchScrollRequest> searchScrollRequestCaptor = ArgumentCaptor.forClass(SearchScrollRequest.class);
@Before
public void setUpMocks() {
client = Mockito.mock(Client.class);
@ -54,10 +61,12 @@ public class BatchedDocumentsIteratorTests extends ESTestCase {
assertTrue(testIterator.next().isEmpty());
assertFalse(testIterator.hasNext());
assertTrue(wasScrollCleared);
assertSearchRequest();
assertSearchScrollRequests(0);
}
public void testCallingNextWhenHasNextIsFalseThrows() {
new ScrollResponsesMocker().addBatch("a", "b", "c").finishMock();
new ScrollResponsesMocker().addBatch(createJsonDoc("a"), createJsonDoc("b"), createJsonDoc("c")).finishMock();
testIterator.next();
assertFalse(testIterator.hasNext());
@ -65,55 +74,84 @@ public class BatchedDocumentsIteratorTests extends ESTestCase {
}
public void testQueryReturnsSingleBatch() {
new ScrollResponsesMocker().addBatch("a", "b", "c").finishMock();
new ScrollResponsesMocker().addBatch(createJsonDoc("a"), createJsonDoc("b"), createJsonDoc("c")).finishMock();
assertTrue(testIterator.hasNext());
Deque<String> batch = testIterator.next();
assertEquals(3, batch.size());
assertTrue(batch.containsAll(Arrays.asList("a", "b", "c")));
assertTrue(batch.containsAll(Arrays.asList(createJsonDoc("a"), createJsonDoc("b"), createJsonDoc("c"))));
assertFalse(testIterator.hasNext());
assertTrue(wasScrollCleared);
assertSearchRequest();
assertSearchScrollRequests(0);
}
public void testQueryReturnsThreeBatches() {
new ScrollResponsesMocker()
.addBatch("a", "b", "c")
.addBatch("d", "e")
.addBatch("f")
.addBatch(createJsonDoc("a"), createJsonDoc("b"), createJsonDoc("c"))
.addBatch(createJsonDoc("d"), createJsonDoc("e"))
.addBatch(createJsonDoc("f"))
.finishMock();
assertTrue(testIterator.hasNext());
Deque<String> batch = testIterator.next();
assertEquals(3, batch.size());
assertTrue(batch.containsAll(Arrays.asList("a", "b", "c")));
assertTrue(batch.containsAll(Arrays.asList(createJsonDoc("a"), createJsonDoc("b"), createJsonDoc("c"))));
batch = testIterator.next();
assertEquals(2, batch.size());
assertTrue(batch.containsAll(Arrays.asList("d", "e")));
assertTrue(batch.containsAll(Arrays.asList(createJsonDoc("d"), createJsonDoc("e"))));
batch = testIterator.next();
assertEquals(1, batch.size());
assertTrue(batch.containsAll(Arrays.asList("f")));
assertTrue(batch.containsAll(Collections.singletonList(createJsonDoc("f"))));
assertFalse(testIterator.hasNext());
assertTrue(wasScrollCleared);
assertSearchRequest();
assertSearchScrollRequests(2);
}
private String createJsonDoc(String value) {
return "{\"foo\":\"" + value + "\"}";
}
private void givenClearScrollRequest() {
ClearScrollRequestBuilder requestBuilder = mock(ClearScrollRequestBuilder.class);
when(client.prepareClearScroll()).thenReturn(requestBuilder);
when(requestBuilder.setScrollIds(Arrays.asList(SCROLL_ID))).thenReturn(requestBuilder);
when(requestBuilder.setScrollIds(Collections.singletonList(SCROLL_ID))).thenReturn(requestBuilder);
when(requestBuilder.get()).thenAnswer((invocation) -> {
wasScrollCleared = true;
return null;
});
}
private void assertSearchRequest() {
List<SearchRequest> searchRequests = searchRequestCaptor.getAllValues();
assertThat(searchRequests.size(), equalTo(1));
SearchRequest searchRequest = searchRequests.get(0);
assertThat(searchRequest.indices(), equalTo(new String[] {INDEX_NAME}));
assertThat(searchRequest.scroll().keepAlive(), equalTo(TimeValue.timeValueMinutes(5)));
assertThat(searchRequest.types().length, equalTo(0));
assertThat(searchRequest.source().query(), equalTo(QueryBuilders.matchAllQuery()));
}
private void assertSearchScrollRequests(int expectedCount) {
List<SearchScrollRequest> searchScrollRequests = searchScrollRequestCaptor.getAllValues();
assertThat(searchScrollRequests.size(), equalTo(expectedCount));
for (SearchScrollRequest request : searchScrollRequests) {
assertThat(request.scrollId(), equalTo(SCROLL_ID));
assertThat(request.scroll().keepAlive(), equalTo(TimeValue.timeValueMinutes(5)));
}
}
private class ScrollResponsesMocker {
private List<String[]> batches = new ArrayList<>();
private long totalHits = 0;
private List<SearchScrollRequestBuilder> nextRequestBuilders = new ArrayList<>();
private List<SearchResponse> responses = new ArrayList<>();
ScrollResponsesMocker addBatch(String... hits) {
totalHits += hits.length;
@ -121,6 +159,7 @@ public class BatchedDocumentsIteratorTests extends ESTestCase {
return this;
}
@SuppressWarnings("unchecked")
void finishMock() {
if (batches.isEmpty()) {
givenInitialResponse();
@ -130,39 +169,38 @@ public class BatchedDocumentsIteratorTests extends ESTestCase {
for (int i = 1; i < batches.size(); ++i) {
givenNextResponse(batches.get(i));
}
if (nextRequestBuilders.size() > 0) {
SearchScrollRequestBuilder first = nextRequestBuilders.get(0);
if (nextRequestBuilders.size() > 1) {
SearchScrollRequestBuilder[] rest = new SearchScrollRequestBuilder[batches.size() - 1];
for (int i = 1; i < nextRequestBuilders.size(); ++i) {
rest[i - 1] = nextRequestBuilders.get(i);
if (responses.size() > 0) {
ActionFuture<SearchResponse> first = wrapResponse(responses.get(0));
if (responses.size() > 1) {
List<ActionFuture> rest = new ArrayList<>();
for (int i = 1; i < responses.size(); ++i) {
rest.add(wrapResponse(responses.get(i)));
}
when(client.prepareSearchScroll(SCROLL_ID)).thenReturn(first, rest);
when(client.searchScroll(searchScrollRequestCaptor.capture())).thenReturn(
first, rest.toArray(new ActionFuture[rest.size() - 1]));
} else {
when(client.prepareSearchScroll(SCROLL_ID)).thenReturn(first);
when(client.searchScroll(searchScrollRequestCaptor.capture())).thenReturn(first);
}
}
}
private void givenInitialResponse(String... hits) {
SearchResponse searchResponse = createSearchResponseWithHits(hits);
SearchRequestBuilder requestBuilder = mock(SearchRequestBuilder.class);
when(client.prepareSearch(INDEX_NAME)).thenReturn(requestBuilder);
when(requestBuilder.setScroll("5m")).thenReturn(requestBuilder);
when(requestBuilder.setSize(10000)).thenReturn(requestBuilder);
when(requestBuilder.setTypes("String")).thenReturn(requestBuilder);
when(requestBuilder.setQuery(any(QueryBuilder.class))).thenReturn(requestBuilder);
when(requestBuilder.addSort(any(SortBuilder.class))).thenReturn(requestBuilder);
when(requestBuilder.get()).thenReturn(searchResponse);
ActionFuture<SearchResponse> future = wrapResponse(searchResponse);
when(future.actionGet()).thenReturn(searchResponse);
when(client.search(searchRequestCaptor.capture())).thenReturn(future);
}
@SuppressWarnings("unchecked")
private ActionFuture<SearchResponse> wrapResponse(SearchResponse searchResponse) {
ActionFuture<SearchResponse> future = mock(ActionFuture.class);
when(future.actionGet()).thenReturn(searchResponse);
return future;
}
private void givenNextResponse(String... hits) {
SearchResponse searchResponse = createSearchResponseWithHits(hits);
SearchScrollRequestBuilder requestBuilder = mock(SearchScrollRequestBuilder.class);
when(requestBuilder.setScrollId(SCROLL_ID)).thenReturn(requestBuilder);
when(requestBuilder.setScroll("5m")).thenReturn(requestBuilder);
when(requestBuilder.get()).thenReturn(searchResponse);
nextRequestBuilders.add(requestBuilder);
responses.add(createSearchResponseWithHits(hits));
}
private SearchResponse createSearchResponseWithHits(String... hits) {
@ -174,16 +212,11 @@ public class BatchedDocumentsIteratorTests extends ESTestCase {
}
private SearchHits createHits(String... values) {
SearchHits searchHits = mock(SearchHits.class);
List<SearchHit> hits = new ArrayList<>();
for (String value : values) {
SearchHit hit = mock(SearchHit.class);
when(hit.getSourceAsString()).thenReturn(value);
hits.add(hit);
hits.add(new SearchHitBuilder(randomInt()).setSource(value).build());
}
when(searchHits.getTotalHits()).thenReturn(totalHits);
when(searchHits.getHits()).thenReturn(hits.toArray(new SearchHit[hits.size()]));
return searchHits;
return new SearchHits(hits.toArray(new SearchHit[hits.size()]), totalHits, 1.0f);
}
}
@ -193,8 +226,8 @@ public class BatchedDocumentsIteratorTests extends ESTestCase {
}
@Override
protected String getType() {
return "String";
protected QueryBuilder getQuery() {
return QueryBuilders.matchAllQuery();
}
@Override
@ -202,5 +235,4 @@ public class BatchedDocumentsIteratorTests extends ESTestCase {
return hit.getSourceAsString();
}
}
}

View File

@ -8,6 +8,7 @@ package org.elasticsearch.xpack.ml.job.persistence;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.xpack.ml.job.results.Result;
import java.util.Deque;
import java.util.List;
@ -15,34 +16,35 @@ import java.util.NoSuchElementException;
import static org.mockito.Mockito.mock;
public class MockBatchedDocumentsIterator<T> extends BatchedDocumentsIterator<T> {
private final List<Deque<T>> batches;
public class MockBatchedDocumentsIterator<T> extends BatchedResultsIterator<T> {
private final List<Deque<Result<T>>> batches;
private int index;
private boolean wasTimeRangeCalled;
private Boolean includeInterim;
private Boolean requireIncludeInterim;
public MockBatchedDocumentsIterator(List<Deque<T>> batches) {
super(mock(Client.class), "foo");
public MockBatchedDocumentsIterator(List<Deque<Result<T>>> batches, String resultType) {
super(mock(Client.class), "foo", resultType);
this.batches = batches;
index = 0;
wasTimeRangeCalled = false;
}
@Override
public BatchedDocumentsIterator<T> timeRange(long startEpochMs, long endEpochMs) {
public BatchedResultsIterator<T> timeRange(long startEpochMs, long endEpochMs) {
wasTimeRangeCalled = true;
return this;
}
@Override
public BatchedDocumentsIterator<T> includeInterim(boolean includeInterim) {
public BatchedResultsIterator<T> includeInterim(boolean includeInterim) {
this.includeInterim = includeInterim;
return this;
}
@Override
public Deque<T> next() {
public Deque<Result<T>> next() {
if (requireIncludeInterim != null && requireIncludeInterim != includeInterim) {
throw new IllegalStateException("Required include interim value [" + requireIncludeInterim + "]; actual was ["
+ includeInterim + "]");
@ -54,12 +56,7 @@ public class MockBatchedDocumentsIterator<T> extends BatchedDocumentsIterator<T>
}
@Override
protected String getType() {
return null;
}
@Override
protected T map(SearchHit hit) {
protected Result<T> map(SearchHit hit) {
return null;
}

View File

@ -209,8 +209,8 @@ public class ScoresUpdaterTests extends ESTestCase {
List<Deque<Result<AnomalyRecord>>> recordBatches = new ArrayList<>();
recordBatches.add(new ArrayDeque<>(records));
MockBatchedDocumentsIterator<Result<AnomalyRecord>> recordIter =
new MockBatchedDocumentsIterator<>(recordBatches);
MockBatchedDocumentsIterator<AnomalyRecord> recordIter = new MockBatchedDocumentsIterator<>(
recordBatches, AnomalyRecord.RESULT_TYPE_VALUE);
recordIter.requireIncludeInterim(false);
when(jobProvider.newBatchedRecordsIterator(JOB_ID)).thenReturn(recordIter);
@ -376,8 +376,7 @@ public class ScoresUpdaterTests extends ESTestCase {
batchesWithIndex.add(queueWithIndex);
}
MockBatchedDocumentsIterator<Result<Bucket>> bucketIter =
new MockBatchedDocumentsIterator<>(batchesWithIndex);
MockBatchedDocumentsIterator<Bucket> bucketIter = new MockBatchedDocumentsIterator<>(batchesWithIndex, Bucket.RESULT_TYPE_VALUE);
bucketIter.requireIncludeInterim(false);
when(jobProvider.newBatchedBucketsIterator(JOB_ID)).thenReturn(bucketIter);
}
@ -394,8 +393,8 @@ public class ScoresUpdaterTests extends ESTestCase {
}
batches.add(batch);
MockBatchedDocumentsIterator<Result<AnomalyRecord>> recordIter =
new MockBatchedDocumentsIterator<>(batches);
MockBatchedDocumentsIterator<AnomalyRecord> recordIter = new MockBatchedDocumentsIterator<>(
batches, AnomalyRecord.RESULT_TYPE_VALUE);
recordIter.requireIncludeInterim(false);
when(jobProvider.newBatchedRecordsIterator(JOB_ID)).thenReturn(recordIter);
}
@ -411,7 +410,7 @@ public class ScoresUpdaterTests extends ESTestCase {
queue.add(new Result<>("foo", inf));
}
batches.add(queue);
MockBatchedDocumentsIterator<Result<Influencer>> iterator = new MockBatchedDocumentsIterator<>(batches);
MockBatchedDocumentsIterator<Influencer> iterator = new MockBatchedDocumentsIterator<>(batches, Influencer.RESULT_TYPE_VALUE);
iterator.requireIncludeInterim(false);
when(jobProvider.newBatchedInfluencersIterator(JOB_ID)).thenReturn(iterator);
}

View File

@ -0,0 +1,50 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.test;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHitField;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* Utility class to build {@link SearchHit} in tests
*/
public class SearchHitBuilder {
private final SearchHit hit;
private final Map<String, SearchHitField> fields;
public SearchHitBuilder(int docId) {
hit = new SearchHit(docId);
fields = new HashMap<>();
}
public SearchHitBuilder addField(String name, Object value) {
return addField(name, Arrays.asList(value));
}
public SearchHitBuilder addField(String name, List<Object> values) {
fields.put(name, new SearchHitField(name, values));
return this;
}
public SearchHitBuilder setSource(String sourceJson) {
hit.sourceRef(new BytesArray(sourceJson));
return this;
}
public SearchHit build() {
if (!fields.isEmpty()) {
hit.fields(fields);
}
return hit;
}
}