[7.x] [ML] adds new n_gram_encoding custom processor (#61578) (#61935)

* [ML] adds new n_gram_encoding custom processor (#61578)

This adds a new `n_gram_encoding` feature processor for analytics and inference.

The focus of this processor is simple ngram encodings that allow:
 - multiple ngrams [1..5]
 - Prefix, infix, suffix
This commit is contained in:
Benjamin Trent 2020-09-04 08:36:50 -04:00 committed by GitHub
parent 7b021bf3fb
commit cec102a391
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 1015 additions and 10 deletions

View File

@ -19,6 +19,7 @@
package org.elasticsearch.client.ml.inference;
import org.elasticsearch.client.ml.inference.preprocessing.CustomWordEmbedding;
import org.elasticsearch.client.ml.inference.preprocessing.NGram;
import org.elasticsearch.client.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.client.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.client.ml.inference.trainedmodel.RegressionConfig;
@ -57,6 +58,8 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
FrequencyEncoding::fromXContent));
namedXContent.add(new NamedXContentRegistry.Entry(PreProcessor.class, new ParseField(CustomWordEmbedding.NAME),
CustomWordEmbedding::fromXContent));
namedXContent.add(new NamedXContentRegistry.Entry(PreProcessor.class, new ParseField(NGram.NAME),
NGram::fromXContent));
// Model
namedXContent.add(new NamedXContentRegistry.Entry(TrainedModel.class, new ParseField(Tree.NAME), Tree::fromXContent));

View File

@ -0,0 +1,211 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference.preprocessing;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.List;
import java.util.Objects;
/**
* PreProcessor for n-gram encoding a string
*/
public class NGram implements PreProcessor {
public static final String NAME = "n_gram_encoding";
public static final ParseField FIELD = new ParseField("field");
public static final ParseField FEATURE_PREFIX = new ParseField("feature_prefix");
public static final ParseField NGRAMS = new ParseField("n_grams");
public static final ParseField START = new ParseField("start");
public static final ParseField LENGTH = new ParseField("length");
public static final ParseField CUSTOM = new ParseField("custom");
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<NGram, Void> PARSER = new ConstructingObjectParser<NGram, Void>(
NAME,
true,
a -> new NGram((String)a[0],
(List<Integer>)a[1],
(Integer)a[2],
(Integer)a[3],
(Boolean)a[4],
(String)a[5]));
static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), FIELD);
PARSER.declareIntArray(ConstructingObjectParser.constructorArg(), NGRAMS);
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), START);
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), LENGTH);
PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), CUSTOM);
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), FEATURE_PREFIX);
}
public static NGram fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
private final String field;
private final String featurePrefix;
private final List<Integer> nGrams;
private final Integer start;
private final Integer length;
private final Boolean custom;
NGram(String field, List<Integer> nGrams, Integer start, Integer length, Boolean custom, String featurePrefix) {
this.field = field;
this.featurePrefix = featurePrefix;
this.nGrams = nGrams;
this.start = start;
this.length = length;
this.custom = custom;
}
@Override
public String getName() {
return NAME;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (field != null) {
builder.field(FIELD.getPreferredName(), field);
}
if (featurePrefix != null) {
builder.field(FEATURE_PREFIX.getPreferredName(), featurePrefix);
}
if (nGrams != null) {
builder.field(NGRAMS.getPreferredName(), nGrams);
}
if (start != null) {
builder.field(START.getPreferredName(), start);
}
if (length != null) {
builder.field(LENGTH.getPreferredName(), length);
}
if (custom != null) {
builder.field(CUSTOM.getPreferredName(), custom);
}
builder.endObject();
return builder;
}
public String getField() {
return field;
}
public String getFeaturePrefix() {
return featurePrefix;
}
public List<Integer> getnGrams() {
return nGrams;
}
public Integer getStart() {
return start;
}
public Integer getLength() {
return length;
}
public Boolean getCustom() {
return custom;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
NGram nGram = (NGram) o;
return Objects.equals(field, nGram.field) &&
Objects.equals(featurePrefix, nGram.featurePrefix) &&
Objects.equals(nGrams, nGram.nGrams) &&
Objects.equals(start, nGram.start) &&
Objects.equals(length, nGram.length) &&
Objects.equals(custom, nGram.custom);
}
@Override
public int hashCode() {
return Objects.hash(field, featurePrefix, start, length, custom, nGrams);
}
public static Builder builder(String field) {
return new Builder(field);
}
public static class Builder {
private String field;
private String featurePrefix;
private List<Integer> nGrams;
private Integer start;
private Integer length;
private Boolean custom;
public Builder(String field) {
this.field = field;
}
public Builder setField(String field) {
this.field = field;
return this;
}
public Builder setCustom(boolean custom) {
this.custom = custom;
return this;
}
public Builder setFeaturePrefix(String featurePrefix) {
this.featurePrefix = featurePrefix;
return this;
}
public Builder setnGrams(List<Integer> nGrams) {
this.nGrams = nGrams;
return this;
}
public Builder setStart(Integer start) {
this.start = start;
return this;
}
public Builder setLength(Integer length) {
this.length = length;
return this;
}
public Builder setCustom(Boolean custom) {
this.custom = custom;
return this;
}
public NGram build() {
return new NGram(field, nGrams, start, length, custom, featurePrefix);
}
}
}

