mirror of
https://github.com/honeymoose/OpenSearch.git
synced 2025-03-24 17:09:48 +00:00
* [ML] adds new n_gram_encoding custom processor (#61578) This adds a new `n_gram_encoding` feature processor for analytics and inference. The focus of this processor is simple ngram encodings that allow: - multiple ngrams [1..5] - Prefix, infix, suffix
This commit is contained in:
parent
7b021bf3fb
commit
cec102a391
@ -19,6 +19,7 @@
|
||||
package org.elasticsearch.client.ml.inference;
|
||||
|
||||
import org.elasticsearch.client.ml.inference.preprocessing.CustomWordEmbedding;
|
||||
import org.elasticsearch.client.ml.inference.preprocessing.NGram;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.ClassificationConfig;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.InferenceConfig;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.RegressionConfig;
|
||||
@ -57,6 +58,8 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
|
||||
FrequencyEncoding::fromXContent));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(PreProcessor.class, new ParseField(CustomWordEmbedding.NAME),
|
||||
CustomWordEmbedding::fromXContent));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(PreProcessor.class, new ParseField(NGram.NAME),
|
||||
NGram::fromXContent));
|
||||
|
||||
// Model
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(TrainedModel.class, new ParseField(Tree.NAME), Tree::fromXContent));
|
||||
|
@ -0,0 +1,211 @@
|
||||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.inference.preprocessing;
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
|
||||
/**
|
||||
* PreProcessor for n-gram encoding a string
|
||||
*/
|
||||
public class NGram implements PreProcessor {
|
||||
|
||||
public static final String NAME = "n_gram_encoding";
|
||||
public static final ParseField FIELD = new ParseField("field");
|
||||
public static final ParseField FEATURE_PREFIX = new ParseField("feature_prefix");
|
||||
public static final ParseField NGRAMS = new ParseField("n_grams");
|
||||
public static final ParseField START = new ParseField("start");
|
||||
public static final ParseField LENGTH = new ParseField("length");
|
||||
public static final ParseField CUSTOM = new ParseField("custom");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static final ConstructingObjectParser<NGram, Void> PARSER = new ConstructingObjectParser<NGram, Void>(
|
||||
NAME,
|
||||
true,
|
||||
a -> new NGram((String)a[0],
|
||||
(List<Integer>)a[1],
|
||||
(Integer)a[2],
|
||||
(Integer)a[3],
|
||||
(Boolean)a[4],
|
||||
(String)a[5]));
|
||||
static {
|
||||
PARSER.declareString(ConstructingObjectParser.constructorArg(), FIELD);
|
||||
PARSER.declareIntArray(ConstructingObjectParser.constructorArg(), NGRAMS);
|
||||
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), START);
|
||||
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), LENGTH);
|
||||
PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), CUSTOM);
|
||||
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), FEATURE_PREFIX);
|
||||
}
|
||||
|
||||
public static NGram fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
private final String field;
|
||||
private final String featurePrefix;
|
||||
private final List<Integer> nGrams;
|
||||
private final Integer start;
|
||||
private final Integer length;
|
||||
private final Boolean custom;
|
||||
|
||||
NGram(String field, List<Integer> nGrams, Integer start, Integer length, Boolean custom, String featurePrefix) {
|
||||
this.field = field;
|
||||
this.featurePrefix = featurePrefix;
|
||||
this.nGrams = nGrams;
|
||||
this.start = start;
|
||||
this.length = length;
|
||||
this.custom = custom;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
if (field != null) {
|
||||
builder.field(FIELD.getPreferredName(), field);
|
||||
}
|
||||
if (featurePrefix != null) {
|
||||
builder.field(FEATURE_PREFIX.getPreferredName(), featurePrefix);
|
||||
}
|
||||
if (nGrams != null) {
|
||||
builder.field(NGRAMS.getPreferredName(), nGrams);
|
||||
}
|
||||
if (start != null) {
|
||||
builder.field(START.getPreferredName(), start);
|
||||
}
|
||||
if (length != null) {
|
||||
builder.field(LENGTH.getPreferredName(), length);
|
||||
}
|
||||
if (custom != null) {
|
||||
builder.field(CUSTOM.getPreferredName(), custom);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
public String getField() {
|
||||
return field;
|
||||
}
|
||||
|
||||
public String getFeaturePrefix() {
|
||||
return featurePrefix;
|
||||
}
|
||||
|
||||
public List<Integer> getnGrams() {
|
||||
return nGrams;
|
||||
}
|
||||
|
||||
public Integer getStart() {
|
||||
return start;
|
||||
}
|
||||
|
||||
public Integer getLength() {
|
||||
return length;
|
||||
}
|
||||
|
||||
public Boolean getCustom() {
|
||||
return custom;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
NGram nGram = (NGram) o;
|
||||
return Objects.equals(field, nGram.field) &&
|
||||
Objects.equals(featurePrefix, nGram.featurePrefix) &&
|
||||
Objects.equals(nGrams, nGram.nGrams) &&
|
||||
Objects.equals(start, nGram.start) &&
|
||||
Objects.equals(length, nGram.length) &&
|
||||
Objects.equals(custom, nGram.custom);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(field, featurePrefix, start, length, custom, nGrams);
|
||||
}
|
||||
|
||||
public static Builder builder(String field) {
|
||||
return new Builder(field);
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
|
||||
private String field;
|
||||
private String featurePrefix;
|
||||
private List<Integer> nGrams;
|
||||
private Integer start;
|
||||
private Integer length;
|
||||
private Boolean custom;
|
||||
|
||||
public Builder(String field) {
|
||||
this.field = field;
|
||||
}
|
||||
|
||||
public Builder setField(String field) {
|
||||
this.field = field;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setCustom(boolean custom) {
|
||||
this.custom = custom;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setFeaturePrefix(String featurePrefix) {
|
||||
this.featurePrefix = featurePrefix;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setnGrams(List<Integer> nGrams) {
|
||||
this.nGrams = nGrams;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setStart(Integer start) {
|
||||
this.start = start;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setLength(Integer length) {
|
||||
this.length = length;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setCustom(Boolean custom) {
|
||||
this.custom = custom;
|
||||
return this;
|
||||
}
|
||||
|
||||
public NGram build() {
|
||||
return new NGram(field, nGrams, start, length, custom, featurePrefix);
|
||||
}
|
||||
}
|
||||
}
|
@ -74,6 +74,7 @@ import org.elasticsearch.client.ml.dataframe.stats.outlierdetection.OutlierDetec
|
||||
import org.elasticsearch.client.ml.dataframe.stats.regression.RegressionStats;
|
||||
import org.elasticsearch.client.ml.inference.preprocessing.CustomWordEmbedding;
|
||||
import org.elasticsearch.client.ml.inference.preprocessing.FrequencyEncoding;
|
||||
import org.elasticsearch.client.ml.inference.preprocessing.NGram;
|
||||
import org.elasticsearch.client.ml.inference.preprocessing.OneHotEncoding;
|
||||
import org.elasticsearch.client.ml.inference.preprocessing.TargetMeanEncoding;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.ClassificationConfig;
|
||||
@ -704,7 +705,7 @@ public class RestHighLevelClientTests extends ESTestCase {
|
||||
|
||||
public void testProvidedNamedXContents() {
|
||||
List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
|
||||
assertEquals(69, namedXContents.size());
|
||||
assertEquals(70, namedXContents.size());
|
||||
Map<Class<?>, Integer> categories = new HashMap<>();
|
||||
List<String> names = new ArrayList<>();
|
||||
for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
|
||||
@ -785,8 +786,9 @@ public class RestHighLevelClientTests extends ESTestCase {
|
||||
registeredMetricName(Regression.NAME, MeanSquaredLogarithmicErrorMetric.NAME),
|
||||
registeredMetricName(Regression.NAME, HuberMetric.NAME),
|
||||
registeredMetricName(Regression.NAME, RSquaredMetric.NAME)));
|
||||
assertEquals(Integer.valueOf(4), categories.get(org.elasticsearch.client.ml.inference.preprocessing.PreProcessor.class));
|
||||
assertThat(names, hasItems(FrequencyEncoding.NAME, OneHotEncoding.NAME, TargetMeanEncoding.NAME, CustomWordEmbedding.NAME));
|
||||
assertEquals(Integer.valueOf(5), categories.get(org.elasticsearch.client.ml.inference.preprocessing.PreProcessor.class));
|
||||
assertThat(names,
|
||||
hasItems(FrequencyEncoding.NAME, OneHotEncoding.NAME, TargetMeanEncoding.NAME, CustomWordEmbedding.NAME, NGram.NAME));
|
||||
assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel.class));
|
||||
assertThat(names, hasItems(Tree.NAME, Ensemble.NAME, LangIdentNeuralNetwork.NAME));
|
||||
assertEquals(Integer.valueOf(4),
|
||||
|
@ -0,0 +1,55 @@
|
||||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.inference.preprocessing;
|
||||
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
|
||||
public class NGramTests extends AbstractXContentTestCase<NGram> {
|
||||
|
||||
@Override
|
||||
protected NGram doParseInstance(XContentParser parser) throws IOException {
|
||||
return NGram.fromXContent(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NGram createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
public static NGram createRandom() {
|
||||
return new NGram(randomAlphaOfLength(10),
|
||||
IntStream.range(1, 5).limit(5).boxed().collect(Collectors.toList()),
|
||||
randomBoolean() ? null : randomIntBetween(0, 10),
|
||||
randomBoolean() ? null : randomIntBetween(1, 10),
|
||||
randomBoolean() ? null : randomBoolean(),
|
||||
randomBoolean() ? null : randomAlphaOfLength(10));
|
||||
}
|
||||
|
||||
}
|
@ -161,6 +161,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierD
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStats;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.CustomWordEmbedding;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.NGram;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncoding;
|
||||
@ -525,6 +526,7 @@ public class XPackClientPlugin extends Plugin implements ActionPlugin, NetworkPl
|
||||
new NamedWriteableRegistry.Entry(PreProcessor.class, OneHotEncoding.NAME.getPreferredName(), OneHotEncoding::new),
|
||||
new NamedWriteableRegistry.Entry(PreProcessor.class, TargetMeanEncoding.NAME.getPreferredName(), TargetMeanEncoding::new),
|
||||
new NamedWriteableRegistry.Entry(PreProcessor.class, CustomWordEmbedding.NAME.getPreferredName(), CustomWordEmbedding::new),
|
||||
new NamedWriteableRegistry.Entry(PreProcessor.class, NGram.NAME.getPreferredName(), NGram::new),
|
||||
// ML - Inference models
|
||||
new NamedWriteableRegistry.Entry(TrainedModel.class, Tree.NAME.getPreferredName(), Tree::new),
|
||||
new NamedWriteableRegistry.Entry(TrainedModel.class, Ensemble.NAME.getPreferredName(), Ensemble::new),
|
||||
|
@ -9,6 +9,13 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.plugins.spi.NamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.CustomWordEmbedding;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.NGram;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncoding;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
|
||||
@ -39,12 +46,6 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.Inferenc
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.TreeInferenceModel;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.langident.LangIdentNeuralNetwork;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncoding;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
@ -64,6 +65,8 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
|
||||
(p, c) -> FrequencyEncoding.fromXContentLenient(p, (PreProcessor.PreProcessorParseContext) c)));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, CustomWordEmbedding.NAME,
|
||||
(p, c) -> CustomWordEmbedding.fromXContentLenient(p)));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, NGram.NAME,
|
||||
(p, c) -> NGram.fromXContentLenient(p, (PreProcessor.PreProcessorParseContext) c)));
|
||||
|
||||
// PreProcessing Strict
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, OneHotEncoding.NAME,
|
||||
@ -74,6 +77,8 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
|
||||
(p, c) -> FrequencyEncoding.fromXContentStrict(p, (PreProcessor.PreProcessorParseContext) c)));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, CustomWordEmbedding.NAME,
|
||||
(p, c) -> CustomWordEmbedding.fromXContentStrict(p)));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, NGram.NAME,
|
||||
(p, c) -> NGram.fromXContentStrict(p, (PreProcessor.PreProcessorParseContext) c)));
|
||||
|
||||
// Model Lenient
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedTrainedModel.class, Tree.NAME, Tree::fromXContentLenient));
|
||||
@ -154,6 +159,8 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
|
||||
FrequencyEncoding::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(PreProcessor.class, CustomWordEmbedding.NAME.getPreferredName(),
|
||||
CustomWordEmbedding::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(PreProcessor.class, NGram.NAME.getPreferredName(),
|
||||
NGram::new));
|
||||
|
||||
// Model
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(TrainedModel.class, Tree.NAME.getPreferredName(), Tree::new));
|
||||
|
@ -0,0 +1,304 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.preprocessing;
|
||||
|
||||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.util.set.Sets;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.index.mapper.TextFieldMapper;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.function.Function;
|
||||
import java.util.function.IntFunction;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
import static org.apache.lucene.util.RamUsageEstimator.sizeOf;
|
||||
|
||||
/**
|
||||
* PreProcessor for n-gram encoding a string
|
||||
*/
|
||||
public class NGram implements LenientlyParsedPreProcessor, StrictlyParsedPreProcessor {
|
||||
|
||||
private static final int DEFAULT_START = 0;
|
||||
private static final int DEFAULT_LENGTH = 50;
|
||||
private static final int MAX_LENGTH = 100;
|
||||
private static final int MIN_GRAM = 1;
|
||||
private static final int MAX_GRAM = 5;
|
||||
|
||||
private static String defaultPrefix(Integer start, Integer length) {
|
||||
return "ngram_"
|
||||
+ (start == null ? DEFAULT_START : start)
|
||||
+ "_"
|
||||
+ (length == null ? DEFAULT_LENGTH : length);
|
||||
}
|
||||
|
||||
public static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(NGram.class);
|
||||
public static final ParseField NAME = new ParseField("n_gram_encoding");
|
||||
public static final ParseField FIELD = new ParseField("field");
|
||||
public static final ParseField FEATURE_PREFIX = new ParseField("feature_prefix");
|
||||
public static final ParseField NGRAMS = new ParseField("n_grams");
|
||||
public static final ParseField START = new ParseField("start");
|
||||
public static final ParseField LENGTH = new ParseField("length");
|
||||
public static final ParseField CUSTOM = new ParseField("custom");
|
||||
|
||||
private static final ConstructingObjectParser<NGram, PreProcessorParseContext> STRICT_PARSER = createParser(false);
|
||||
private static final ConstructingObjectParser<NGram, PreProcessorParseContext> LENIENT_PARSER = createParser(true);
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static ConstructingObjectParser<NGram, PreProcessorParseContext> createParser(boolean lenient) {
|
||||
ConstructingObjectParser<NGram, PreProcessorParseContext> parser = new ConstructingObjectParser<>(
|
||||
NAME.getPreferredName(),
|
||||
lenient,
|
||||
(a, c) -> new NGram((String)a[0],
|
||||
(List<Integer>)a[1],
|
||||
(Integer)a[2],
|
||||
(Integer)a[3],
|
||||
a[4] == null ? c.isCustomByDefault() : (Boolean)a[4],
|
||||
(String)a[5]));
|
||||
parser.declareString(ConstructingObjectParser.constructorArg(), FIELD);
|
||||
parser.declareIntArray(ConstructingObjectParser.constructorArg(), NGRAMS);
|
||||
parser.declareInt(ConstructingObjectParser.optionalConstructorArg(), START);
|
||||
parser.declareInt(ConstructingObjectParser.optionalConstructorArg(), LENGTH);
|
||||
parser.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), CUSTOM);
|
||||
parser.declareString(ConstructingObjectParser.optionalConstructorArg(), FEATURE_PREFIX);
|
||||
return parser;
|
||||
}
|
||||
|
||||
public static NGram fromXContentStrict(XContentParser parser, PreProcessorParseContext context) {
|
||||
return STRICT_PARSER.apply(parser, context == null ? PreProcessorParseContext.DEFAULT : context);
|
||||
}
|
||||
|
||||
public static NGram fromXContentLenient(XContentParser parser, PreProcessorParseContext context) {
|
||||
return LENIENT_PARSER.apply(parser, context == null ? PreProcessorParseContext.DEFAULT : context);
|
||||
}
|
||||
|
||||
private final String field;
|
||||
private final String featurePrefix;
|
||||
private final int[] nGrams;
|
||||
private final int start;
|
||||
private final int length;
|
||||
private final boolean custom;
|
||||
|
||||
NGram(String field,
|
||||
List<Integer> nGrams,
|
||||
Integer start,
|
||||
Integer length,
|
||||
Boolean custom,
|
||||
String featurePrefix) {
|
||||
this(field,
|
||||
featurePrefix == null ? defaultPrefix(start, length) : featurePrefix,
|
||||
Sets.newHashSet(nGrams).stream().mapToInt(Integer::intValue).toArray(),
|
||||
start == null ? DEFAULT_START : start,
|
||||
length == null ? DEFAULT_LENGTH : length,
|
||||
custom != null && custom);
|
||||
}
|
||||
|
||||
public NGram(String field, String featurePrefix, int[] nGrams, int start, int length, boolean custom) {
|
||||
this.field = ExceptionsHelper.requireNonNull(field, FIELD);
|
||||
this.featurePrefix = ExceptionsHelper.requireNonNull(featurePrefix, FEATURE_PREFIX);
|
||||
this.nGrams = ExceptionsHelper.requireNonNull(nGrams, NGRAMS);
|
||||
if (Arrays.stream(this.nGrams).anyMatch(i -> i < MIN_GRAM || i > MAX_GRAM)) {
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"[{}] is invalid [{}]; minimum supported value is [{}]; maximum supported value is [{}]",
|
||||
NGRAMS.getPreferredName(),
|
||||
Arrays.stream(nGrams).mapToObj(String::valueOf).collect(Collectors.joining(", ")),
|
||||
MIN_GRAM,
|
||||
MAX_GRAM);
|
||||
}
|
||||
this.start = start;
|
||||
if (start < 0 && length + start > 0) {
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"if [start] is negative, [length] + [start] must be less than 0");
|
||||
}
|
||||
this.length = length;
|
||||
if (length <= 0) {
|
||||
throw ExceptionsHelper.badRequestException("[{}] must be a positive integer", LENGTH.getPreferredName());
|
||||
}
|
||||
if (length > MAX_LENGTH) {
|
||||
throw ExceptionsHelper.badRequestException("[{}] must be not be greater than [{}]", LENGTH.getPreferredName(), MAX_LENGTH);
|
||||
}
|
||||
this.custom = custom;
|
||||
}
|
||||
|
||||
public NGram(StreamInput in) throws IOException {
|
||||
this.field = in.readString();
|
||||
this.featurePrefix = in.readString();
|
||||
this.nGrams = in.readVIntArray();
|
||||
this.start = in.readInt();
|
||||
this.length = in.readVInt();
|
||||
this.custom = in.readBoolean();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeString(field);
|
||||
out.writeString(featurePrefix);
|
||||
out.writeVIntArray(nGrams);
|
||||
out.writeInt(start);
|
||||
out.writeVInt(length);
|
||||
out.writeBoolean(custom);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return Strings.toString(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> inputFields() {
|
||||
return Collections.singletonList(field);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> outputFields() {
|
||||
return allPossibleNGramOutputFeatureNames();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void process(Map<String, Object> fields) {
|
||||
Object value = fields.get(field);
|
||||
if (value == null) {
|
||||
return;
|
||||
}
|
||||
final String stringValue = value.toString();
|
||||
// String is too small for the starting point
|
||||
if (start > stringValue.length() || stringValue.length() + start < 0) {
|
||||
return;
|
||||
}
|
||||
final int startPos = start < 0 ? (stringValue.length() + start) : start;
|
||||
final int len = Math.min(startPos + length, stringValue.length());
|
||||
for (int i = 0; i < len; i++) {
|
||||
for (int nGram : nGrams) {
|
||||
if (startPos + i + nGram > len) {
|
||||
break;
|
||||
}
|
||||
fields.put(nGramFeature(nGram, i), stringValue.substring(startPos + i, startPos + i + nGram));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, String> reverseLookup() {
|
||||
return outputFields().stream().collect(Collectors.toMap(Function.identity(), ignored -> field));
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getOutputFieldType(String outputField) {
|
||||
return TextFieldMapper.CONTENT_TYPE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public long ramBytesUsed() {
|
||||
long size = SHALLOW_SIZE;
|
||||
size += sizeOf(field);
|
||||
size += sizeOf(featurePrefix);
|
||||
size += sizeOf(nGrams);
|
||||
return size;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(FIELD.getPreferredName(), field);
|
||||
builder.field(FEATURE_PREFIX.getPreferredName(), featurePrefix);
|
||||
builder.field(NGRAMS.getPreferredName(), nGrams);
|
||||
builder.field(START.getPreferredName(), start);
|
||||
builder.field(LENGTH.getPreferredName(), length);
|
||||
builder.field(CUSTOM.getPreferredName(), custom);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
public String getField() {
|
||||
return field;
|
||||
}
|
||||
|
||||
public String getFeaturePrefix() {
|
||||
return featurePrefix;
|
||||
}
|
||||
|
||||
public int[] getnGrams() {
|
||||
return nGrams;
|
||||
}
|
||||
|
||||
public int getStart() {
|
||||
return start;
|
||||
}
|
||||
|
||||
public int getLength() {
|
||||
return length;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isCustom() {
|
||||
return custom;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
NGram nGram = (NGram) o;
|
||||
return start == nGram.start &&
|
||||
length == nGram.length &&
|
||||
custom == nGram.custom &&
|
||||
Objects.equals(field, nGram.field) &&
|
||||
Objects.equals(featurePrefix, nGram.featurePrefix) &&
|
||||
Arrays.equals(nGrams, nGram.nGrams);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
int result = Objects.hash(field, featurePrefix, start, length, custom);
|
||||
result = 31 * result + Arrays.hashCode(nGrams);
|
||||
return result;
|
||||
}
|
||||
|
||||
private String nGramFeature(int nGram, int pos) {
|
||||
return featurePrefix
|
||||
+ "."
|
||||
+ nGram
|
||||
+ pos;
|
||||
}
|
||||
|
||||
private List<String> allPossibleNGramOutputFeatureNames() {
|
||||
int totalNgrams = 0;
|
||||
for (int nGram : nGrams) {
|
||||
totalNgrams += (length - (nGram - 1));
|
||||
}
|
||||
List<String> ngramOutputs = new ArrayList<>(totalNgrams);
|
||||
|
||||
for (int nGram : nGrams) {
|
||||
IntFunction<String> func = i -> nGramFeature(nGram, i);
|
||||
IntStream.range(0, (length - (nGram - 1))).mapToObj(func).forEach(ngramOutputs::add);
|
||||
}
|
||||
return ngramOutputs;
|
||||
}
|
||||
}
|
@ -36,7 +36,7 @@ public interface PreProcessor extends NamedXContentObject, NamedWriteable, Accou
|
||||
List<String> inputFields();
|
||||
|
||||
/**
|
||||
* @return The resulting output fields
|
||||
* @return The resulting output fields. It is imperative that the order is consistent between calls.
|
||||
*/
|
||||
List<String> outputFields();
|
||||
|
||||
|
@ -0,0 +1,144 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.preprocessing;
|
||||
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.hamcrest.Matcher;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
import static org.hamcrest.Matchers.containsInAnyOrder;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.hamcrest.Matchers.nullValue;
|
||||
import static org.hamcrest.Matchers.contains;
|
||||
|
||||
|
||||
public class NGramTests extends PreProcessingTests<NGram> {
|
||||
|
||||
@Override
|
||||
protected NGram doParseInstance(XContentParser parser) throws IOException {
|
||||
return lenient ?
|
||||
NGram.fromXContentLenient(parser, PreProcessor.PreProcessorParseContext.DEFAULT) :
|
||||
NGram.fromXContentStrict(parser, PreProcessor.PreProcessorParseContext.DEFAULT);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NGram createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
public static NGram createRandom() {
|
||||
return createRandom(randomBoolean() ? randomBoolean() : null);
|
||||
}
|
||||
|
||||
public static NGram createRandom(Boolean isCustom) {
|
||||
return new NGram(
|
||||
randomAlphaOfLength(10),
|
||||
IntStream.generate(() -> randomIntBetween(1, 5)).limit(5).boxed().collect(Collectors.toList()),
|
||||
randomBoolean() ? null : randomIntBetween(0, 10),
|
||||
randomBoolean() ? null : randomIntBetween(1, 10),
|
||||
isCustom,
|
||||
randomBoolean() ? null : randomAlphaOfLength(10));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<NGram> instanceReader() {
|
||||
return NGram::new;
|
||||
}
|
||||
|
||||
public void testProcessNGramPrefix() {
|
||||
String field = "text";
|
||||
String fieldValue = "this is the value";
|
||||
NGram encoding = new NGram(field, "f", new int[]{1, 4}, 0, 5, false);
|
||||
Map<String, Object> fieldValues = randomFieldValues(field, fieldValue);
|
||||
|
||||
Map<String, Matcher<? super Object>> matchers = new HashMap<>();
|
||||
matchers.put("f.10", equalTo("t"));
|
||||
matchers.put("f.11", equalTo("h"));
|
||||
matchers.put("f.12", equalTo("i"));
|
||||
matchers.put("f.13", equalTo("s"));
|
||||
matchers.put("f.14", equalTo(" "));
|
||||
matchers.put("f.40", equalTo("this"));
|
||||
matchers.put("f.41", equalTo("his "));
|
||||
testProcess(encoding, fieldValues, matchers);
|
||||
}
|
||||
|
||||
public void testProcessNGramSuffix() {
|
||||
String field = "text";
|
||||
String fieldValue = "this is the value";
|
||||
|
||||
NGram encoding = new NGram(field, "f", new int[]{1, 3}, -3, 3, false);
|
||||
Map<String, Object> fieldValues = randomFieldValues(field, fieldValue);
|
||||
Map<String, Matcher<? super Object>> matchers = new HashMap<>();
|
||||
matchers.put("f.10", equalTo("l"));
|
||||
matchers.put("f.11", equalTo("u"));
|
||||
matchers.put("f.12", equalTo("e"));
|
||||
matchers.put("f.30", equalTo("lue"));
|
||||
matchers.put("f.31", is(nullValue()));
|
||||
testProcess(encoding, fieldValues, matchers);
|
||||
}
|
||||
|
||||
public void testProcessNGramInfix() {
|
||||
String field = "text";
|
||||
String fieldValue = "this is the value";
|
||||
|
||||
NGram encoding = new NGram(field, "f", new int[]{1, 3}, 3, 3, false);
|
||||
Map<String, Object> fieldValues = randomFieldValues(field, fieldValue);
|
||||
Map<String, Matcher<? super Object>> matchers = new HashMap<>();
|
||||
matchers.put("f.10", equalTo("s"));
|
||||
matchers.put("f.11", equalTo(" "));
|
||||
matchers.put("f.12", equalTo("i"));
|
||||
matchers.put("f.30", equalTo("s i"));
|
||||
matchers.put("f.31", is(nullValue()));
|
||||
testProcess(encoding, fieldValues, matchers);
|
||||
}
|
||||
|
||||
public void testProcessNGramLengthOverrun() {
|
||||
String field = "text";
|
||||
String fieldValue = "this is the value";
|
||||
|
||||
NGram encoding = new NGram(field, "f", new int[]{1, 3}, 12, 10, false);
|
||||
Map<String, Object> fieldValues = randomFieldValues(field, fieldValue);
|
||||
Map<String, Matcher<? super Object>> matchers = new HashMap<>();
|
||||
matchers.put("f.10", equalTo("v"));
|
||||
matchers.put("f.11", equalTo("a"));
|
||||
matchers.put("f.12", equalTo("l"));
|
||||
matchers.put("f.13", equalTo("u"));
|
||||
matchers.put("f.14", equalTo("e"));
|
||||
matchers.put("f.30", equalTo("val"));
|
||||
matchers.put("f.31", equalTo("alu"));
|
||||
matchers.put("f.32", equalTo("lue"));
|
||||
testProcess(encoding, fieldValues, matchers);
|
||||
}
|
||||
|
||||
public void testInputOutputFields() {
|
||||
String field = randomAlphaOfLength(10);
|
||||
NGram encoding = new NGram(field, "f", new int[]{1, 4}, 0, 5, false);
|
||||
assertThat(encoding.inputFields(), containsInAnyOrder(field));
|
||||
assertThat(encoding.outputFields(),
|
||||
contains("f.10", "f.11","f.12","f.13","f.14","f.40", "f.41"));
|
||||
|
||||
encoding = new NGram(field, Arrays.asList(1, 4), 0, 5, false, null);
|
||||
assertThat(encoding.inputFields(), containsInAnyOrder(field));
|
||||
assertThat(encoding.outputFields(),
|
||||
contains(
|
||||
"ngram_0_5.10",
|
||||
"ngram_0_5.11",
|
||||
"ngram_0_5.12",
|
||||
"ngram_0_5.13",
|
||||
"ngram_0_5.14",
|
||||
"ngram_0_5.40",
|
||||
"ngram_0_5.41"));
|
||||
}
|
||||
|
||||
}
|
@ -0,0 +1,277 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.ml.integration;
|
||||
|
||||
import org.elasticsearch.ElasticsearchException;
|
||||
import org.elasticsearch.action.DocWriteRequest;
|
||||
import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
|
||||
import org.elasticsearch.action.bulk.BulkRequestBuilder;
|
||||
import org.elasticsearch.action.bulk.BulkResponse;
|
||||
import org.elasticsearch.action.get.GetResponse;
|
||||
import org.elasticsearch.action.index.IndexRequest;
|
||||
import org.elasticsearch.action.search.SearchResponse;
|
||||
import org.elasticsearch.action.support.WriteRequest;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.XContentType;
|
||||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.search.SearchHit;
|
||||
import org.elasticsearch.search.SearchModule;
|
||||
import org.elasticsearch.search.fetch.subphase.FetchSourceContext;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.NGram;
|
||||
import org.elasticsearch.xpack.core.ml.utils.QueryProvider;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.everyItem;
|
||||
import static org.hamcrest.Matchers.hasKey;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.hamcrest.Matchers.startsWith;
|
||||
|
||||
public class DataFrameAnalysisCustomFeatureIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||
|
||||
private static final String BOOLEAN_FIELD = "boolean-field";
|
||||
private static final String NUMERICAL_FIELD = "numerical-field";
|
||||
private static final String DISCRETE_NUMERICAL_FIELD = "discrete-numerical-field";
|
||||
private static final String TEXT_FIELD = "text-field";
|
||||
private static final String KEYWORD_FIELD = "keyword-field";
|
||||
private static final String NESTED_FIELD = "outer-field.inner-field";
|
||||
private static final String ALIAS_TO_KEYWORD_FIELD = "alias-to-keyword-field";
|
||||
private static final String ALIAS_TO_NESTED_FIELD = "alias-to-nested-field";
|
||||
private static final List<Boolean> BOOLEAN_FIELD_VALUES = org.elasticsearch.common.collect.List.of(false, true);
|
||||
private static final List<Double> NUMERICAL_FIELD_VALUES = org.elasticsearch.common.collect.List.of(1.0, 2.0);
|
||||
private static final List<Integer> DISCRETE_NUMERICAL_FIELD_VALUES = org.elasticsearch.common.collect.List.of(10, 20);
|
||||
private static final List<String> KEYWORD_FIELD_VALUES = org.elasticsearch.common.collect.List.of("cat", "dog");
|
||||
|
||||
private String jobId;
|
||||
private String sourceIndex;
|
||||
private String destIndex;
|
||||
|
||||
@Before
|
||||
public void setupLogging() {
|
||||
client().admin().cluster()
|
||||
.prepareUpdateSettings()
|
||||
.setTransientSettings(Settings.builder()
|
||||
.put("logger.org.elasticsearch.xpack.ml.dataframe", "DEBUG")
|
||||
.put("logger.org.elasticsearch.xpack.core.ml.inference", "DEBUG"))
|
||||
.get();
|
||||
}
|
||||
|
||||
@After
|
||||
public void cleanup() {
|
||||
cleanUp();
|
||||
client().admin().cluster()
|
||||
.prepareUpdateSettings()
|
||||
.setTransientSettings(Settings.builder()
|
||||
.putNull("logger.org.elasticsearch.xpack.ml.dataframe")
|
||||
.putNull("logger.org.elasticsearch.xpack.core.ml.inference"))
|
||||
.get();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NamedXContentRegistry xContentRegistry() {
|
||||
SearchModule searchModule = new SearchModule(Settings.EMPTY, false, Collections.emptyList());
|
||||
List<NamedXContentRegistry.Entry> entries = new ArrayList<>(searchModule.getNamedXContents());
|
||||
entries.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
|
||||
entries.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers());
|
||||
return new NamedXContentRegistry(entries);
|
||||
}
|
||||
|
||||
public void testNGramCustomFeature() throws Exception {
|
||||
initialize("test_ngram_feature_processor");
|
||||
String predictedClassField = NUMERICAL_FIELD + "_prediction";
|
||||
indexData(sourceIndex, 300, 50, NUMERICAL_FIELD);
|
||||
|
||||
DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder()
|
||||
.setId(jobId)
|
||||
.setSource(new DataFrameAnalyticsSource(new String[] { sourceIndex },
|
||||
QueryProvider.fromParsedQuery(QueryBuilders.matchAllQuery()), null))
|
||||
.setDest(new DataFrameAnalyticsDest(destIndex, null))
|
||||
.setAnalysis(new Regression(NUMERICAL_FIELD,
|
||||
BoostedTreeParams.builder().setNumTopFeatureImportanceValues(6).build(),
|
||||
null,
|
||||
null,
|
||||
42L,
|
||||
null,
|
||||
null,
|
||||
Collections.singletonList(new NGram(TEXT_FIELD, "f", new int[]{1, 2}, 0, 2, true))))
|
||||
.setAnalyzedFields(new FetchSourceContext(true, new String[]{TEXT_FIELD, NUMERICAL_FIELD}, new String[]{}))
|
||||
.build();
|
||||
putAnalytics(config);
|
||||
|
||||
assertIsStopped(jobId);
|
||||
assertProgressIsZero(jobId);
|
||||
|
||||
startAnalytics(jobId);
|
||||
waitUntilAnalyticsIsStopped(jobId);
|
||||
|
||||
client().admin().indices().refresh(new RefreshRequest(destIndex));
|
||||
SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get();
|
||||
for (SearchHit hit : sourceData.getHits()) {
|
||||
Map<String, Object> destDoc = getDestDoc(config, hit);
|
||||
Map<String, Object> resultsObject = getFieldValue(destDoc, "ml");
|
||||
@SuppressWarnings("unchecked")
|
||||
List<Map<String, Object>> importanceArray = (List<Map<String, Object>>)resultsObject.get("feature_importance");
|
||||
assertThat(importanceArray.stream().map(m -> m.get("feature_name").toString()).collect(Collectors.toSet()),
|
||||
everyItem(startsWith("f.")));
|
||||
}
|
||||
|
||||
assertProgressComplete(jobId);
|
||||
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||
assertInferenceModelPersisted(jobId);
|
||||
assertModelStatePersisted(stateDocId());
|
||||
}
|
||||
|
||||
private void initialize(String jobId) {
|
||||
initialize(jobId, false);
|
||||
}
|
||||
|
||||
private void initialize(String jobId, boolean isDatastream) {
|
||||
this.jobId = jobId;
|
||||
this.sourceIndex = jobId + "_source_index";
|
||||
this.destIndex = sourceIndex + "_results";
|
||||
boolean analysisUsesExistingDestIndex = randomBoolean();
|
||||
createIndex(sourceIndex, isDatastream);
|
||||
if (analysisUsesExistingDestIndex) {
|
||||
createIndex(destIndex, false);
|
||||
}
|
||||
}
|
||||
|
||||
private static void createIndex(String index, boolean isDatastream) {
|
||||
String mapping = "{\n" +
|
||||
" \"properties\": {\n" +
|
||||
" \"@timestamp\": {\n" +
|
||||
" \"type\": \"date\"\n" +
|
||||
" }," +
|
||||
" \""+ BOOLEAN_FIELD + "\": {\n" +
|
||||
" \"type\": \"boolean\"\n" +
|
||||
" }," +
|
||||
" \""+ NUMERICAL_FIELD + "\": {\n" +
|
||||
" \"type\": \"double\"\n" +
|
||||
" }," +
|
||||
" \""+ DISCRETE_NUMERICAL_FIELD + "\": {\n" +
|
||||
" \"type\": \"integer\"\n" +
|
||||
" }," +
|
||||
" \""+ TEXT_FIELD + "\": {\n" +
|
||||
" \"type\": \"text\"\n" +
|
||||
" }," +
|
||||
" \""+ KEYWORD_FIELD + "\": {\n" +
|
||||
" \"type\": \"keyword\"\n" +
|
||||
" }," +
|
||||
" \""+ NESTED_FIELD + "\": {\n" +
|
||||
" \"type\": \"keyword\"\n" +
|
||||
" }," +
|
||||
" \""+ ALIAS_TO_KEYWORD_FIELD + "\": {\n" +
|
||||
" \"type\": \"alias\",\n" +
|
||||
" \"path\": \"" + KEYWORD_FIELD + "\"\n" +
|
||||
" }," +
|
||||
" \""+ ALIAS_TO_NESTED_FIELD + "\": {\n" +
|
||||
" \"type\": \"alias\",\n" +
|
||||
" \"path\": \"" + NESTED_FIELD + "\"\n" +
|
||||
" }" +
|
||||
" }\n" +
|
||||
" }";
|
||||
if (isDatastream) {
|
||||
try {
|
||||
createDataStreamAndTemplate(index, mapping);
|
||||
} catch (IOException ex) {
|
||||
throw new ElasticsearchException(ex);
|
||||
}
|
||||
} else {
|
||||
client().admin().indices().prepareCreate(index)
|
||||
.addMapping("_doc", mapping, XContentType.JSON)
|
||||
.get();
|
||||
}
|
||||
}
|
||||
|
||||
private static void indexData(String sourceIndex, int numTrainingRows, int numNonTrainingRows, String dependentVariable) {
|
||||
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
|
||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
|
||||
for (int i = 0; i < numTrainingRows; i++) {
|
||||
List<Object> source = org.elasticsearch.common.collect.List.of(
|
||||
"@timestamp", "2020-12-12",
|
||||
BOOLEAN_FIELD, BOOLEAN_FIELD_VALUES.get(i % BOOLEAN_FIELD_VALUES.size()),
|
||||
NUMERICAL_FIELD, NUMERICAL_FIELD_VALUES.get(i % NUMERICAL_FIELD_VALUES.size()),
|
||||
DISCRETE_NUMERICAL_FIELD, DISCRETE_NUMERICAL_FIELD_VALUES.get(i % DISCRETE_NUMERICAL_FIELD_VALUES.size()),
|
||||
TEXT_FIELD, KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size()),
|
||||
KEYWORD_FIELD, KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size()),
|
||||
NESTED_FIELD, KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size()));
|
||||
IndexRequest indexRequest = new IndexRequest(sourceIndex).source(source.toArray()).opType(DocWriteRequest.OpType.CREATE);
|
||||
bulkRequestBuilder.add(indexRequest);
|
||||
}
|
||||
for (int i = numTrainingRows; i < numTrainingRows + numNonTrainingRows; i++) {
|
||||
List<Object> source = new ArrayList<>();
|
||||
if (BOOLEAN_FIELD.equals(dependentVariable) == false) {
|
||||
source.addAll(org.elasticsearch.common.collect.List.of(BOOLEAN_FIELD,
|
||||
BOOLEAN_FIELD_VALUES.get(i % BOOLEAN_FIELD_VALUES.size())));
|
||||
}
|
||||
if (NUMERICAL_FIELD.equals(dependentVariable) == false) {
|
||||
source.addAll(org.elasticsearch.common.collect.List.of(NUMERICAL_FIELD,
|
||||
NUMERICAL_FIELD_VALUES.get(i % NUMERICAL_FIELD_VALUES.size())));
|
||||
}
|
||||
if (DISCRETE_NUMERICAL_FIELD.equals(dependentVariable) == false) {
|
||||
source.addAll(
|
||||
org.elasticsearch.common.collect.List.of(DISCRETE_NUMERICAL_FIELD,
|
||||
DISCRETE_NUMERICAL_FIELD_VALUES.get(i % DISCRETE_NUMERICAL_FIELD_VALUES.size())));
|
||||
}
|
||||
if (TEXT_FIELD.equals(dependentVariable) == false) {
|
||||
source.addAll(org.elasticsearch.common.collect.List.of(TEXT_FIELD,
|
||||
KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size())));
|
||||
}
|
||||
if (KEYWORD_FIELD.equals(dependentVariable) == false) {
|
||||
source.addAll(org.elasticsearch.common.collect.List.of(KEYWORD_FIELD,
|
||||
KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size())));
|
||||
}
|
||||
if (NESTED_FIELD.equals(dependentVariable) == false) {
|
||||
source.addAll(org.elasticsearch.common.collect.List.of(NESTED_FIELD,
|
||||
KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size())));
|
||||
}
|
||||
source.addAll(org.elasticsearch.common.collect.List.of("@timestamp", "2020-12-12"));
|
||||
IndexRequest indexRequest = new IndexRequest(sourceIndex).source(source.toArray()).opType(DocWriteRequest.OpType.CREATE);
|
||||
bulkRequestBuilder.add(indexRequest);
|
||||
}
|
||||
BulkResponse bulkResponse = bulkRequestBuilder.get();
|
||||
if (bulkResponse.hasFailures()) {
|
||||
fail("Failed to index data: " + bulkResponse.buildFailureMessage());
|
||||
}
|
||||
}
|
||||
|
||||
private static Map<String, Object> getDestDoc(DataFrameAnalyticsConfig config, SearchHit hit) {
|
||||
GetResponse destDocGetResponse = client().prepareGet().setIndex(config.getDest().getIndex()).setId(hit.getId()).get();
|
||||
assertThat(destDocGetResponse.isExists(), is(true));
|
||||
Map<String, Object> sourceDoc = hit.getSourceAsMap();
|
||||
Map<String, Object> destDoc = destDocGetResponse.getSource();
|
||||
for (String field : sourceDoc.keySet()) {
|
||||
assertThat(destDoc, hasKey(field));
|
||||
assertThat(destDoc.get(field), equalTo(sourceDoc.get(field)));
|
||||
}
|
||||
return destDoc;
|
||||
}
|
||||
|
||||
private String stateDocId() {
|
||||
return jobId + "_regression_state#1";
|
||||
}
|
||||
|
||||
@Override
|
||||
boolean supportsInference() {
|
||||
return true;
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user