[ML] Fix search that fetches results for renormalization ()

The commit that converted the results index into single type
broke the search for fetching results for renormalization.
This commit fixes that.

Original commit: elastic/x-pack-elasticsearch@1ca7517adc
This commit is contained in:
Dimitris Athanasiou 2017-05-25 17:54:13 +01:00 committed by GitHub
parent 9b655ce6f1
commit 1e86f55746
11 changed files with 304 additions and 169 deletions

@ -15,7 +15,6 @@ import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.sort.SortBuilders;
import org.elasticsearch.xpack.ml.job.results.Result;
import java.util.ArrayDeque;
import java.util.Arrays;
@ -35,49 +34,17 @@ public abstract class BatchedDocumentsIterator<T> {
private final Client client;
private final String index;
private final ResultsFilterBuilder filterBuilder;
private volatile long count;
private volatile long totalHits;
private volatile String scrollId;
private volatile boolean isScrollInitialised;
public BatchedDocumentsIterator(Client client, String index) {
this(client, index, new ResultsFilterBuilder());
}
protected BatchedDocumentsIterator(Client client, String index, QueryBuilder queryBuilder) {
this(client, index, new ResultsFilterBuilder(queryBuilder));
}
private BatchedDocumentsIterator(Client client, String index, ResultsFilterBuilder resultsFilterBuilder) {
protected BatchedDocumentsIterator(Client client, String index) {
this.client = Objects.requireNonNull(client);
this.index = Objects.requireNonNull(index);
totalHits = 0;
count = 0;
filterBuilder = Objects.requireNonNull(resultsFilterBuilder);
isScrollInitialised = false;
}
/**
* Query documents whose timestamp is within the given time range
*
* @param startEpochMs the start time as epoch milliseconds (inclusive)
* @param endEpochMs the end time as epoch milliseconds (exclusive)
* @return the iterator itself
*/
public BatchedDocumentsIterator<T> timeRange(long startEpochMs, long endEpochMs) {
filterBuilder.timeRange(Result.TIMESTAMP.getPreferredName(), startEpochMs, endEpochMs);
return this;
}
/**
* Sets whether interim results should be included
*
* @param includeInterim Whether interim results should be included
*/
public BatchedDocumentsIterator<T> includeInterim(boolean includeInterim) {
filterBuilder.interim(includeInterim);
return this;
this.totalHits = 0;
this.count = 0;
this.isScrollInitialised = false;
}
/**
@ -118,17 +85,16 @@ public abstract class BatchedDocumentsIterator<T> {
}
private SearchResponse initScroll() {
LOGGER.trace("ES API CALL: search all of type {} from index {}", getType(), index);
LOGGER.trace("ES API CALL: search index {}", index);
isScrollInitialised = true;
SearchRequest searchRequest = new SearchRequest(index);
searchRequest.indicesOptions(JobProvider.addIgnoreUnavailable(SearchRequest.DEFAULT_INDICES_OPTIONS));
searchRequest.types(getType());
searchRequest.scroll(CONTEXT_ALIVE_DURATION);
searchRequest.source(new SearchSourceBuilder()
.size(BATCH_SIZE)
.query(filterBuilder.build())
.query(getQuery())
.sort(SortBuilders.fieldSort(ElasticsearchMappings.ES_DOC)));
SearchResponse searchResponse = client.search(searchRequest).actionGet();
@ -155,7 +121,11 @@ public abstract class BatchedDocumentsIterator<T> {
return results;
}
protected abstract String getType();
/**
* Get the query to use for the search
* @return the search query
*/
protected abstract QueryBuilder getQuery();
/**
* Maps the search hit to the document type

@ -6,18 +6,42 @@
package org.elasticsearch.xpack.ml.job.persistence;
import org.elasticsearch.client.Client;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.TermsQueryBuilder;
import org.elasticsearch.xpack.ml.job.results.Result;
public abstract class BatchedResultsIterator<T> extends BatchedDocumentsIterator<Result<T>> {
private final ResultsFilterBuilder filterBuilder;
public BatchedResultsIterator(Client client, String jobId, String resultType) {
super(client, AnomalyDetectorsIndex.jobResultsAliasedName(jobId),
new TermsQueryBuilder(Result.RESULT_TYPE.getPreferredName(), resultType));
super(client, AnomalyDetectorsIndex.jobResultsAliasedName(jobId));
this.filterBuilder = new ResultsFilterBuilder(new TermsQueryBuilder(Result.RESULT_TYPE.getPreferredName(), resultType));
}
@Override
protected String getType() {
return Result.TYPE.getPreferredName();
/**
* Query documents whose timestamp is within the given time range
*
* @param startEpochMs the start time as epoch milliseconds (inclusive)
* @param endEpochMs the end time as epoch milliseconds (exclusive)
* @return the iterator itself
*/
public BatchedResultsIterator<T> timeRange(long startEpochMs, long endEpochMs) {
filterBuilder.timeRange(Result.TIMESTAMP.getPreferredName(), startEpochMs, endEpochMs);
return this;
}
/**
* Sets whether interim results should be included
*
* @param includeInterim Whether interim results should be included
*/
public BatchedResultsIterator<T> includeInterim(boolean includeInterim) {
filterBuilder.interim(includeInterim);
return this;
}
protected final QueryBuilder getQuery() {
return filterBuilder.build();
}
}

@ -440,27 +440,26 @@ public class JobProvider {
}
/**
* Returns a {@link BatchedDocumentsIterator} that allows querying
* Returns a {@link BatchedResultsIterator} that allows querying
* and iterating over a large number of buckets of the given job.
* The bucket and source indexes are returned by the iterator.
*
* @param jobId the id of the job for which buckets are requested
* @return a bucket {@link BatchedDocumentsIterator}
* @return a bucket {@link BatchedResultsIterator}
*/
public BatchedDocumentsIterator<Result<Bucket>> newBatchedBucketsIterator(String jobId) {
public BatchedResultsIterator<Bucket> newBatchedBucketsIterator(String jobId) {
return new BatchedBucketsIterator(client, jobId);
}
/**
* Returns a {@link BatchedDocumentsIterator} that allows querying
* Returns a {@link BatchedResultsIterator} that allows querying
* and iterating over a large number of records in the given job
* The records and source indexes are returned by the iterator.
*
* @param jobId the id of the job for which buckets are requested
* @return a record {@link BatchedDocumentsIterator}
* @return a record {@link BatchedResultsIterator}
*/
public BatchedDocumentsIterator<Result<AnomalyRecord>>
newBatchedRecordsIterator(String jobId) {
public BatchedResultsIterator<AnomalyRecord> newBatchedRecordsIterator(String jobId) {
return new BatchedRecordsIterator(client, jobId);
}
@ -672,13 +671,13 @@ public class JobProvider {
}
/**
* Returns a {@link BatchedDocumentsIterator} that allows querying
* Returns a {@link BatchedResultsIterator} that allows querying
* and iterating over a large number of influencers of the given job
*
* @param jobId the id of the job for which influencers are requested
* @return an influencer {@link BatchedDocumentsIterator}
* @return an influencer {@link BatchedResultsIterator}
*/
public BatchedDocumentsIterator<Result<Influencer>> newBatchedInfluencersIterator(String jobId) {
public BatchedResultsIterator<Influencer> newBatchedInfluencersIterator(String jobId) {
return new BatchedInfluencersIterator(client, jobId);
}

@ -5,16 +5,12 @@
*/
package org.elasticsearch.xpack.ml.datafeed.extractor.scroll;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHitField;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.ml.test.SearchHitBuilder;
import org.joda.time.DateTime;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.hamcrest.Matchers.equalTo;
@ -106,36 +102,4 @@ public class ExtractedFieldTests extends ESTestCase {
assertThat(timeField.value(hit), equalTo(new Object[] { 123456789L }));
}
static class SearchHitBuilder {
private final SearchHit hit;
private final Map<String, SearchHitField> fields;
SearchHitBuilder(int docId) {
hit = new SearchHit(docId);
fields = new HashMap<>();
}
SearchHitBuilder addField(String name, Object value) {
return addField(name, Arrays.asList(value));
}
SearchHitBuilder addField(String name, List<Object> values) {
fields.put(name, new SearchHitField(name, values));
return this;
}
SearchHitBuilder setSource(String sourceJson) {
hit.sourceRef(new BytesArray(sourceJson));
return this;
}
SearchHit build() {
if (!fields.isEmpty()) {
hit.fields(fields);
}
return hit;
}
}
}

@ -7,6 +7,7 @@ package org.elasticsearch.xpack.ml.datafeed.extractor.scroll;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.ml.test.SearchHitBuilder;
import org.joda.time.DateTime;
import java.util.Arrays;
@ -48,7 +49,7 @@ public class ExtractedFieldsTests extends ESTestCase {
}
public void testTimeFieldValue() {
SearchHit hit = new ExtractedFieldTests.SearchHitBuilder(1).addField("time", new DateTime(1000L)).build();
SearchHit hit = new SearchHitBuilder(1).addField("time", new DateTime(1000L)).build();
ExtractedFields extractedFields = new ExtractedFields(timeField, Arrays.asList(timeField));
@ -56,7 +57,7 @@ public class ExtractedFieldsTests extends ESTestCase {
}
public void testTimeFieldValueGivenEmptyArray() {
SearchHit hit = new ExtractedFieldTests.SearchHitBuilder(1).addField("time", Collections.emptyList()).build();
SearchHit hit = new SearchHitBuilder(1).addField("time", Collections.emptyList()).build();
ExtractedFields extractedFields = new ExtractedFields(timeField, Arrays.asList(timeField));
@ -64,7 +65,7 @@ public class ExtractedFieldsTests extends ESTestCase {
}
public void testTimeFieldValueGivenValueHasTwoElements() {
SearchHit hit = new ExtractedFieldTests.SearchHitBuilder(1).addField("time", Arrays.asList(1L, 2L)).build();
SearchHit hit = new SearchHitBuilder(1).addField("time", Arrays.asList(1L, 2L)).build();
ExtractedFields extractedFields = new ExtractedFields(timeField, Arrays.asList(timeField));
@ -72,7 +73,7 @@ public class ExtractedFieldsTests extends ESTestCase {
}
public void testTimeFieldValueGivenValueIsString() {
SearchHit hit = new ExtractedFieldTests.SearchHitBuilder(1).addField("time", "a string").build();
SearchHit hit = new SearchHitBuilder(1).addField("time", "a string").build();
ExtractedFields extractedFields = new ExtractedFields(timeField, Arrays.asList(timeField));

@ -7,6 +7,7 @@ package org.elasticsearch.xpack.ml.datafeed.extractor.scroll;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.ml.test.SearchHitBuilder;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
@ -24,7 +25,7 @@ public class SearchHitToJsonProcessorTests extends ESTestCase {
ExtractedField arrayField = ExtractedField.newField("array", ExtractedField.ExtractionMethod.DOC_VALUE);
ExtractedFields extractedFields = new ExtractedFields(timeField, Arrays.asList(timeField, missingField, singleField, arrayField));
SearchHit hit = new ExtractedFieldTests.SearchHitBuilder(8)
SearchHit hit = new SearchHitBuilder(8)
.addField("time", 1000L)
.addField("single", "a")
.addField("array", Arrays.asList("b", "c"))
@ -42,13 +43,13 @@ public class SearchHitToJsonProcessorTests extends ESTestCase {
ExtractedField arrayField = ExtractedField.newField("array", ExtractedField.ExtractionMethod.DOC_VALUE);
ExtractedFields extractedFields = new ExtractedFields(timeField, Arrays.asList(timeField, missingField, singleField, arrayField));
SearchHit hit1 = new ExtractedFieldTests.SearchHitBuilder(8)
SearchHit hit1 = new SearchHitBuilder(8)
.addField("time", 1000L)
.addField("single", "a1")
.addField("array", Arrays.asList("b1", "c1"))
.build();
SearchHit hit2 = new ExtractedFieldTests.SearchHitBuilder(8)
SearchHit hit2 = new SearchHitBuilder(8)
.addField("time", 2000L)
.addField("single", "a2")
.addField("array", Arrays.asList("b2", "c2"))

@ -0,0 +1,98 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.integration;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.xpack.ml.job.config.AnalysisConfig;
import org.elasticsearch.xpack.ml.job.config.DataDescription;
import org.elasticsearch.xpack.ml.job.config.Detector;
import org.elasticsearch.xpack.ml.job.config.Job;
import org.elasticsearch.xpack.ml.job.results.AnomalyRecord;
import org.junit.After;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan;
/**
* This is a minimal test to ensure renormalization takes place
*/
public class BasicRenormalizationIT extends MlNativeAutodetectIntegTestCase {
@After
public void tearDownData() throws Exception {
cleanUp();
}
public void test() throws Exception {
TimeValue bucketSpan = TimeValue.timeValueHours(1);
long startTime = 1491004800000L;
Job.Builder job = buildAndRegisterJob("basic-renormalization-it-job", bucketSpan);
openJob(job.getId());
postData(job.getId(), generateData(startTime, bucketSpan, 50,
bucketIndex -> {
if (bucketIndex == 35) {
// First anomaly is 10 events
return 10;
} else if (bucketIndex == 45) {
// Second anomaly is 100, should get the highest score and should bring the first score down
return 100;
} else {
return 1;
}
}).stream().collect(Collectors.joining()));
closeJob(job.getId());
List<AnomalyRecord> records = getRecords(job.getId());
assertThat(records.size(), equalTo(2));
AnomalyRecord laterRecord = records.get(0);
assertThat(laterRecord.getActual().get(0), equalTo(100.0));
AnomalyRecord earlierRecord = records.get(1);
assertThat(earlierRecord.getActual().get(0), equalTo(10.0));
assertThat(laterRecord.getRecordScore(), greaterThan(earlierRecord.getRecordScore()));
// This is the key assertion: if renormalization never happened then the record_score would
// be the same as the initial_record_score on the anomaly record that happened earlier
assertThat(earlierRecord.getInitialRecordScore(), greaterThan(earlierRecord.getRecordScore()));
}
private Job.Builder buildAndRegisterJob(String jobId, TimeValue bucketSpan) throws Exception {
Detector.Builder detector = new Detector.Builder("count", null);
AnalysisConfig.Builder analysisConfig = new AnalysisConfig.Builder(Arrays.asList(detector.build()));
analysisConfig.setBucketSpan(bucketSpan);
Job.Builder job = new Job.Builder(jobId);
job.setAnalysisConfig(analysisConfig);
DataDescription.Builder dataDescription = new DataDescription.Builder();
job.setDataDescription(dataDescription);
registerJob(job);
putJob(job);
return job;
}
private static List<String> generateData(long timestamp, TimeValue bucketSpan, int bucketCount,
Function<Integer, Integer> timeToCountFunction) throws IOException {
List<String> data = new ArrayList<>();
long now = timestamp;
for (int bucketIndex = 0; bucketIndex < bucketCount; bucketIndex++) {
for (int count = 0; count < timeToCountFunction.apply(bucketIndex); count++) {
Map<String, Object> record = new HashMap<>();
record.put("time", now);
data.add(createJsonRecord(record));
}
now += bucketSpan.getMillis();
}
return data;
}
}

@ -5,32 +5,36 @@
*/
package org.elasticsearch.xpack.ml.job.persistence;
import org.apache.lucene.util.LuceneTestCase;
import org.elasticsearch.action.ActionFuture;
import org.elasticsearch.action.search.ClearScrollRequestBuilder;
import org.elasticsearch.action.search.SearchRequestBuilder;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.search.SearchScrollRequestBuilder;
import org.elasticsearch.action.search.SearchScrollRequest;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.sort.SortBuilder;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.ml.test.SearchHitBuilder;
import org.junit.Before;
import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Deque;
import java.util.List;
import java.util.NoSuchElementException;
import static org.mockito.Matchers.any;
import static org.hamcrest.Matchers.equalTo;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
@LuceneTestCase.AwaitsFix(bugUrl = "https://github.com/elastic/prelert-legacy/issues/127")
public class BatchedDocumentsIteratorTests extends ESTestCase {
private static final String INDEX_NAME = ".ml-anomalies-foo";
private static final String SCROLL_ID = "someScrollId";
@ -39,6 +43,9 @@ public class BatchedDocumentsIteratorTests extends ESTestCase {
private TestIterator testIterator;
private ArgumentCaptor<SearchRequest> searchRequestCaptor = ArgumentCaptor.forClass(SearchRequest.class);
private ArgumentCaptor<SearchScrollRequest> searchScrollRequestCaptor = ArgumentCaptor.forClass(SearchScrollRequest.class);
@Before
public void setUpMocks() {
client = Mockito.mock(Client.class);
@ -54,10 +61,12 @@ public class BatchedDocumentsIteratorTests extends ESTestCase {
assertTrue(testIterator.next().isEmpty());
assertFalse(testIterator.hasNext());
assertTrue(wasScrollCleared);
assertSearchRequest();
assertSearchScrollRequests(0);
}
public void testCallingNextWhenHasNextIsFalseThrows() {
new ScrollResponsesMocker().addBatch("a", "b", "c").finishMock();
new ScrollResponsesMocker().addBatch(createJsonDoc("a"), createJsonDoc("b"), createJsonDoc("c")).finishMock();
testIterator.next();
assertFalse(testIterator.hasNext());
@ -65,55 +74,84 @@ public class BatchedDocumentsIteratorTests extends ESTestCase {
}
public void testQueryReturnsSingleBatch() {
new ScrollResponsesMocker().addBatch("a", "b", "c").finishMock();
new ScrollResponsesMocker().addBatch(createJsonDoc("a"), createJsonDoc("b"), createJsonDoc("c")).finishMock();
assertTrue(testIterator.hasNext());
Deque<String> batch = testIterator.next();
assertEquals(3, batch.size());
assertTrue(batch.containsAll(Arrays.asList("a", "b", "c")));
assertTrue(batch.containsAll(Arrays.asList(createJsonDoc("a"), createJsonDoc("b"), createJsonDoc("c"))));
assertFalse(testIterator.hasNext());
assertTrue(wasScrollCleared);
assertSearchRequest();
assertSearchScrollRequests(0);
}
public void testQueryReturnsThreeBatches() {
new ScrollResponsesMocker()
.addBatch("a", "b", "c")
.addBatch("d", "e")
.addBatch("f")
.addBatch(createJsonDoc("a"), createJsonDoc("b"), createJsonDoc("c"))
.addBatch(createJsonDoc("d"), createJsonDoc("e"))
.addBatch(createJsonDoc("f"))
.finishMock();
assertTrue(testIterator.hasNext());
Deque<String> batch = testIterator.next();
assertEquals(3, batch.size());
assertTrue(batch.containsAll(Arrays.asList("a", "b", "c")));
assertTrue(batch.containsAll(Arrays.asList(createJsonDoc("a"), createJsonDoc("b"), createJsonDoc("c"))));
batch = testIterator.next();
assertEquals(2, batch.size());
assertTrue(batch.containsAll(Arrays.asList("d", "e")));
assertTrue(batch.containsAll(Arrays.asList(createJsonDoc("d"), createJsonDoc("e"))));
batch = testIterator.next();
assertEquals(1, batch.size());
assertTrue(batch.containsAll(Arrays.asList("f")));
assertTrue(batch.containsAll(Collections.singletonList(createJsonDoc("f"))));
assertFalse(testIterator.hasNext());
assertTrue(wasScrollCleared);
assertSearchRequest();
assertSearchScrollRequests(2);
}
private String createJsonDoc(String value) {
return "{\"foo\":\"" + value + "\"}";
}
private void givenClearScrollRequest() {
ClearScrollRequestBuilder requestBuilder = mock(ClearScrollRequestBuilder.class);
when(client.prepareClearScroll()).thenReturn(requestBuilder);
when(requestBuilder.setScrollIds(Arrays.asList(SCROLL_ID))).thenReturn(requestBuilder);
when(requestBuilder.setScrollIds(Collections.singletonList(SCROLL_ID))).thenReturn(requestBuilder);
when(requestBuilder.get()).thenAnswer((invocation) -> {
wasScrollCleared = true;
return null;
});
}
private void assertSearchRequest() {
List<SearchRequest> searchRequests = searchRequestCaptor.getAllValues();
assertThat(searchRequests.size(), equalTo(1));
SearchRequest searchRequest = searchRequests.get(0);
assertThat(searchRequest.indices(), equalTo(new String[] {INDEX_NAME}));
assertThat(searchRequest.scroll().keepAlive(), equalTo(TimeValue.timeValueMinutes(5)));
assertThat(searchRequest.types().length, equalTo(0));
assertThat(searchRequest.source().query(), equalTo(QueryBuilders.matchAllQuery()));
}
private void assertSearchScrollRequests(int expectedCount) {
List<SearchScrollRequest> searchScrollRequests = searchScrollRequestCaptor.getAllValues();
assertThat(searchScrollRequests.size(), equalTo(expectedCount));
for (SearchScrollRequest request : searchScrollRequests) {
assertThat(request.scrollId(), equalTo(SCROLL_ID));
assertThat(request.scroll().keepAlive(), equalTo(TimeValue.timeValueMinutes(5)));
}
}
private class ScrollResponsesMocker {
private List<String[]> batches = new ArrayList<>();
private long totalHits = 0;
private List<SearchScrollRequestBuilder> nextRequestBuilders = new ArrayList<>();
private List<SearchResponse> responses = new ArrayList<>();
ScrollResponsesMocker addBatch(String... hits) {
totalHits += hits.length;
@ -121,6 +159,7 @@ public class BatchedDocumentsIteratorTests extends ESTestCase {
return this;
}
@SuppressWarnings("unchecked")
void finishMock() {
if (batches.isEmpty()) {
givenInitialResponse();
@ -130,39 +169,38 @@ public class BatchedDocumentsIteratorTests extends ESTestCase {
for (int i = 1; i < batches.size(); ++i) {
givenNextResponse(batches.get(i));
}
if (nextRequestBuilders.size() > 0) {
SearchScrollRequestBuilder first = nextRequestBuilders.get(0);
if (nextRequestBuilders.size() > 1) {
SearchScrollRequestBuilder[] rest = new SearchScrollRequestBuilder[batches.size() - 1];
for (int i = 1; i < nextRequestBuilders.size(); ++i) {
rest[i - 1] = nextRequestBuilders.get(i);
if (responses.size() > 0) {
ActionFuture<SearchResponse> first = wrapResponse(responses.get(0));
if (responses.size() > 1) {
List<ActionFuture> rest = new ArrayList<>();
for (int i = 1; i < responses.size(); ++i) {
rest.add(wrapResponse(responses.get(i)));
}
when(client.prepareSearchScroll(SCROLL_ID)).thenReturn(first, rest);
when(client.searchScroll(searchScrollRequestCaptor.capture())).thenReturn(
first, rest.toArray(new ActionFuture[rest.size() - 1]));
} else {
when(client.prepareSearchScroll(SCROLL_ID)).thenReturn(first);
when(client.searchScroll(searchScrollRequestCaptor.capture())).thenReturn(first);
}
}
}
private void givenInitialResponse(String... hits) {
SearchResponse searchResponse = createSearchResponseWithHits(hits);
SearchRequestBuilder requestBuilder = mock(SearchRequestBuilder.class);
when(client.prepareSearch(INDEX_NAME)).thenReturn(requestBuilder);
when(requestBuilder.setScroll("5m")).thenReturn(requestBuilder);
when(requestBuilder.setSize(10000)).thenReturn(requestBuilder);
when(requestBuilder.setTypes("String")).thenReturn(requestBuilder);
when(requestBuilder.setQuery(any(QueryBuilder.class))).thenReturn(requestBuilder);
when(requestBuilder.addSort(any(SortBuilder.class))).thenReturn(requestBuilder);
when(requestBuilder.get()).thenReturn(searchResponse);
ActionFuture<SearchResponse> future = wrapResponse(searchResponse);
when(future.actionGet()).thenReturn(searchResponse);
when(client.search(searchRequestCaptor.capture())).thenReturn(future);
}
@SuppressWarnings("unchecked")
private ActionFuture<SearchResponse> wrapResponse(SearchResponse searchResponse) {
ActionFuture<SearchResponse> future = mock(ActionFuture.class);
when(future.actionGet()).thenReturn(searchResponse);
return future;
}
private void givenNextResponse(String... hits) {
SearchResponse searchResponse = createSearchResponseWithHits(hits);
SearchScrollRequestBuilder requestBuilder = mock(SearchScrollRequestBuilder.class);
when(requestBuilder.setScrollId(SCROLL_ID)).thenReturn(requestBuilder);
when(requestBuilder.setScroll("5m")).thenReturn(requestBuilder);
when(requestBuilder.get()).thenReturn(searchResponse);
nextRequestBuilders.add(requestBuilder);
responses.add(createSearchResponseWithHits(hits));
}
private SearchResponse createSearchResponseWithHits(String... hits) {
@ -174,16 +212,11 @@ public class BatchedDocumentsIteratorTests extends ESTestCase {
}
private SearchHits createHits(String... values) {
SearchHits searchHits = mock(SearchHits.class);
List<SearchHit> hits = new ArrayList<>();
for (String value : values) {
SearchHit hit = mock(SearchHit.class);
when(hit.getSourceAsString()).thenReturn(value);
hits.add(hit);
hits.add(new SearchHitBuilder(randomInt()).setSource(value).build());
}
when(searchHits.getTotalHits()).thenReturn(totalHits);
when(searchHits.getHits()).thenReturn(hits.toArray(new SearchHit[hits.size()]));
return searchHits;
return new SearchHits(hits.toArray(new SearchHit[hits.size()]), totalHits, 1.0f);
}
}
@ -193,8 +226,8 @@ public class BatchedDocumentsIteratorTests extends ESTestCase {
}
@Override
protected String getType() {
return "String";
protected QueryBuilder getQuery() {
return QueryBuilders.matchAllQuery();
}
@Override
@ -202,5 +235,4 @@ public class BatchedDocumentsIteratorTests extends ESTestCase {
return hit.getSourceAsString();
}
}
}

@ -8,6 +8,7 @@ package org.elasticsearch.xpack.ml.job.persistence;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.xpack.ml.job.results.Result;
import java.util.Deque;
import java.util.List;
@ -15,34 +16,35 @@ import java.util.NoSuchElementException;
import static org.mockito.Mockito.mock;
public class MockBatchedDocumentsIterator<T> extends BatchedDocumentsIterator<T> {
private final List<Deque<T>> batches;
public class MockBatchedDocumentsIterator<T> extends BatchedResultsIterator<T> {
private final List<Deque<Result<T>>> batches;
private int index;
private boolean wasTimeRangeCalled;
private Boolean includeInterim;
private Boolean requireIncludeInterim;
public MockBatchedDocumentsIterator(List<Deque<T>> batches) {
super(mock(Client.class), "foo");
public MockBatchedDocumentsIterator(List<Deque<Result<T>>> batches, String resultType) {
super(mock(Client.class), "foo", resultType);
this.batches = batches;
index = 0;
wasTimeRangeCalled = false;
}
@Override
public BatchedDocumentsIterator<T> timeRange(long startEpochMs, long endEpochMs) {
public BatchedResultsIterator<T> timeRange(long startEpochMs, long endEpochMs) {
wasTimeRangeCalled = true;
return this;
}
@Override
public BatchedDocumentsIterator<T> includeInterim(boolean includeInterim) {
public BatchedResultsIterator<T> includeInterim(boolean includeInterim) {
this.includeInterim = includeInterim;
return this;
}
@Override
public Deque<T> next() {
public Deque<Result<T>> next() {
if (requireIncludeInterim != null && requireIncludeInterim != includeInterim) {
throw new IllegalStateException("Required include interim value [" + requireIncludeInterim + "]; actual was ["
+ includeInterim + "]");
@ -54,12 +56,7 @@ public class MockBatchedDocumentsIterator<T> extends BatchedDocumentsIterator<T>
}
@Override
protected String getType() {
return null;
}
@Override
protected T map(SearchHit hit) {
protected Result<T> map(SearchHit hit) {
return null;
}

@ -209,8 +209,8 @@ public class ScoresUpdaterTests extends ESTestCase {
List<Deque<Result<AnomalyRecord>>> recordBatches = new ArrayList<>();
recordBatches.add(new ArrayDeque<>(records));
MockBatchedDocumentsIterator<Result<AnomalyRecord>> recordIter =
new MockBatchedDocumentsIterator<>(recordBatches);
MockBatchedDocumentsIterator<AnomalyRecord> recordIter = new MockBatchedDocumentsIterator<>(
recordBatches, AnomalyRecord.RESULT_TYPE_VALUE);
recordIter.requireIncludeInterim(false);
when(jobProvider.newBatchedRecordsIterator(JOB_ID)).thenReturn(recordIter);
@ -376,8 +376,7 @@ public class ScoresUpdaterTests extends ESTestCase {
batchesWithIndex.add(queueWithIndex);
}
MockBatchedDocumentsIterator<Result<Bucket>> bucketIter =
new MockBatchedDocumentsIterator<>(batchesWithIndex);
MockBatchedDocumentsIterator<Bucket> bucketIter = new MockBatchedDocumentsIterator<>(batchesWithIndex, Bucket.RESULT_TYPE_VALUE);
bucketIter.requireIncludeInterim(false);
when(jobProvider.newBatchedBucketsIterator(JOB_ID)).thenReturn(bucketIter);
}
@ -394,8 +393,8 @@ public class ScoresUpdaterTests extends ESTestCase {
}
batches.add(batch);
MockBatchedDocumentsIterator<Result<AnomalyRecord>> recordIter =
new MockBatchedDocumentsIterator<>(batches);
MockBatchedDocumentsIterator<AnomalyRecord> recordIter = new MockBatchedDocumentsIterator<>(
batches, AnomalyRecord.RESULT_TYPE_VALUE);
recordIter.requireIncludeInterim(false);
when(jobProvider.newBatchedRecordsIterator(JOB_ID)).thenReturn(recordIter);
}
@ -411,7 +410,7 @@ public class ScoresUpdaterTests extends ESTestCase {
queue.add(new Result<>("foo", inf));
}
batches.add(queue);
MockBatchedDocumentsIterator<Result<Influencer>> iterator = new MockBatchedDocumentsIterator<>(batches);
MockBatchedDocumentsIterator<Influencer> iterator = new MockBatchedDocumentsIterator<>(batches, Influencer.RESULT_TYPE_VALUE);
iterator.requireIncludeInterim(false);
when(jobProvider.newBatchedInfluencersIterator(JOB_ID)).thenReturn(iterator);
}

@ -0,0 +1,50 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.test;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHitField;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* Utility class to build {@link SearchHit} in tests
*/
public class SearchHitBuilder {
private final SearchHit hit;
private final Map<String, SearchHitField> fields;
public SearchHitBuilder(int docId) {
hit = new SearchHit(docId);
fields = new HashMap<>();
}
public SearchHitBuilder addField(String name, Object value) {
return addField(name, Arrays.asList(value));
}
public SearchHitBuilder addField(String name, List<Object> values) {
fields.put(name, new SearchHitField(name, values));
return this;
}
public SearchHitBuilder setSource(String sourceJson) {
hit.sourceRef(new BytesArray(sourceJson));
return this;
}
public SearchHit build() {
if (!fields.isEmpty()) {
hit.fields(fields);
}
return hit;
}
}