View File

@ -74,6 +74,7 @@ import org.elasticsearch.client.ml.dataframe.stats.outlierdetection.OutlierDetec
import org.elasticsearch.client.ml.dataframe.stats.regression.RegressionStats;
import org.elasticsearch.client.ml.inference.preprocessing.CustomWordEmbedding;
import org.elasticsearch.client.ml.inference.preprocessing.FrequencyEncoding;
import org.elasticsearch.client.ml.inference.preprocessing.NGram;
import org.elasticsearch.client.ml.inference.preprocessing.OneHotEncoding;
import org.elasticsearch.client.ml.inference.preprocessing.TargetMeanEncoding;
import org.elasticsearch.client.ml.inference.trainedmodel.ClassificationConfig;
@ -704,7 +705,7 @@ public class RestHighLevelClientTests extends ESTestCase {
public void testProvidedNamedXContents() {
List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
assertEquals(69, namedXContents.size());
assertEquals(70, namedXContents.size());
Map<Class<?>, Integer> categories = new HashMap<>();
List<String> names = new ArrayList<>();
for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
@ -785,8 +786,9 @@ public class RestHighLevelClientTests extends ESTestCase {
registeredMetricName(Regression.NAME, MeanSquaredLogarithmicErrorMetric.NAME),
registeredMetricName(Regression.NAME, HuberMetric.NAME),
registeredMetricName(Regression.NAME, RSquaredMetric.NAME)));
assertEquals(Integer.valueOf(4), categories.get(org.elasticsearch.client.ml.inference.preprocessing.PreProcessor.class));
assertThat(names, hasItems(FrequencyEncoding.NAME, OneHotEncoding.NAME, TargetMeanEncoding.NAME, CustomWordEmbedding.NAME));
assertEquals(Integer.valueOf(5), categories.get(org.elasticsearch.client.ml.inference.preprocessing.PreProcessor.class));
assertThat(names,
hasItems(FrequencyEncoding.NAME, OneHotEncoding.NAME, TargetMeanEncoding.NAME, CustomWordEmbedding.NAME, NGram.NAME));
assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel.class));
assertThat(names, hasItems(Tree.NAME, Ensemble.NAME, LangIdentNeuralNetwork.NAME));
assertEquals(Integer.valueOf(4),

View File

@ -0,0 +1,55 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference.preprocessing;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
public class NGramTests extends AbstractXContentTestCase<NGram> {
@Override
protected NGram doParseInstance(XContentParser parser) throws IOException {
return NGram.fromXContent(parser);
}
@Override
protected boolean supportsUnknownFields() {
return true;
}
@Override
protected NGram createTestInstance() {
return createRandom();
}
public static NGram createRandom() {
return new NGram(randomAlphaOfLength(10),
IntStream.range(1, 5).limit(5).boxed().collect(Collectors.toList()),
randomBoolean() ? null : randomIntBetween(0, 10),
randomBoolean() ? null : randomIntBetween(1, 10),
randomBoolean() ? null : randomBoolean(),
randomBoolean() ? null : randomAlphaOfLength(10));
}
}

View File

@ -161,6 +161,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierD
import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStats;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.CustomWordEmbedding;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.NGram;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncoding;
@ -525,6 +526,7 @@ public class XPackClientPlugin extends Plugin implements ActionPlugin, NetworkPl
new NamedWriteableRegistry.Entry(PreProcessor.class, OneHotEncoding.NAME.getPreferredName(), OneHotEncoding::new),
new NamedWriteableRegistry.Entry(PreProcessor.class, TargetMeanEncoding.NAME.getPreferredName(), TargetMeanEncoding::new),
new NamedWriteableRegistry.Entry(PreProcessor.class, CustomWordEmbedding.NAME.getPreferredName(), CustomWordEmbedding::new),
new NamedWriteableRegistry.Entry(PreProcessor.class, NGram.NAME.getPreferredName(), NGram::new),
// ML - Inference models
new NamedWriteableRegistry.Entry(TrainedModel.class, Tree.NAME.getPreferredName(), Tree::new),
new NamedWriteableRegistry.Entry(TrainedModel.class, Ensemble.NAME.getPreferredName(), Ensemble::new),

View File

@ -9,6 +9,13 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.plugins.spi.NamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.CustomWordEmbedding;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.NGram;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncoding;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
@ -39,12 +46,6 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.Inferenc
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.TreeInferenceModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.langident.LangIdentNeuralNetwork;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncoding;
import java.util.ArrayList;
import java.util.List;
@ -64,6 +65,8 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
(p, c) -> FrequencyEncoding.fromXContentLenient(p, (PreProcessor.PreProcessorParseContext) c)));
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, CustomWordEmbedding.NAME,
(p, c) -> CustomWordEmbedding.fromXContentLenient(p)));
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, NGram.NAME,
(p, c) -> NGram.fromXContentLenient(p, (PreProcessor.PreProcessorParseContext) c)));
// PreProcessing Strict
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, OneHotEncoding.NAME,
@ -74,6 +77,8 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
(p, c) -> FrequencyEncoding.fromXContentStrict(p, (PreProcessor.PreProcessorParseContext) c)));
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, CustomWordEmbedding.NAME,
(p, c) -> CustomWordEmbedding.fromXContentStrict(p)));
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, NGram.NAME,
(p, c) -> NGram.fromXContentStrict(p, (PreProcessor.PreProcessorParseContext) c)));
// Model Lenient
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedTrainedModel.class, Tree.NAME, Tree::fromXContentLenient));
@ -154,6 +159,8 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
FrequencyEncoding::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(PreProcessor.class, CustomWordEmbedding.NAME.getPreferredName(),
CustomWordEmbedding::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(PreProcessor.class, NGram.NAME.getPreferredName(),
NGram::new));
// Model
namedWriteables.add(new NamedWriteableRegistry.Entry(TrainedModel.class, Tree.NAME.getPreferredName(), Tree::new));

