[ML] Refactor doc value format into ExtractedField (#35053)

This commit moves the knowledge of which doc value format
to be used down to the `ExtractedField` instead of being
in the data extractor.
This commit is contained in:
Dimitris Athanasiou 2018-10-29 22:56:53 +00:00 committed by GitHub
parent 794d4fa879
commit d85a654ebb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 44 additions and 27 deletions

View File

@ -7,6 +7,7 @@ package org.elasticsearch.xpack.ml.datafeed.extractor.scroll;
import org.elasticsearch.common.document.DocumentField;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.fetch.subphase.DocValueFieldsContext;
import org.joda.time.base.BaseDateTime;
import java.util.List;
@ -51,6 +52,10 @@ abstract class ExtractedField {
public abstract Object[] value(SearchHit hit);
public String getDocValueFormat() {
return DocValueFieldsContext.USE_DEFAULT_FORMAT;
}
public static ExtractedField newTimeField(String name, ExtractionMethod extractionMethod) {
if (extractionMethod == ExtractionMethod.SOURCE) {
throw new IllegalArgumentException("time field cannot be extracted from source");
@ -93,6 +98,8 @@ abstract class ExtractedField {
private static class TimeField extends FromFields {
private static final String EPOCH_MILLIS_FORMAT = "epoch_millis";
TimeField(String name, ExtractionMethod extractionMethod) {
super(name, name, extractionMethod);
}
@ -112,6 +119,11 @@ abstract class ExtractedField {
}
return value;
}
@Override
public String getDocValueFormat() {
return EPOCH_MILLIS_FORMAT;
}
}
private static class FromSource extends ExtractedField {

View File

@ -31,7 +31,7 @@ class ExtractedFields {
private final ExtractedField timeField;
private final List<ExtractedField> allFields;
private final String[] docValueFields;
private final List<ExtractedField> docValueFields;
private final String[] sourceFields;
ExtractedFields(ExtractedField timeField, List<ExtractedField> allFields) {
@ -41,7 +41,8 @@ class ExtractedFields {
this.timeField = Objects.requireNonNull(timeField);
this.allFields = Collections.unmodifiableList(allFields);
this.docValueFields = filterFields(ExtractedField.ExtractionMethod.DOC_VALUE, allFields);
this.sourceFields = filterFields(ExtractedField.ExtractionMethod.SOURCE, allFields);
this.sourceFields = filterFields(ExtractedField.ExtractionMethod.SOURCE, allFields).stream().map(ExtractedField::getName)
.toArray(String[]::new);
}
public List<ExtractedField> getAllFields() {
@ -52,18 +53,12 @@ class ExtractedFields {
return sourceFields;
}
public String[] getDocValueFields() {
public List<ExtractedField> getDocValueFields() {
return docValueFields;
}
private static String[] filterFields(ExtractedField.ExtractionMethod method, List<ExtractedField> fields) {
List<String> result = new ArrayList<>();
for (ExtractedField field : fields) {
if (field.getExtractionMethod() == method) {
result.add(field.getName());
}
}
return result.toArray(new String[result.size()]);
private static List<ExtractedField> filterFields(ExtractedField.ExtractionMethod method, List<ExtractedField> fields) {
return fields.stream().filter(field -> field.getExtractionMethod() == method).collect(Collectors.toList());
}
public String timeField() {

View File

@ -19,7 +19,6 @@ import org.elasticsearch.client.Client;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.fetch.StoredFieldsContext;
import org.elasticsearch.search.fetch.subphase.DocValueFieldsContext;
import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.datafeed.extractor.DataExtractor;
@ -44,7 +43,6 @@ class ScrollDataExtractor implements DataExtractor {
private static final Logger LOGGER = LogManager.getLogger(ScrollDataExtractor.class);
private static final TimeValue SCROLL_TIMEOUT = new TimeValue(30, TimeUnit.MINUTES);
private static final String EPOCH_MILLIS_FORMAT = "epoch_millis";
private final Client client;
private final ScrollDataExtractorContext context;
@ -112,12 +110,8 @@ class ScrollDataExtractor implements DataExtractor {
.setQuery(ExtractorUtils.wrapInTimeRangeQuery(
context.query, context.extractedFields.timeField(), start, context.end));
for (String docValueField : context.extractedFields.getDocValueFields()) {
if (docValueField.equals(context.extractedFields.timeField())) {
searchRequestBuilder.addDocValueField(docValueField, EPOCH_MILLIS_FORMAT);
} else {
searchRequestBuilder.addDocValueField(docValueField, DocValueFieldsContext.USE_DEFAULT_FORMAT);
}
for (ExtractedField docValueField : context.extractedFields.getDocValueFields()) {
searchRequestBuilder.addDocValueField(docValueField.getName(), docValueField.getDocValueFormat());
}
String[] sourceFields = context.extractedFields.getSourceFields();
if (sourceFields.length == 0) {

View File

@ -6,6 +6,7 @@
package org.elasticsearch.xpack.ml.datafeed.extractor.scroll;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.fetch.subphase.DocValueFieldsContext;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.ml.test.SearchHitBuilder;
import org.joda.time.DateTime;
@ -140,4 +141,14 @@ public class ExtractedFieldTests extends ESTestCase {
assertThat(field.getName(), equalTo("b"));
assertThat(field.value(hit), equalTo(new Integer[] { 2 }));
}
public void testGetDocValueFormat() {
for (ExtractedField.ExtractionMethod method : ExtractedField.ExtractionMethod.values()) {
assertThat(ExtractedField.newField("f", method).getDocValueFormat(), equalTo(DocValueFieldsContext.USE_DEFAULT_FORMAT));
}
assertThat(ExtractedField.newTimeField("doc_value_time", ExtractedField.ExtractionMethod.DOC_VALUE).getDocValueFormat(),
equalTo("epoch_millis"));
assertThat(ExtractedField.newTimeField("source_time", ExtractedField.ExtractionMethod.SCRIPT_FIELD).getDocValueFormat(),
equalTo("epoch_millis"));
}
}

View File

@ -11,6 +11,7 @@ import org.elasticsearch.action.fieldcaps.FieldCapabilitiesResponse;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.fetch.subphase.DocValueFieldsContext;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.datafeed.DatafeedConfig;
import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig;
@ -43,7 +44,8 @@ public class ExtractedFieldsTests extends ESTestCase {
assertThat(extractedFields.getAllFields(), equalTo(Arrays.asList(timeField)));
assertThat(extractedFields.timeField(), equalTo("time"));
assertThat(extractedFields.getDocValueFields(), equalTo(new String[] { timeField.getName() }));
assertThat(extractedFields.getDocValueFields().stream().map(ExtractedField::getName).toArray(String[]::new),
equalTo(new String[] { timeField.getName() }));
assertThat(extractedFields.getSourceFields().length, equalTo(0));
}
@ -59,7 +61,8 @@ public class ExtractedFieldsTests extends ESTestCase {
assertThat(extractedFields.getAllFields().size(), equalTo(7));
assertThat(extractedFields.timeField(), equalTo("time"));
assertThat(extractedFields.getDocValueFields(), equalTo(new String[] {"time", "doc1", "doc2"}));
assertThat(extractedFields.getDocValueFields().stream().map(ExtractedField::getName).toArray(String[]::new),
equalTo(new String[] {"time", "doc1", "doc2"}));
assertThat(extractedFields.getSourceFields(), equalTo(new String[] {"src1", "src2"}));
}
@ -138,9 +141,11 @@ public class ExtractedFieldsTests extends ESTestCase {
fieldCapabilitiesResponse);
assertThat(extractedFields.timeField(), equalTo("time"));
assertThat(extractedFields.getDocValueFields().length, equalTo(2));
assertThat(extractedFields.getDocValueFields()[0], equalTo("time"));
assertThat(extractedFields.getDocValueFields()[1], equalTo("value"));
assertThat(extractedFields.getDocValueFields().size(), equalTo(2));
assertThat(extractedFields.getDocValueFields().get(0).getName(), equalTo("time"));
assertThat(extractedFields.getDocValueFields().get(0).getDocValueFormat(), equalTo("epoch_millis"));
assertThat(extractedFields.getDocValueFields().get(1).getName(), equalTo("value"));
assertThat(extractedFields.getDocValueFields().get(1).getDocValueFormat(), equalTo(DocValueFieldsContext.USE_DEFAULT_FORMAT));
assertThat(extractedFields.getSourceFields().length, equalTo(1));
assertThat(extractedFields.getSourceFields()[0], equalTo("airline"));
assertThat(extractedFields.getAllFields().size(), equalTo(4));
@ -174,9 +179,9 @@ public class ExtractedFieldsTests extends ESTestCase {
fieldCapabilitiesResponse);
assertThat(extractedFields.timeField(), equalTo("time"));
assertThat(extractedFields.getDocValueFields().length, equalTo(2));
assertThat(extractedFields.getDocValueFields()[0], equalTo("time"));
assertThat(extractedFields.getDocValueFields()[1], equalTo("airport.keyword"));
assertThat(extractedFields.getDocValueFields().size(), equalTo(2));
assertThat(extractedFields.getDocValueFields().get(0).getName(), equalTo("time"));
assertThat(extractedFields.getDocValueFields().get(1).getName(), equalTo("airport.keyword"));
assertThat(extractedFields.getSourceFields().length, equalTo(1));
assertThat(extractedFields.getSourceFields()[0], equalTo("airline"));
assertThat(extractedFields.getAllFields().size(), equalTo(3));