mirror of
https://github.com/honeymoose/OpenSearch.git
synced 2025-02-17 10:25:15 +00:00
[ML] Fix search that fetches results for renormalization (elastic/x-pack-elasticsearch#1556)
The commit that converted the results index into single type broke the search for fetching results for renormalization. This commit fixes that. Original commit: elastic/x-pack-elasticsearch@1ca7517adc
This commit is contained in:
parent
9b655ce6f1
commit
1e86f55746
@ -15,7 +15,6 @@ import org.elasticsearch.index.query.QueryBuilder;
|
||||
import org.elasticsearch.search.SearchHit;
|
||||
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||
import org.elasticsearch.search.sort.SortBuilders;
|
||||
import org.elasticsearch.xpack.ml.job.results.Result;
|
||||
|
||||
import java.util.ArrayDeque;
|
||||
import java.util.Arrays;
|
||||
@ -35,49 +34,17 @@ public abstract class BatchedDocumentsIterator<T> {
|
||||
|
||||
private final Client client;
|
||||
private final String index;
|
||||
private final ResultsFilterBuilder filterBuilder;
|
||||
private volatile long count;
|
||||
private volatile long totalHits;
|
||||
private volatile String scrollId;
|
||||
private volatile boolean isScrollInitialised;
|
||||
|
||||
public BatchedDocumentsIterator(Client client, String index) {
|
||||
this(client, index, new ResultsFilterBuilder());
|
||||
}
|
||||
|
||||
protected BatchedDocumentsIterator(Client client, String index, QueryBuilder queryBuilder) {
|
||||
this(client, index, new ResultsFilterBuilder(queryBuilder));
|
||||
}
|
||||
|
||||
private BatchedDocumentsIterator(Client client, String index, ResultsFilterBuilder resultsFilterBuilder) {
|
||||
protected BatchedDocumentsIterator(Client client, String index) {
|
||||
this.client = Objects.requireNonNull(client);
|
||||
this.index = Objects.requireNonNull(index);
|
||||
totalHits = 0;
|
||||
count = 0;
|
||||
filterBuilder = Objects.requireNonNull(resultsFilterBuilder);
|
||||
isScrollInitialised = false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Query documents whose timestamp is within the given time range
|
||||
*
|
||||
* @param startEpochMs the start time as epoch milliseconds (inclusive)
|
||||
* @param endEpochMs the end time as epoch milliseconds (exclusive)
|
||||
* @return the iterator itself
|
||||
*/
|
||||
public BatchedDocumentsIterator<T> timeRange(long startEpochMs, long endEpochMs) {
|
||||
filterBuilder.timeRange(Result.TIMESTAMP.getPreferredName(), startEpochMs, endEpochMs);
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets whether interim results should be included
|
||||
*
|
||||
* @param includeInterim Whether interim results should be included
|
||||
*/
|
||||
public BatchedDocumentsIterator<T> includeInterim(boolean includeInterim) {
|
||||
filterBuilder.interim(includeInterim);
|
||||
return this;
|
||||
this.totalHits = 0;
|
||||
this.count = 0;
|
||||
this.isScrollInitialised = false;
|
||||
}
|
||||
|
||||
/**
|
||||
@ -118,17 +85,16 @@ public abstract class BatchedDocumentsIterator<T> {
|
||||
}
|
||||
|
||||
private SearchResponse initScroll() {
|
||||
LOGGER.trace("ES API CALL: search all of type {} from index {}", getType(), index);
|
||||
LOGGER.trace("ES API CALL: search index {}", index);
|
||||
|
||||
isScrollInitialised = true;
|
||||
|
||||
SearchRequest searchRequest = new SearchRequest(index);
|
||||
searchRequest.indicesOptions(JobProvider.addIgnoreUnavailable(SearchRequest.DEFAULT_INDICES_OPTIONS));
|
||||
searchRequest.types(getType());
|
||||
searchRequest.scroll(CONTEXT_ALIVE_DURATION);
|
||||
searchRequest.source(new SearchSourceBuilder()
|
||||
.size(BATCH_SIZE)
|
||||
.query(filterBuilder.build())
|
||||
.query(getQuery())
|
||||
.sort(SortBuilders.fieldSort(ElasticsearchMappings.ES_DOC)));
|
||||
|
||||
SearchResponse searchResponse = client.search(searchRequest).actionGet();
|
||||
@ -155,7 +121,11 @@ public abstract class BatchedDocumentsIterator<T> {
|
||||
return results;
|
||||
}
|
||||
|
||||
protected abstract String getType();
|
||||
/**
|
||||
* Get the query to use for the search
|
||||
* @return the search query
|
||||
*/
|
||||
protected abstract QueryBuilder getQuery();
|
||||
|
||||
/**
|
||||
* Maps the search hit to the document type
|
||||
|
@ -6,18 +6,42 @@
|
||||
package org.elasticsearch.xpack.ml.job.persistence;
|
||||
|
||||
import org.elasticsearch.client.Client;
|
||||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
import org.elasticsearch.index.query.TermsQueryBuilder;
|
||||
import org.elasticsearch.xpack.ml.job.results.Result;
|
||||
|
||||
public abstract class BatchedResultsIterator<T> extends BatchedDocumentsIterator<Result<T>> {
|
||||
|
||||
private final ResultsFilterBuilder filterBuilder;
|
||||
|
||||
public BatchedResultsIterator(Client client, String jobId, String resultType) {
|
||||
super(client, AnomalyDetectorsIndex.jobResultsAliasedName(jobId),
|
||||
new TermsQueryBuilder(Result.RESULT_TYPE.getPreferredName(), resultType));
|
||||
super(client, AnomalyDetectorsIndex.jobResultsAliasedName(jobId));
|
||||
this.filterBuilder = new ResultsFilterBuilder(new TermsQueryBuilder(Result.RESULT_TYPE.getPreferredName(), resultType));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected String getType() {
|
||||
return Result.TYPE.getPreferredName();
|
||||
/**
|
||||
* Query documents whose timestamp is within the given time range
|
||||
*
|
||||
* @param startEpochMs the start time as epoch milliseconds (inclusive)
|
||||
* @param endEpochMs the end time as epoch milliseconds (exclusive)
|
||||
* @return the iterator itself
|
||||
*/
|
||||
public BatchedResultsIterator<T> timeRange(long startEpochMs, long endEpochMs) {
|
||||
filterBuilder.timeRange(Result.TIMESTAMP.getPreferredName(), startEpochMs, endEpochMs);
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets whether interim results should be included
|
||||
*
|
||||
* @param includeInterim Whether interim results should be included
|
||||
*/
|
||||
public BatchedResultsIterator<T> includeInterim(boolean includeInterim) {
|
||||
filterBuilder.interim(includeInterim);
|
||||
return this;
|
||||
}
|
||||
|
||||
protected final QueryBuilder getQuery() {
|
||||
return filterBuilder.build();
|
||||
}
|
||||
}
|
||||
|
@ -440,27 +440,26 @@ public class JobProvider {
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a {@link BatchedDocumentsIterator} that allows querying
|
||||
* Returns a {@link BatchedResultsIterator} that allows querying
|
||||
* and iterating over a large number of buckets of the given job.
|
||||
* The bucket and source indexes are returned by the iterator.
|
||||
*
|
||||
* @param jobId the id of the job for which buckets are requested
|
||||
* @return a bucket {@link BatchedDocumentsIterator}
|
||||
* @return a bucket {@link BatchedResultsIterator}
|
||||
*/
|
||||
public BatchedDocumentsIterator<Result<Bucket>> newBatchedBucketsIterator(String jobId) {
|
||||
public BatchedResultsIterator<Bucket> newBatchedBucketsIterator(String jobId) {
|
||||
return new BatchedBucketsIterator(client, jobId);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a {@link BatchedDocumentsIterator} that allows querying
|
||||
* Returns a {@link BatchedResultsIterator} that allows querying
|
||||
* and iterating over a large number of records in the given job
|
||||
* The records and source indexes are returned by the iterator.
|
||||
*
|
||||
* @param jobId the id of the job for which buckets are requested
|
||||
* @return a record {@link BatchedDocumentsIterator}
|
||||
* @return a record {@link BatchedResultsIterator}
|
||||
*/
|
||||
public BatchedDocumentsIterator<Result<AnomalyRecord>>
|
||||
newBatchedRecordsIterator(String jobId) {
|
||||
public BatchedResultsIterator<AnomalyRecord> newBatchedRecordsIterator(String jobId) {
|
||||
return new BatchedRecordsIterator(client, jobId);
|
||||
}
|
||||
|
||||
@ -672,13 +671,13 @@ public class JobProvider {
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a {@link BatchedDocumentsIterator} that allows querying
|
||||
* Returns a {@link BatchedResultsIterator} that allows querying
|
||||
* and iterating over a large number of influencers of the given job
|
||||
*
|
||||
* @param jobId the id of the job for which influencers are requested
|
||||
* @return an influencer {@link BatchedDocumentsIterator}
|
||||
* @return an influencer {@link BatchedResultsIterator}
|
||||
*/
|
||||
public BatchedDocumentsIterator<Result<Influencer>> newBatchedInfluencersIterator(String jobId) {
|
||||
public BatchedResultsIterator<Influencer> newBatchedInfluencersIterator(String jobId) {
|
||||
return new BatchedInfluencersIterator(client, jobId);
|
||||
}
|
||||
|
||||
|
@ -5,16 +5,12 @@
|
||||
*/
|
||||
package org.elasticsearch.xpack.ml.datafeed.extractor.scroll;
|
||||
|
||||
import org.elasticsearch.common.bytes.BytesArray;
|
||||
import org.elasticsearch.search.SearchHit;
|
||||
import org.elasticsearch.search.SearchHitField;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.ml.test.SearchHitBuilder;
|
||||
import org.joda.time.DateTime;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
@ -106,36 +102,4 @@ public class ExtractedFieldTests extends ESTestCase {
|
||||
|
||||
assertThat(timeField.value(hit), equalTo(new Object[] { 123456789L }));
|
||||
}
|
||||
|
||||
static class SearchHitBuilder {
|
||||
|
||||
private final SearchHit hit;
|
||||
private final Map<String, SearchHitField> fields;
|
||||
|
||||
SearchHitBuilder(int docId) {
|
||||
hit = new SearchHit(docId);
|
||||
fields = new HashMap<>();
|
||||
}
|
||||
|
||||
SearchHitBuilder addField(String name, Object value) {
|
||||
return addField(name, Arrays.asList(value));
|
||||
}
|
||||
|
||||
SearchHitBuilder addField(String name, List<Object> values) {
|
||||
fields.put(name, new SearchHitField(name, values));
|
||||
return this;
|
||||
}
|
||||
|
||||
SearchHitBuilder setSource(String sourceJson) {
|
||||
hit.sourceRef(new BytesArray(sourceJson));
|
||||
return this;
|
||||
}
|
||||
|
||||
SearchHit build() {
|
||||
if (!fields.isEmpty()) {
|
||||
hit.fields(fields);
|
||||
}
|
||||
return hit;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -7,6 +7,7 @@ package org.elasticsearch.xpack.ml.datafeed.extractor.scroll;
|
||||
|
||||
import org.elasticsearch.search.SearchHit;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.ml.test.SearchHitBuilder;
|
||||
import org.joda.time.DateTime;
|
||||
|
||||
import java.util.Arrays;
|
||||
@ -48,7 +49,7 @@ public class ExtractedFieldsTests extends ESTestCase {
|
||||
}
|
||||
|
||||
public void testTimeFieldValue() {
|
||||
SearchHit hit = new ExtractedFieldTests.SearchHitBuilder(1).addField("time", new DateTime(1000L)).build();
|
||||
SearchHit hit = new SearchHitBuilder(1).addField("time", new DateTime(1000L)).build();
|
||||
|
||||
ExtractedFields extractedFields = new ExtractedFields(timeField, Arrays.asList(timeField));
|
||||
|
||||
@ -56,7 +57,7 @@ public class ExtractedFieldsTests extends ESTestCase {
|
||||
}
|
||||
|
||||
public void testTimeFieldValueGivenEmptyArray() {
|
||||
SearchHit hit = new ExtractedFieldTests.SearchHitBuilder(1).addField("time", Collections.emptyList()).build();
|
||||
SearchHit hit = new SearchHitBuilder(1).addField("time", Collections.emptyList()).build();
|
||||
|
||||
ExtractedFields extractedFields = new ExtractedFields(timeField, Arrays.asList(timeField));
|
||||
|
||||
@ -64,7 +65,7 @@ public class ExtractedFieldsTests extends ESTestCase {
|
||||
}
|
||||
|
||||
public void testTimeFieldValueGivenValueHasTwoElements() {
|
||||
SearchHit hit = new ExtractedFieldTests.SearchHitBuilder(1).addField("time", Arrays.asList(1L, 2L)).build();
|
||||
SearchHit hit = new SearchHitBuilder(1).addField("time", Arrays.asList(1L, 2L)).build();
|
||||
|
||||
ExtractedFields extractedFields = new ExtractedFields(timeField, Arrays.asList(timeField));
|
||||
|
||||
@ -72,7 +73,7 @@ public class ExtractedFieldsTests extends ESTestCase {
|
||||
}
|
||||
|
||||
public void testTimeFieldValueGivenValueIsString() {
|
||||
SearchHit hit = new ExtractedFieldTests.SearchHitBuilder(1).addField("time", "a string").build();
|
||||
SearchHit hit = new SearchHitBuilder(1).addField("time", "a string").build();
|
||||
|
||||
ExtractedFields extractedFields = new ExtractedFields(timeField, Arrays.asList(timeField));
|
||||
|
||||
|
@ -7,6 +7,7 @@ package org.elasticsearch.xpack.ml.datafeed.extractor.scroll;
|
||||
|
||||
import org.elasticsearch.search.SearchHit;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.ml.test.SearchHitBuilder;
|
||||
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.io.IOException;
|
||||
@ -24,7 +25,7 @@ public class SearchHitToJsonProcessorTests extends ESTestCase {
|
||||
ExtractedField arrayField = ExtractedField.newField("array", ExtractedField.ExtractionMethod.DOC_VALUE);
|
||||
ExtractedFields extractedFields = new ExtractedFields(timeField, Arrays.asList(timeField, missingField, singleField, arrayField));
|
||||
|
||||
SearchHit hit = new ExtractedFieldTests.SearchHitBuilder(8)
|
||||
SearchHit hit = new SearchHitBuilder(8)
|
||||
.addField("time", 1000L)
|
||||
.addField("single", "a")
|
||||
.addField("array", Arrays.asList("b", "c"))
|
||||
@ -42,13 +43,13 @@ public class SearchHitToJsonProcessorTests extends ESTestCase {
|
||||
ExtractedField arrayField = ExtractedField.newField("array", ExtractedField.ExtractionMethod.DOC_VALUE);
|
||||
ExtractedFields extractedFields = new ExtractedFields(timeField, Arrays.asList(timeField, missingField, singleField, arrayField));
|
||||
|
||||
SearchHit hit1 = new ExtractedFieldTests.SearchHitBuilder(8)
|
||||
SearchHit hit1 = new SearchHitBuilder(8)
|
||||
.addField("time", 1000L)
|
||||
.addField("single", "a1")
|
||||
.addField("array", Arrays.asList("b1", "c1"))
|
||||
.build();
|
||||
|
||||
SearchHit hit2 = new ExtractedFieldTests.SearchHitBuilder(8)
|
||||
SearchHit hit2 = new SearchHitBuilder(8)
|
||||
.addField("time", 2000L)
|
||||
.addField("single", "a2")
|
||||
.addField("array", Arrays.asList("b2", "c2"))
|
||||
|
@ -0,0 +1,98 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.ml.integration;
|
||||
|
||||
import org.elasticsearch.common.unit.TimeValue;
|
||||
import org.elasticsearch.xpack.ml.job.config.AnalysisConfig;
|
||||
import org.elasticsearch.xpack.ml.job.config.DataDescription;
|
||||
import org.elasticsearch.xpack.ml.job.config.Detector;
|
||||
import org.elasticsearch.xpack.ml.job.config.Job;
|
||||
import org.elasticsearch.xpack.ml.job.results.AnomalyRecord;
|
||||
import org.junit.After;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.greaterThan;
|
||||
|
||||
/**
|
||||
* This is a minimal test to ensure renormalization takes place
|
||||
*/
|
||||
public class BasicRenormalizationIT extends MlNativeAutodetectIntegTestCase {
|
||||
|
||||
@After
|
||||
public void tearDownData() throws Exception {
|
||||
cleanUp();
|
||||
}
|
||||
|
||||
public void test() throws Exception {
|
||||
TimeValue bucketSpan = TimeValue.timeValueHours(1);
|
||||
long startTime = 1491004800000L;
|
||||
|
||||
Job.Builder job = buildAndRegisterJob("basic-renormalization-it-job", bucketSpan);
|
||||
openJob(job.getId());
|
||||
postData(job.getId(), generateData(startTime, bucketSpan, 50,
|
||||
bucketIndex -> {
|
||||
if (bucketIndex == 35) {
|
||||
// First anomaly is 10 events
|
||||
return 10;
|
||||
} else if (bucketIndex == 45) {
|
||||
// Second anomaly is 100, should get the highest score and should bring the first score down
|
||||
return 100;
|
||||
} else {
|
||||
return 1;
|
||||
}
|
||||
}).stream().collect(Collectors.joining()));
|
||||
closeJob(job.getId());
|
||||
|
||||
List<AnomalyRecord> records = getRecords(job.getId());
|
||||
assertThat(records.size(), equalTo(2));
|
||||
AnomalyRecord laterRecord = records.get(0);
|
||||
assertThat(laterRecord.getActual().get(0), equalTo(100.0));
|
||||
AnomalyRecord earlierRecord = records.get(1);
|
||||
assertThat(earlierRecord.getActual().get(0), equalTo(10.0));
|
||||
assertThat(laterRecord.getRecordScore(), greaterThan(earlierRecord.getRecordScore()));
|
||||
|
||||
// This is the key assertion: if renormalization never happened then the record_score would
|
||||
// be the same as the initial_record_score on the anomaly record that happened earlier
|
||||
assertThat(earlierRecord.getInitialRecordScore(), greaterThan(earlierRecord.getRecordScore()));
|
||||
}
|
||||
|
||||
private Job.Builder buildAndRegisterJob(String jobId, TimeValue bucketSpan) throws Exception {
|
||||
Detector.Builder detector = new Detector.Builder("count", null);
|
||||
AnalysisConfig.Builder analysisConfig = new AnalysisConfig.Builder(Arrays.asList(detector.build()));
|
||||
analysisConfig.setBucketSpan(bucketSpan);
|
||||
Job.Builder job = new Job.Builder(jobId);
|
||||
job.setAnalysisConfig(analysisConfig);
|
||||
DataDescription.Builder dataDescription = new DataDescription.Builder();
|
||||
job.setDataDescription(dataDescription);
|
||||
registerJob(job);
|
||||
putJob(job);
|
||||
return job;
|
||||
}
|
||||
|
||||
private static List<String> generateData(long timestamp, TimeValue bucketSpan, int bucketCount,
|
||||
Function<Integer, Integer> timeToCountFunction) throws IOException {
|
||||
List<String> data = new ArrayList<>();
|
||||
long now = timestamp;
|
||||
for (int bucketIndex = 0; bucketIndex < bucketCount; bucketIndex++) {
|
||||
for (int count = 0; count < timeToCountFunction.apply(bucketIndex); count++) {
|
||||
Map<String, Object> record = new HashMap<>();
|
||||
record.put("time", now);
|
||||
data.add(createJsonRecord(record));
|
||||
}
|
||||
now += bucketSpan.getMillis();
|
||||
}
|
||||
return data;
|
||||
}
|
||||
}
|
@ -5,32 +5,36 @@
|
||||
*/
|
||||
package org.elasticsearch.xpack.ml.job.persistence;
|
||||
|
||||
import org.apache.lucene.util.LuceneTestCase;
|
||||
import org.elasticsearch.action.ActionFuture;
|
||||
import org.elasticsearch.action.search.ClearScrollRequestBuilder;
|
||||
import org.elasticsearch.action.search.SearchRequestBuilder;
|
||||
import org.elasticsearch.action.search.SearchRequest;
|
||||
import org.elasticsearch.action.search.SearchResponse;
|
||||
import org.elasticsearch.action.search.SearchScrollRequestBuilder;
|
||||
import org.elasticsearch.action.search.SearchScrollRequest;
|
||||
import org.elasticsearch.client.Client;
|
||||
import org.elasticsearch.common.unit.TimeValue;
|
||||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.search.SearchHit;
|
||||
import org.elasticsearch.search.SearchHits;
|
||||
import org.elasticsearch.search.sort.SortBuilder;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.ml.test.SearchHitBuilder;
|
||||
import org.junit.Before;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
import org.mockito.Mockito;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.Deque;
|
||||
import java.util.List;
|
||||
import java.util.NoSuchElementException;
|
||||
|
||||
import static org.mockito.Matchers.any;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
@LuceneTestCase.AwaitsFix(bugUrl = "https://github.com/elastic/prelert-legacy/issues/127")
|
||||
public class BatchedDocumentsIteratorTests extends ESTestCase {
|
||||
|
||||
private static final String INDEX_NAME = ".ml-anomalies-foo";
|
||||
private static final String SCROLL_ID = "someScrollId";
|
||||
|
||||
@ -39,6 +43,9 @@ public class BatchedDocumentsIteratorTests extends ESTestCase {
|
||||
|
||||
private TestIterator testIterator;
|
||||
|
||||
private ArgumentCaptor<SearchRequest> searchRequestCaptor = ArgumentCaptor.forClass(SearchRequest.class);
|
||||
private ArgumentCaptor<SearchScrollRequest> searchScrollRequestCaptor = ArgumentCaptor.forClass(SearchScrollRequest.class);
|
||||
|
||||
@Before
|
||||
public void setUpMocks() {
|
||||
client = Mockito.mock(Client.class);
|
||||
@ -54,10 +61,12 @@ public class BatchedDocumentsIteratorTests extends ESTestCase {
|
||||
assertTrue(testIterator.next().isEmpty());
|
||||
assertFalse(testIterator.hasNext());
|
||||
assertTrue(wasScrollCleared);
|
||||
assertSearchRequest();
|
||||
assertSearchScrollRequests(0);
|
||||
}
|
||||
|
||||
public void testCallingNextWhenHasNextIsFalseThrows() {
|
||||
new ScrollResponsesMocker().addBatch("a", "b", "c").finishMock();
|
||||
new ScrollResponsesMocker().addBatch(createJsonDoc("a"), createJsonDoc("b"), createJsonDoc("c")).finishMock();
|
||||
testIterator.next();
|
||||
assertFalse(testIterator.hasNext());
|
||||
|
||||
@ -65,55 +74,84 @@ public class BatchedDocumentsIteratorTests extends ESTestCase {
|
||||
}
|
||||
|
||||
public void testQueryReturnsSingleBatch() {
|
||||
new ScrollResponsesMocker().addBatch("a", "b", "c").finishMock();
|
||||
new ScrollResponsesMocker().addBatch(createJsonDoc("a"), createJsonDoc("b"), createJsonDoc("c")).finishMock();
|
||||
|
||||
assertTrue(testIterator.hasNext());
|
||||
Deque<String> batch = testIterator.next();
|
||||
assertEquals(3, batch.size());
|
||||
assertTrue(batch.containsAll(Arrays.asList("a", "b", "c")));
|
||||
assertTrue(batch.containsAll(Arrays.asList(createJsonDoc("a"), createJsonDoc("b"), createJsonDoc("c"))));
|
||||
assertFalse(testIterator.hasNext());
|
||||
assertTrue(wasScrollCleared);
|
||||
|
||||
assertSearchRequest();
|
||||
assertSearchScrollRequests(0);
|
||||
}
|
||||
|
||||
public void testQueryReturnsThreeBatches() {
|
||||
new ScrollResponsesMocker()
|
||||
.addBatch("a", "b", "c")
|
||||
.addBatch("d", "e")
|
||||
.addBatch("f")
|
||||
.addBatch(createJsonDoc("a"), createJsonDoc("b"), createJsonDoc("c"))
|
||||
.addBatch(createJsonDoc("d"), createJsonDoc("e"))
|
||||
.addBatch(createJsonDoc("f"))
|
||||
.finishMock();
|
||||
|
||||
assertTrue(testIterator.hasNext());
|
||||
|
||||
Deque<String> batch = testIterator.next();
|
||||
assertEquals(3, batch.size());
|
||||
assertTrue(batch.containsAll(Arrays.asList("a", "b", "c")));
|
||||
assertTrue(batch.containsAll(Arrays.asList(createJsonDoc("a"), createJsonDoc("b"), createJsonDoc("c"))));
|
||||
|
||||
batch = testIterator.next();
|
||||
assertEquals(2, batch.size());
|
||||
assertTrue(batch.containsAll(Arrays.asList("d", "e")));
|
||||
assertTrue(batch.containsAll(Arrays.asList(createJsonDoc("d"), createJsonDoc("e"))));
|
||||
|
||||
batch = testIterator.next();
|
||||
assertEquals(1, batch.size());
|
||||
assertTrue(batch.containsAll(Arrays.asList("f")));
|
||||
assertTrue(batch.containsAll(Collections.singletonList(createJsonDoc("f"))));
|
||||
|
||||
assertFalse(testIterator.hasNext());
|
||||
assertTrue(wasScrollCleared);
|
||||
|
||||
assertSearchRequest();
|
||||
assertSearchScrollRequests(2);
|
||||
}
|
||||
|
||||
private String createJsonDoc(String value) {
|
||||
return "{\"foo\":\"" + value + "\"}";
|
||||
}
|
||||
|
||||
private void givenClearScrollRequest() {
|
||||
ClearScrollRequestBuilder requestBuilder = mock(ClearScrollRequestBuilder.class);
|
||||
when(client.prepareClearScroll()).thenReturn(requestBuilder);
|
||||
when(requestBuilder.setScrollIds(Arrays.asList(SCROLL_ID))).thenReturn(requestBuilder);
|
||||
when(requestBuilder.setScrollIds(Collections.singletonList(SCROLL_ID))).thenReturn(requestBuilder);
|
||||
when(requestBuilder.get()).thenAnswer((invocation) -> {
|
||||
wasScrollCleared = true;
|
||||
return null;
|
||||
});
|
||||
}
|
||||
|
||||
private void assertSearchRequest() {
|
||||
List<SearchRequest> searchRequests = searchRequestCaptor.getAllValues();
|
||||
assertThat(searchRequests.size(), equalTo(1));
|
||||
SearchRequest searchRequest = searchRequests.get(0);
|
||||
assertThat(searchRequest.indices(), equalTo(new String[] {INDEX_NAME}));
|
||||
assertThat(searchRequest.scroll().keepAlive(), equalTo(TimeValue.timeValueMinutes(5)));
|
||||
assertThat(searchRequest.types().length, equalTo(0));
|
||||
assertThat(searchRequest.source().query(), equalTo(QueryBuilders.matchAllQuery()));
|
||||
}
|
||||
|
||||
private void assertSearchScrollRequests(int expectedCount) {
|
||||
List<SearchScrollRequest> searchScrollRequests = searchScrollRequestCaptor.getAllValues();
|
||||
assertThat(searchScrollRequests.size(), equalTo(expectedCount));
|
||||
for (SearchScrollRequest request : searchScrollRequests) {
|
||||
assertThat(request.scrollId(), equalTo(SCROLL_ID));
|
||||
assertThat(request.scroll().keepAlive(), equalTo(TimeValue.timeValueMinutes(5)));
|
||||
}
|
||||
}
|
||||
|
||||
private class ScrollResponsesMocker {
|
||||
private List<String[]> batches = new ArrayList<>();
|
||||
private long totalHits = 0;
|
||||
private List<SearchScrollRequestBuilder> nextRequestBuilders = new ArrayList<>();
|
||||
private List<SearchResponse> responses = new ArrayList<>();
|
||||
|
||||
ScrollResponsesMocker addBatch(String... hits) {
|
||||
totalHits += hits.length;
|
||||
@ -121,6 +159,7 @@ public class BatchedDocumentsIteratorTests extends ESTestCase {
|
||||
return this;
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
void finishMock() {
|
||||
if (batches.isEmpty()) {
|
||||
givenInitialResponse();
|
||||
@ -130,39 +169,38 @@ public class BatchedDocumentsIteratorTests extends ESTestCase {
|
||||
for (int i = 1; i < batches.size(); ++i) {
|
||||
givenNextResponse(batches.get(i));
|
||||
}
|
||||
if (nextRequestBuilders.size() > 0) {
|
||||
SearchScrollRequestBuilder first = nextRequestBuilders.get(0);
|
||||
if (nextRequestBuilders.size() > 1) {
|
||||
SearchScrollRequestBuilder[] rest = new SearchScrollRequestBuilder[batches.size() - 1];
|
||||
for (int i = 1; i < nextRequestBuilders.size(); ++i) {
|
||||
rest[i - 1] = nextRequestBuilders.get(i);
|
||||
if (responses.size() > 0) {
|
||||
ActionFuture<SearchResponse> first = wrapResponse(responses.get(0));
|
||||
if (responses.size() > 1) {
|
||||
List<ActionFuture> rest = new ArrayList<>();
|
||||
for (int i = 1; i < responses.size(); ++i) {
|
||||
rest.add(wrapResponse(responses.get(i)));
|
||||
}
|
||||
when(client.prepareSearchScroll(SCROLL_ID)).thenReturn(first, rest);
|
||||
|
||||
when(client.searchScroll(searchScrollRequestCaptor.capture())).thenReturn(
|
||||
first, rest.toArray(new ActionFuture[rest.size() - 1]));
|
||||
} else {
|
||||
when(client.prepareSearchScroll(SCROLL_ID)).thenReturn(first);
|
||||
when(client.searchScroll(searchScrollRequestCaptor.capture())).thenReturn(first);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void givenInitialResponse(String... hits) {
|
||||
SearchResponse searchResponse = createSearchResponseWithHits(hits);
|
||||
SearchRequestBuilder requestBuilder = mock(SearchRequestBuilder.class);
|
||||
when(client.prepareSearch(INDEX_NAME)).thenReturn(requestBuilder);
|
||||
when(requestBuilder.setScroll("5m")).thenReturn(requestBuilder);
|
||||
when(requestBuilder.setSize(10000)).thenReturn(requestBuilder);
|
||||
when(requestBuilder.setTypes("String")).thenReturn(requestBuilder);
|
||||
when(requestBuilder.setQuery(any(QueryBuilder.class))).thenReturn(requestBuilder);
|
||||
when(requestBuilder.addSort(any(SortBuilder.class))).thenReturn(requestBuilder);
|
||||
when(requestBuilder.get()).thenReturn(searchResponse);
|
||||
ActionFuture<SearchResponse> future = wrapResponse(searchResponse);
|
||||
when(future.actionGet()).thenReturn(searchResponse);
|
||||
when(client.search(searchRequestCaptor.capture())).thenReturn(future);
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private ActionFuture<SearchResponse> wrapResponse(SearchResponse searchResponse) {
|
||||
ActionFuture<SearchResponse> future = mock(ActionFuture.class);
|
||||
when(future.actionGet()).thenReturn(searchResponse);
|
||||
return future;
|
||||
}
|
||||
|
||||
private void givenNextResponse(String... hits) {
|
||||
SearchResponse searchResponse = createSearchResponseWithHits(hits);
|
||||
SearchScrollRequestBuilder requestBuilder = mock(SearchScrollRequestBuilder.class);
|
||||
when(requestBuilder.setScrollId(SCROLL_ID)).thenReturn(requestBuilder);
|
||||
when(requestBuilder.setScroll("5m")).thenReturn(requestBuilder);
|
||||
when(requestBuilder.get()).thenReturn(searchResponse);
|
||||
nextRequestBuilders.add(requestBuilder);
|
||||
responses.add(createSearchResponseWithHits(hits));
|
||||
}
|
||||
|
||||
private SearchResponse createSearchResponseWithHits(String... hits) {
|
||||
@ -174,16 +212,11 @@ public class BatchedDocumentsIteratorTests extends ESTestCase {
|
||||
}
|
||||
|
||||
private SearchHits createHits(String... values) {
|
||||
SearchHits searchHits = mock(SearchHits.class);
|
||||
List<SearchHit> hits = new ArrayList<>();
|
||||
for (String value : values) {
|
||||
SearchHit hit = mock(SearchHit.class);
|
||||
when(hit.getSourceAsString()).thenReturn(value);
|
||||
hits.add(hit);
|
||||
hits.add(new SearchHitBuilder(randomInt()).setSource(value).build());
|
||||
}
|
||||
when(searchHits.getTotalHits()).thenReturn(totalHits);
|
||||
when(searchHits.getHits()).thenReturn(hits.toArray(new SearchHit[hits.size()]));
|
||||
return searchHits;
|
||||
return new SearchHits(hits.toArray(new SearchHit[hits.size()]), totalHits, 1.0f);
|
||||
}
|
||||
}
|
||||
|
||||
@ -193,8 +226,8 @@ public class BatchedDocumentsIteratorTests extends ESTestCase {
|
||||
}
|
||||
|
||||
@Override
|
||||
protected String getType() {
|
||||
return "String";
|
||||
protected QueryBuilder getQuery() {
|
||||
return QueryBuilders.matchAllQuery();
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -202,5 +235,4 @@ public class BatchedDocumentsIteratorTests extends ESTestCase {
|
||||
return hit.getSourceAsString();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -8,6 +8,7 @@ package org.elasticsearch.xpack.ml.job.persistence;
|
||||
import org.elasticsearch.client.Client;
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.search.SearchHit;
|
||||
import org.elasticsearch.xpack.ml.job.results.Result;
|
||||
|
||||
import java.util.Deque;
|
||||
import java.util.List;
|
||||
@ -15,34 +16,35 @@ import java.util.NoSuchElementException;
|
||||
|
||||
import static org.mockito.Mockito.mock;
|
||||
|
||||
public class MockBatchedDocumentsIterator<T> extends BatchedDocumentsIterator<T> {
|
||||
private final List<Deque<T>> batches;
|
||||
public class MockBatchedDocumentsIterator<T> extends BatchedResultsIterator<T> {
|
||||
|
||||
private final List<Deque<Result<T>>> batches;
|
||||
private int index;
|
||||
private boolean wasTimeRangeCalled;
|
||||
private Boolean includeInterim;
|
||||
private Boolean requireIncludeInterim;
|
||||
|
||||
public MockBatchedDocumentsIterator(List<Deque<T>> batches) {
|
||||
super(mock(Client.class), "foo");
|
||||
public MockBatchedDocumentsIterator(List<Deque<Result<T>>> batches, String resultType) {
|
||||
super(mock(Client.class), "foo", resultType);
|
||||
this.batches = batches;
|
||||
index = 0;
|
||||
wasTimeRangeCalled = false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public BatchedDocumentsIterator<T> timeRange(long startEpochMs, long endEpochMs) {
|
||||
public BatchedResultsIterator<T> timeRange(long startEpochMs, long endEpochMs) {
|
||||
wasTimeRangeCalled = true;
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public BatchedDocumentsIterator<T> includeInterim(boolean includeInterim) {
|
||||
public BatchedResultsIterator<T> includeInterim(boolean includeInterim) {
|
||||
this.includeInterim = includeInterim;
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Deque<T> next() {
|
||||
public Deque<Result<T>> next() {
|
||||
if (requireIncludeInterim != null && requireIncludeInterim != includeInterim) {
|
||||
throw new IllegalStateException("Required include interim value [" + requireIncludeInterim + "]; actual was ["
|
||||
+ includeInterim + "]");
|
||||
@ -54,12 +56,7 @@ public class MockBatchedDocumentsIterator<T> extends BatchedDocumentsIterator<T>
|
||||
}
|
||||
|
||||
@Override
|
||||
protected String getType() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected T map(SearchHit hit) {
|
||||
protected Result<T> map(SearchHit hit) {
|
||||
return null;
|
||||
}
|
||||
|
||||
|
@ -209,8 +209,8 @@ public class ScoresUpdaterTests extends ESTestCase {
|
||||
|
||||
List<Deque<Result<AnomalyRecord>>> recordBatches = new ArrayList<>();
|
||||
recordBatches.add(new ArrayDeque<>(records));
|
||||
MockBatchedDocumentsIterator<Result<AnomalyRecord>> recordIter =
|
||||
new MockBatchedDocumentsIterator<>(recordBatches);
|
||||
MockBatchedDocumentsIterator<AnomalyRecord> recordIter = new MockBatchedDocumentsIterator<>(
|
||||
recordBatches, AnomalyRecord.RESULT_TYPE_VALUE);
|
||||
recordIter.requireIncludeInterim(false);
|
||||
when(jobProvider.newBatchedRecordsIterator(JOB_ID)).thenReturn(recordIter);
|
||||
|
||||
@ -376,8 +376,7 @@ public class ScoresUpdaterTests extends ESTestCase {
|
||||
batchesWithIndex.add(queueWithIndex);
|
||||
}
|
||||
|
||||
MockBatchedDocumentsIterator<Result<Bucket>> bucketIter =
|
||||
new MockBatchedDocumentsIterator<>(batchesWithIndex);
|
||||
MockBatchedDocumentsIterator<Bucket> bucketIter = new MockBatchedDocumentsIterator<>(batchesWithIndex, Bucket.RESULT_TYPE_VALUE);
|
||||
bucketIter.requireIncludeInterim(false);
|
||||
when(jobProvider.newBatchedBucketsIterator(JOB_ID)).thenReturn(bucketIter);
|
||||
}
|
||||
@ -394,8 +393,8 @@ public class ScoresUpdaterTests extends ESTestCase {
|
||||
}
|
||||
batches.add(batch);
|
||||
|
||||
MockBatchedDocumentsIterator<Result<AnomalyRecord>> recordIter =
|
||||
new MockBatchedDocumentsIterator<>(batches);
|
||||
MockBatchedDocumentsIterator<AnomalyRecord> recordIter = new MockBatchedDocumentsIterator<>(
|
||||
batches, AnomalyRecord.RESULT_TYPE_VALUE);
|
||||
recordIter.requireIncludeInterim(false);
|
||||
when(jobProvider.newBatchedRecordsIterator(JOB_ID)).thenReturn(recordIter);
|
||||
}
|
||||
@ -411,7 +410,7 @@ public class ScoresUpdaterTests extends ESTestCase {
|
||||
queue.add(new Result<>("foo", inf));
|
||||
}
|
||||
batches.add(queue);
|
||||
MockBatchedDocumentsIterator<Result<Influencer>> iterator = new MockBatchedDocumentsIterator<>(batches);
|
||||
MockBatchedDocumentsIterator<Influencer> iterator = new MockBatchedDocumentsIterator<>(batches, Influencer.RESULT_TYPE_VALUE);
|
||||
iterator.requireIncludeInterim(false);
|
||||
when(jobProvider.newBatchedInfluencersIterator(JOB_ID)).thenReturn(iterator);
|
||||
}
|
||||
|
@ -0,0 +1,50 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.ml.test;
|
||||
|
||||
import org.elasticsearch.common.bytes.BytesArray;
|
||||
import org.elasticsearch.search.SearchHit;
|
||||
import org.elasticsearch.search.SearchHitField;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* Utility class to build {@link SearchHit} in tests
|
||||
*/
|
||||
public class SearchHitBuilder {
|
||||
|
||||
private final SearchHit hit;
|
||||
private final Map<String, SearchHitField> fields;
|
||||
|
||||
public SearchHitBuilder(int docId) {
|
||||
hit = new SearchHit(docId);
|
||||
fields = new HashMap<>();
|
||||
}
|
||||
|
||||
public SearchHitBuilder addField(String name, Object value) {
|
||||
return addField(name, Arrays.asList(value));
|
||||
}
|
||||
|
||||
public SearchHitBuilder addField(String name, List<Object> values) {
|
||||
fields.put(name, new SearchHitField(name, values));
|
||||
return this;
|
||||
}
|
||||
|
||||
public SearchHitBuilder setSource(String sourceJson) {
|
||||
hit.sourceRef(new BytesArray(sourceJson));
|
||||
return this;
|
||||
}
|
||||
|
||||
public SearchHit build() {
|
||||
if (!fields.isEmpty()) {
|
||||
hit.fields(fields);
|
||||
}
|
||||
return hit;
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user