View File

@ -0,0 +1,304 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.preprocessing;
import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.mapper.TextFieldMapper;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;
import java.util.function.IntFunction;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import static org.apache.lucene.util.RamUsageEstimator.sizeOf;
/**
* PreProcessor for n-gram encoding a string
*/
public class NGram implements LenientlyParsedPreProcessor, StrictlyParsedPreProcessor {
private static final int DEFAULT_START = 0;
private static final int DEFAULT_LENGTH = 50;
private static final int MAX_LENGTH = 100;
private static final int MIN_GRAM = 1;
private static final int MAX_GRAM = 5;
private static String defaultPrefix(Integer start, Integer length) {
return "ngram_"
+ (start == null ? DEFAULT_START : start)
+ "_"
+ (length == null ? DEFAULT_LENGTH : length);
}
public static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(NGram.class);
public static final ParseField NAME = new ParseField("n_gram_encoding");
public static final ParseField FIELD = new ParseField("field");
public static final ParseField FEATURE_PREFIX = new ParseField("feature_prefix");
public static final ParseField NGRAMS = new ParseField("n_grams");
public static final ParseField START = new ParseField("start");
public static final ParseField LENGTH = new ParseField("length");
public static final ParseField CUSTOM = new ParseField("custom");
private static final ConstructingObjectParser<NGram, PreProcessorParseContext> STRICT_PARSER = createParser(false);
private static final ConstructingObjectParser<NGram, PreProcessorParseContext> LENIENT_PARSER = createParser(true);
@SuppressWarnings("unchecked")
private static ConstructingObjectParser<NGram, PreProcessorParseContext> createParser(boolean lenient) {
ConstructingObjectParser<NGram, PreProcessorParseContext> parser = new ConstructingObjectParser<>(
NAME.getPreferredName(),
lenient,
(a, c) -> new NGram((String)a[0],
(List<Integer>)a[1],
(Integer)a[2],
(Integer)a[3],
a[4] == null ? c.isCustomByDefault() : (Boolean)a[4],
(String)a[5]));
parser.declareString(ConstructingObjectParser.constructorArg(), FIELD);
parser.declareIntArray(ConstructingObjectParser.constructorArg(), NGRAMS);
parser.declareInt(ConstructingObjectParser.optionalConstructorArg(), START);
parser.declareInt(ConstructingObjectParser.optionalConstructorArg(), LENGTH);
parser.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), CUSTOM);
parser.declareString(ConstructingObjectParser.optionalConstructorArg(), FEATURE_PREFIX);
return parser;
}
public static NGram fromXContentStrict(XContentParser parser, PreProcessorParseContext context) {
return STRICT_PARSER.apply(parser, context == null ? PreProcessorParseContext.DEFAULT : context);
}
public static NGram fromXContentLenient(XContentParser parser, PreProcessorParseContext context) {
return LENIENT_PARSER.apply(parser, context == null ? PreProcessorParseContext.DEFAULT : context);
}
private final String field;
private final String featurePrefix;
private final int[] nGrams;
private final int start;
private final int length;
private final boolean custom;
NGram(String field,
List<Integer> nGrams,
Integer start,
Integer length,
Boolean custom,
String featurePrefix) {
this(field,
featurePrefix == null ? defaultPrefix(start, length) : featurePrefix,
Sets.newHashSet(nGrams).stream().mapToInt(Integer::intValue).toArray(),
start == null ? DEFAULT_START : start,
length == null ? DEFAULT_LENGTH : length,
custom != null && custom);
}
public NGram(String field, String featurePrefix, int[] nGrams, int start, int length, boolean custom) {
this.field = ExceptionsHelper.requireNonNull(field, FIELD);
this.featurePrefix = ExceptionsHelper.requireNonNull(featurePrefix, FEATURE_PREFIX);
this.nGrams = ExceptionsHelper.requireNonNull(nGrams, NGRAMS);
if (Arrays.stream(this.nGrams).anyMatch(i -> i < MIN_GRAM || i > MAX_GRAM)) {
throw ExceptionsHelper.badRequestException(
"[{}] is invalid [{}]; minimum supported value is [{}]; maximum supported value is [{}]",
NGRAMS.getPreferredName(),
Arrays.stream(nGrams).mapToObj(String::valueOf).collect(Collectors.joining(", ")),
MIN_GRAM,
MAX_GRAM);
}
this.start = start;
if (start < 0 && length + start > 0) {
throw ExceptionsHelper.badRequestException(
"if [start] is negative, [length] + [start] must be less than 0");
}
this.length = length;
if (length <= 0) {
throw ExceptionsHelper.badRequestException("[{}] must be a positive integer", LENGTH.getPreferredName());
}
if (length > MAX_LENGTH) {
throw ExceptionsHelper.badRequestException("[{}] must be not be greater than [{}]", LENGTH.getPreferredName(), MAX_LENGTH);
}
this.custom = custom;
}
public NGram(StreamInput in) throws IOException {
this.field = in.readString();
this.featurePrefix = in.readString();
this.nGrams = in.readVIntArray();
this.start = in.readInt();
this.length = in.readVInt();
this.custom = in.readBoolean();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(field);
out.writeString(featurePrefix);
out.writeVIntArray(nGrams);
out.writeInt(start);
out.writeVInt(length);
out.writeBoolean(custom);
}
@Override
public String toString() {
return Strings.toString(this);
}
@Override
public List<String> inputFields() {
return Collections.singletonList(field);
}
@Override
public List<String> outputFields() {
return allPossibleNGramOutputFeatureNames();
}
@Override
public void process(Map<String, Object> fields) {
Object value = fields.get(field);
if (value == null) {
return;
}
final String stringValue = value.toString();
// String is too small for the starting point
if (start > stringValue.length() || stringValue.length() + start < 0) {
return;
}
final int startPos = start < 0 ? (stringValue.length() + start) : start;
final int len = Math.min(startPos + length, stringValue.length());
for (int i = 0; i < len; i++) {
for (int nGram : nGrams) {
if (startPos + i + nGram > len) {
break;
}
fields.put(nGramFeature(nGram, i), stringValue.substring(startPos + i, startPos + i + nGram));
}
}
}
@Override
public Map<String, String> reverseLookup() {
return outputFields().stream().collect(Collectors.toMap(Function.identity(), ignored -> field));
}
@Override
public String getOutputFieldType(String outputField) {
return TextFieldMapper.CONTENT_TYPE;
}
@Override
public long ramBytesUsed() {
long size = SHALLOW_SIZE;
size += sizeOf(field);
size += sizeOf(featurePrefix);
size += sizeOf(nGrams);
return size;
}
@Override
public String getWriteableName() {
return NAME.getPreferredName();
}
@Override
public String getName() {
return NAME.getPreferredName();
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(FIELD.getPreferredName(), field);
builder.field(FEATURE_PREFIX.getPreferredName(), featurePrefix);
builder.field(NGRAMS.getPreferredName(), nGrams);
builder.field(START.getPreferredName(), start);
builder.field(LENGTH.getPreferredName(), length);
builder.field(CUSTOM.getPreferredName(), custom);
builder.endObject();
return builder;
}
public String getField() {
return field;
}
public String getFeaturePrefix() {
return featurePrefix;
}
public int[] getnGrams() {
return nGrams;
}
public int getStart() {
return start;
}
public int getLength() {
return length;
}
@Override
public boolean isCustom() {
return custom;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
NGram nGram = (NGram) o;
return start == nGram.start &&
length == nGram.length &&
custom == nGram.custom &&
Objects.equals(field, nGram.field) &&
Objects.equals(featurePrefix, nGram.featurePrefix) &&
Arrays.equals(nGrams, nGram.nGrams);
}
@Override
public int hashCode() {
int result = Objects.hash(field, featurePrefix, start, length, custom);
result = 31 * result + Arrays.hashCode(nGrams);
return result;
}
private String nGramFeature(int nGram, int pos) {
return featurePrefix
+ "."
+ nGram
+ pos;
}
private List<String> allPossibleNGramOutputFeatureNames() {
int totalNgrams = 0;
for (int nGram : nGrams) {
totalNgrams += (length - (nGram - 1));
}
List<String> ngramOutputs = new ArrayList<>(totalNgrams);
for (int nGram : nGrams) {
IntFunction<String> func = i -> nGramFeature(nGram, i);
IntStream.range(0, (length - (nGram - 1))).mapToObj(func).forEach(ngramOutputs::add);
}
return ngramOutputs;
}
}

View File

@ -36,7 +36,7 @@ public interface PreProcessor extends NamedXContentObject, NamedWriteable, Accou
List<String> inputFields();
/**
* @return The resulting output fields
* @return The resulting output fields. It is imperative that the order is consistent between calls.
*/
List<String> outputFields();

View File

@ -0,0 +1,144 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.preprocessing;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.hamcrest.Matcher;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.nullValue;
import static org.hamcrest.Matchers.contains;
public class NGramTests extends PreProcessingTests<NGram> {
@Override
protected NGram doParseInstance(XContentParser parser) throws IOException {
return lenient ?
NGram.fromXContentLenient(parser, PreProcessor.PreProcessorParseContext.DEFAULT) :
NGram.fromXContentStrict(parser, PreProcessor.PreProcessorParseContext.DEFAULT);
}
@Override
protected NGram createTestInstance() {
return createRandom();
}
public static NGram createRandom() {
return createRandom(randomBoolean() ? randomBoolean() : null);
}
public static NGram createRandom(Boolean isCustom) {
return new NGram(
randomAlphaOfLength(10),
IntStream.generate(() -> randomIntBetween(1, 5)).limit(5).boxed().collect(Collectors.toList()),
randomBoolean() ? null : randomIntBetween(0, 10),
randomBoolean() ? null : randomIntBetween(1, 10),
isCustom,
randomBoolean() ? null : randomAlphaOfLength(10));
}
@Override
protected Writeable.Reader<NGram> instanceReader() {
return NGram::new;
}
public void testProcessNGramPrefix() {
String field = "text";
String fieldValue = "this is the value";
NGram encoding = new NGram(field, "f", new int[]{1, 4}, 0, 5, false);
Map<String, Object> fieldValues = randomFieldValues(field, fieldValue);
Map<String, Matcher<? super Object>> matchers = new HashMap<>();
matchers.put("f.10", equalTo("t"));
matchers.put("f.11", equalTo("h"));
matchers.put("f.12", equalTo("i"));
matchers.put("f.13", equalTo("s"));
matchers.put("f.14", equalTo(" "));
matchers.put("f.40", equalTo("this"));
matchers.put("f.41", equalTo("his "));
testProcess(encoding, fieldValues, matchers);
}
public void testProcessNGramSuffix() {
String field = "text";
String fieldValue = "this is the value";
NGram encoding = new NGram(field, "f", new int[]{1, 3}, -3, 3, false);
Map<String, Object> fieldValues = randomFieldValues(field, fieldValue);
Map<String, Matcher<? super Object>> matchers = new HashMap<>();
matchers.put("f.10", equalTo("l"));
matchers.put("f.11", equalTo("u"));
matchers.put("f.12", equalTo("e"));
matchers.put("f.30", equalTo("lue"));
matchers.put("f.31", is(nullValue()));
testProcess(encoding, fieldValues, matchers);
}
public void testProcessNGramInfix() {
String field = "text";
String fieldValue = "this is the value";
NGram encoding = new NGram(field, "f", new int[]{1, 3}, 3, 3, false);
Map<String, Object> fieldValues = randomFieldValues(field, fieldValue);
Map<String, Matcher<? super Object>> matchers = new HashMap<>();
matchers.put("f.10", equalTo("s"));
matchers.put("f.11", equalTo(" "));
matchers.put("f.12", equalTo("i"));
matchers.put("f.30", equalTo("s i"));
matchers.put("f.31", is(nullValue()));
testProcess(encoding, fieldValues, matchers);
}
public void testProcessNGramLengthOverrun() {
String field = "text";
String fieldValue = "this is the value";
NGram encoding = new NGram(field, "f", new int[]{1, 3}, 12, 10, false);
Map<String, Object> fieldValues = randomFieldValues(field, fieldValue);
Map<String, Matcher<? super Object>> matchers = new HashMap<>();
matchers.put("f.10", equalTo("v"));
matchers.put("f.11", equalTo("a"));
matchers.put("f.12", equalTo("l"));
matchers.put("f.13", equalTo("u"));
matchers.put("f.14", equalTo("e"));
matchers.put("f.30", equalTo("val"));
matchers.put("f.31", equalTo("alu"));
matchers.put("f.32", equalTo("lue"));
testProcess(encoding, fieldValues, matchers);
}
public void testInputOutputFields() {
String field = randomAlphaOfLength(10);
NGram encoding = new NGram(field, "f", new int[]{1, 4}, 0, 5, false);
assertThat(encoding.inputFields(), containsInAnyOrder(field));
assertThat(encoding.outputFields(),
contains("f.10", "f.11","f.12","f.13","f.14","f.40", "f.41"));
encoding = new NGram(field, Arrays.asList(1, 4), 0, 5, false, null);
assertThat(encoding.inputFields(), containsInAnyOrder(field));
assertThat(encoding.outputFields(),
contains(
"ngram_0_5.10",
"ngram_0_5.11",
"ngram_0_5.12",
"ngram_0_5.13",
"ngram_0_5.14",
"ngram_0_5.40",
"ngram_0_5.41"));
}
}

View File

@ -0,0 +1,277 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.integration;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.DocWriteRequest;
import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
import org.elasticsearch.action.bulk.BulkRequestBuilder;
import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.action.get.GetResponse;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.search.fetch.subphase.FetchSourceContext;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.NGram;
import org.elasticsearch.xpack.core.ml.utils.QueryProvider;
import org.junit.After;
import org.junit.Before;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.everyItem;
import static org.hamcrest.Matchers.hasKey;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.startsWith;
public class DataFrameAnalysisCustomFeatureIT extends MlNativeDataFrameAnalyticsIntegTestCase {
private static final String BOOLEAN_FIELD = "boolean-field";
private static final String NUMERICAL_FIELD = "numerical-field";
private static final String DISCRETE_NUMERICAL_FIELD = "discrete-numerical-field";
private static final String TEXT_FIELD = "text-field";
private static final String KEYWORD_FIELD = "keyword-field";
private static final String NESTED_FIELD = "outer-field.inner-field";
private static final String ALIAS_TO_KEYWORD_FIELD = "alias-to-keyword-field";
private static final String ALIAS_TO_NESTED_FIELD = "alias-to-nested-field";
private static final List<Boolean> BOOLEAN_FIELD_VALUES = org.elasticsearch.common.collect.List.of(false, true);
private static final List<Double> NUMERICAL_FIELD_VALUES = org.elasticsearch.common.collect.List.of(1.0, 2.0);
private static final List<Integer> DISCRETE_NUMERICAL_FIELD_VALUES = org.elasticsearch.common.collect.List.of(10, 20);
private static final List<String> KEYWORD_FIELD_VALUES = org.elasticsearch.common.collect.List.of("cat", "dog");
private String jobId;
private String sourceIndex;
private String destIndex;
@Before
public void setupLogging() {
client().admin().cluster()
.prepareUpdateSettings()
.setTransientSettings(Settings.builder()
.put("logger.org.elasticsearch.xpack.ml.dataframe", "DEBUG")
.put("logger.org.elasticsearch.xpack.core.ml.inference", "DEBUG"))
.get();
}
@After
public void cleanup() {
cleanUp();
client().admin().cluster()
.prepareUpdateSettings()
.setTransientSettings(Settings.builder()
.putNull("logger.org.elasticsearch.xpack.ml.dataframe")
.putNull("logger.org.elasticsearch.xpack.core.ml.inference"))
.get();
}
@Override
protected NamedXContentRegistry xContentRegistry() {
SearchModule searchModule = new SearchModule(Settings.EMPTY, false, Collections.emptyList());
List<NamedXContentRegistry.Entry> entries = new ArrayList<>(searchModule.getNamedXContents());
entries.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
entries.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers());
return new NamedXContentRegistry(entries);
}
public void testNGramCustomFeature() throws Exception {
initialize("test_ngram_feature_processor");
String predictedClassField = NUMERICAL_FIELD + "_prediction";
indexData(sourceIndex, 300, 50, NUMERICAL_FIELD);
DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder()
.setId(jobId)
.setSource(new DataFrameAnalyticsSource(new String[] { sourceIndex },
QueryProvider.fromParsedQuery(QueryBuilders.matchAllQuery()), null))
.setDest(new DataFrameAnalyticsDest(destIndex, null))
.setAnalysis(new Regression(NUMERICAL_FIELD,
BoostedTreeParams.builder().setNumTopFeatureImportanceValues(6).build(),
null,
null,
42L,
null,
null,
Collections.singletonList(new NGram(TEXT_FIELD, "f", new int[]{1, 2}, 0, 2, true))))
.setAnalyzedFields(new FetchSourceContext(true, new String[]{TEXT_FIELD, NUMERICAL_FIELD}, new String[]{}))
.build();
putAnalytics(config);
assertIsStopped(jobId);
assertProgressIsZero(jobId);
startAnalytics(jobId);
waitUntilAnalyticsIsStopped(jobId);
client().admin().indices().refresh(new RefreshRequest(destIndex));
SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get();
for (SearchHit hit : sourceData.getHits()) {
Map<String, Object> destDoc = getDestDoc(config, hit);
Map<String, Object> resultsObject = getFieldValue(destDoc, "ml");
@SuppressWarnings("unchecked")
List<Map<String, Object>> importanceArray = (List<Map<String, Object>>)resultsObject.get("feature_importance");
assertThat(importanceArray.stream().map(m -> m.get("feature_name").toString()).collect(Collectors.toSet()),
everyItem(startsWith("f.")));
}
assertProgressComplete(jobId);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertInferenceModelPersisted(jobId);
assertModelStatePersisted(stateDocId());
}
private void initialize(String jobId) {
initialize(jobId, false);
}
private void initialize(String jobId, boolean isDatastream) {
this.jobId = jobId;
this.sourceIndex = jobId + "_source_index";
this.destIndex = sourceIndex + "_results";
boolean analysisUsesExistingDestIndex = randomBoolean();
createIndex(sourceIndex, isDatastream);
if (analysisUsesExistingDestIndex) {
createIndex(destIndex, false);
}
}
private static void createIndex(String index, boolean isDatastream) {
String mapping = "{\n" +
" \"properties\": {\n" +
" \"@timestamp\": {\n" +
" \"type\": \"date\"\n" +
" }," +
" \""+ BOOLEAN_FIELD + "\": {\n" +
" \"type\": \"boolean\"\n" +
" }," +
" \""+ NUMERICAL_FIELD + "\": {\n" +
" \"type\": \"double\"\n" +
" }," +
" \""+ DISCRETE_NUMERICAL_FIELD + "\": {\n" +
" \"type\": \"integer\"\n" +
" }," +
" \""+ TEXT_FIELD + "\": {\n" +
" \"type\": \"text\"\n" +
" }," +
" \""+ KEYWORD_FIELD + "\": {\n" +
" \"type\": \"keyword\"\n" +
" }," +
" \""+ NESTED_FIELD + "\": {\n" +
" \"type\": \"keyword\"\n" +
" }," +
" \""+ ALIAS_TO_KEYWORD_FIELD + "\": {\n" +
" \"type\": \"alias\",\n" +
" \"path\": \"" + KEYWORD_FIELD + "\"\n" +
" }," +
" \""+ ALIAS_TO_NESTED_FIELD + "\": {\n" +
" \"type\": \"alias\",\n" +
" \"path\": \"" + NESTED_FIELD + "\"\n" +
" }" +
" }\n" +
" }";
if (isDatastream) {
try {
createDataStreamAndTemplate(index, mapping);
} catch (IOException ex) {
throw new ElasticsearchException(ex);
}
} else {
client().admin().indices().prepareCreate(index)
.addMapping("_doc", mapping, XContentType.JSON)
.get();
}
}
private static void indexData(String sourceIndex, int numTrainingRows, int numNonTrainingRows, String dependentVariable) {
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
for (int i = 0; i < numTrainingRows; i++) {
List<Object> source = org.elasticsearch.common.collect.List.of(
"@timestamp", "2020-12-12",
BOOLEAN_FIELD, BOOLEAN_FIELD_VALUES.get(i % BOOLEAN_FIELD_VALUES.size()),
NUMERICAL_FIELD, NUMERICAL_FIELD_VALUES.get(i % NUMERICAL_FIELD_VALUES.size()),
DISCRETE_NUMERICAL_FIELD, DISCRETE_NUMERICAL_FIELD_VALUES.get(i % DISCRETE_NUMERICAL_FIELD_VALUES.size()),
TEXT_FIELD, KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size()),
KEYWORD_FIELD, KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size()),
NESTED_FIELD, KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size()));
IndexRequest indexRequest = new IndexRequest(sourceIndex).source(source.toArray()).opType(DocWriteRequest.OpType.CREATE);
bulkRequestBuilder.add(indexRequest);
}
for (int i = numTrainingRows; i < numTrainingRows + numNonTrainingRows; i++) {
List<Object> source = new ArrayList<>();
if (BOOLEAN_FIELD.equals(dependentVariable) == false) {
source.addAll(org.elasticsearch.common.collect.List.of(BOOLEAN_FIELD,
BOOLEAN_FIELD_VALUES.get(i % BOOLEAN_FIELD_VALUES.size())));
}
if (NUMERICAL_FIELD.equals(dependentVariable) == false) {
source.addAll(org.elasticsearch.common.collect.List.of(NUMERICAL_FIELD,
NUMERICAL_FIELD_VALUES.get(i % NUMERICAL_FIELD_VALUES.size())));
}
if (DISCRETE_NUMERICAL_FIELD.equals(dependentVariable) == false) {
source.addAll(
org.elasticsearch.common.collect.List.of(DISCRETE_NUMERICAL_FIELD,
DISCRETE_NUMERICAL_FIELD_VALUES.get(i % DISCRETE_NUMERICAL_FIELD_VALUES.size())));
}
if (TEXT_FIELD.equals(dependentVariable) == false) {
source.addAll(org.elasticsearch.common.collect.List.of(TEXT_FIELD,
KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size())));
}
if (KEYWORD_FIELD.equals(dependentVariable) == false) {
source.addAll(org.elasticsearch.common.collect.List.of(KEYWORD_FIELD,
KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size())));
}
if (NESTED_FIELD.equals(dependentVariable) == false) {
source.addAll(org.elasticsearch.common.collect.List.of(NESTED_FIELD,
KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size())));
}
source.addAll(org.elasticsearch.common.collect.List.of("@timestamp", "2020-12-12"));
IndexRequest indexRequest = new IndexRequest(sourceIndex).source(source.toArray()).opType(DocWriteRequest.OpType.CREATE);
bulkRequestBuilder.add(indexRequest);
}
BulkResponse bulkResponse = bulkRequestBuilder.get();
if (bulkResponse.hasFailures()) {
fail("Failed to index data: " + bulkResponse.buildFailureMessage());
}
}
private static Map<String, Object> getDestDoc(DataFrameAnalyticsConfig config, SearchHit hit) {
GetResponse destDocGetResponse = client().prepareGet().setIndex(config.getDest().getIndex()).setId(hit.getId()).get();
assertThat(destDocGetResponse.isExists(), is(true));
Map<String, Object> sourceDoc = hit.getSourceAsMap();
Map<String, Object> destDoc = destDocGetResponse.getSource();
for (String field : sourceDoc.keySet()) {
assertThat(destDoc, hasKey(field));
assertThat(destDoc.get(field), equalTo(sourceDoc.get(field)));
}
return destDoc;
}
private String stateDocId() {
return jobId + "_regression_state#1";
}
@Override
boolean supportsInference() {
return true;
}
}