[ML] adds new feature_processors field for data frame analytics (#60528) (#61148)

feature_processors allow users to create custom features from
individual document fields.

These `feature_processors` are the same object as the trained model's pre_processors.

They are passed to the native process and the native process then appends them to the
pre_processor array in the inference model.

closes https://github.com/elastic/elasticsearch/issues/59327
This commit is contained in:
Benjamin Trent 2020-08-14 10:32:20 -04:00 committed by GitHub
parent d1b60269f4
commit 8f302282f4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
43 changed files with 1591 additions and 195 deletions

View File

@ -15,10 +15,14 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.mapper.FieldAliasMapper; import org.elasticsearch.index.mapper.FieldAliasMapper;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays; import java.util.Arrays;
@ -46,6 +50,7 @@ public class Classification implements DataFrameAnalysis {
public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes"); public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
public static final ParseField TRAINING_PERCENT = new ParseField("training_percent"); public static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
public static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed"); public static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");
public static final ParseField FEATURE_PROCESSORS = new ParseField("feature_processors");
private static final String STATE_DOC_ID_SUFFIX = "_classification_state#1"; private static final String STATE_DOC_ID_SUFFIX = "_classification_state#1";
@ -59,6 +64,7 @@ public class Classification implements DataFrameAnalysis {
*/ */
public static final int MAX_DEPENDENT_VARIABLE_CARDINALITY = 30; public static final int MAX_DEPENDENT_VARIABLE_CARDINALITY = 30;
@SuppressWarnings("unchecked")
private static ConstructingObjectParser<Classification, Void> createParser(boolean lenient) { private static ConstructingObjectParser<Classification, Void> createParser(boolean lenient) {
ConstructingObjectParser<Classification, Void> parser = new ConstructingObjectParser<>( ConstructingObjectParser<Classification, Void> parser = new ConstructingObjectParser<>(
NAME.getPreferredName(), NAME.getPreferredName(),
@ -70,7 +76,8 @@ public class Classification implements DataFrameAnalysis {
(ClassAssignmentObjective) a[8], (ClassAssignmentObjective) a[8],
(Integer) a[9], (Integer) a[9],
(Double) a[10], (Double) a[10],
(Long) a[11])); (Long) a[11],
(List<PreProcessor>) a[12]));
parser.declareString(constructorArg(), DEPENDENT_VARIABLE); parser.declareString(constructorArg(), DEPENDENT_VARIABLE);
BoostedTreeParams.declareFields(parser); BoostedTreeParams.declareFields(parser);
parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME); parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME);
@ -78,6 +85,12 @@ public class Classification implements DataFrameAnalysis {
parser.declareInt(optionalConstructorArg(), NUM_TOP_CLASSES); parser.declareInt(optionalConstructorArg(), NUM_TOP_CLASSES);
parser.declareDouble(optionalConstructorArg(), TRAINING_PERCENT); parser.declareDouble(optionalConstructorArg(), TRAINING_PERCENT);
parser.declareLong(optionalConstructorArg(), RANDOMIZE_SEED); parser.declareLong(optionalConstructorArg(), RANDOMIZE_SEED);
parser.declareNamedObjects(optionalConstructorArg(),
(p, c, n) -> lenient ?
p.namedObject(LenientlyParsedPreProcessor.class, n, new PreProcessor.PreProcessorParseContext(true)) :
p.namedObject(StrictlyParsedPreProcessor.class, n, new PreProcessor.PreProcessorParseContext(true)),
(classification) -> {/*TODO should we throw if this is not set?*/},
FEATURE_PROCESSORS);
return parser; return parser;
} }
@ -119,6 +132,7 @@ public class Classification implements DataFrameAnalysis {
private final int numTopClasses; private final int numTopClasses;
private final double trainingPercent; private final double trainingPercent;
private final long randomizeSeed; private final long randomizeSeed;
private final List<PreProcessor> featureProcessors;
public Classification(String dependentVariable, public Classification(String dependentVariable,
BoostedTreeParams boostedTreeParams, BoostedTreeParams boostedTreeParams,
@ -126,7 +140,8 @@ public class Classification implements DataFrameAnalysis {
@Nullable ClassAssignmentObjective classAssignmentObjective, @Nullable ClassAssignmentObjective classAssignmentObjective,
@Nullable Integer numTopClasses, @Nullable Integer numTopClasses,
@Nullable Double trainingPercent, @Nullable Double trainingPercent,
@Nullable Long randomizeSeed) { @Nullable Long randomizeSeed,
@Nullable List<PreProcessor> featureProcessors) {
if (numTopClasses != null && (numTopClasses < 0 || numTopClasses > 1000)) { if (numTopClasses != null && (numTopClasses < 0 || numTopClasses > 1000)) {
throw ExceptionsHelper.badRequestException("[{}] must be an integer in [0, 1000]", NUM_TOP_CLASSES.getPreferredName()); throw ExceptionsHelper.badRequestException("[{}] must be an integer in [0, 1000]", NUM_TOP_CLASSES.getPreferredName());
} }
@ -141,10 +156,11 @@ public class Classification implements DataFrameAnalysis {
this.numTopClasses = numTopClasses == null ? DEFAULT_NUM_TOP_CLASSES : numTopClasses; this.numTopClasses = numTopClasses == null ? DEFAULT_NUM_TOP_CLASSES : numTopClasses;
this.trainingPercent = trainingPercent == null ? 100.0 : trainingPercent; this.trainingPercent = trainingPercent == null ? 100.0 : trainingPercent;
this.randomizeSeed = randomizeSeed == null ? Randomness.get().nextLong() : randomizeSeed; this.randomizeSeed = randomizeSeed == null ? Randomness.get().nextLong() : randomizeSeed;
this.featureProcessors = featureProcessors == null ? Collections.emptyList() : Collections.unmodifiableList(featureProcessors);
} }
public Classification(String dependentVariable) { public Classification(String dependentVariable) {
this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null, null); this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null, null, null);
} }
public Classification(StreamInput in) throws IOException { public Classification(StreamInput in) throws IOException {
@ -163,6 +179,11 @@ public class Classification implements DataFrameAnalysis {
} else { } else {
randomizeSeed = Randomness.get().nextLong(); randomizeSeed = Randomness.get().nextLong();
} }
if (in.getVersion().onOrAfter(Version.V_7_10_0)) {
featureProcessors = Collections.unmodifiableList(in.readNamedWriteableList(PreProcessor.class));
} else {
featureProcessors = Collections.emptyList();
}
} }
public String getDependentVariable() { public String getDependentVariable() {
@ -193,6 +214,10 @@ public class Classification implements DataFrameAnalysis {
return randomizeSeed; return randomizeSeed;
} }
public List<PreProcessor> getFeatureProcessors() {
return featureProcessors;
}
@Override @Override
public String getWriteableName() { public String getWriteableName() {
return NAME.getPreferredName(); return NAME.getPreferredName();
@ -211,6 +236,9 @@ public class Classification implements DataFrameAnalysis {
if (out.getVersion().onOrAfter(Version.V_7_6_0)) { if (out.getVersion().onOrAfter(Version.V_7_6_0)) {
out.writeOptionalLong(randomizeSeed); out.writeOptionalLong(randomizeSeed);
} }
if (out.getVersion().onOrAfter(Version.V_7_10_0)) {
out.writeNamedWriteableList(featureProcessors);
}
} }
@Override @Override
@ -229,6 +257,9 @@ public class Classification implements DataFrameAnalysis {
if (version.onOrAfter(Version.V_7_6_0)) { if (version.onOrAfter(Version.V_7_6_0)) {
builder.field(RANDOMIZE_SEED.getPreferredName(), randomizeSeed); builder.field(RANDOMIZE_SEED.getPreferredName(), randomizeSeed);
} }
if (featureProcessors.isEmpty() == false) {
NamedXContentObjectHelper.writeNamedObjects(builder, params, true, FEATURE_PROCESSORS.getPreferredName(), featureProcessors);
}
builder.endObject(); builder.endObject();
return builder; return builder;
} }
@ -249,6 +280,10 @@ public class Classification implements DataFrameAnalysis {
} }
params.put(NUM_CLASSES, fieldInfo.getCardinality(dependentVariable)); params.put(NUM_CLASSES, fieldInfo.getCardinality(dependentVariable));
params.put(TRAINING_PERCENT.getPreferredName(), trainingPercent); params.put(TRAINING_PERCENT.getPreferredName(), trainingPercent);
if (featureProcessors.isEmpty() == false) {
params.put(FEATURE_PROCESSORS.getPreferredName(),
featureProcessors.stream().map(p -> Collections.singletonMap(p.getName(), p)).collect(Collectors.toList()));
}
return params; return params;
} }
@ -390,6 +425,7 @@ public class Classification implements DataFrameAnalysis {
&& Objects.equals(predictionFieldName, that.predictionFieldName) && Objects.equals(predictionFieldName, that.predictionFieldName)
&& Objects.equals(classAssignmentObjective, that.classAssignmentObjective) && Objects.equals(classAssignmentObjective, that.classAssignmentObjective)
&& Objects.equals(numTopClasses, that.numTopClasses) && Objects.equals(numTopClasses, that.numTopClasses)
&& Objects.equals(featureProcessors, that.featureProcessors)
&& trainingPercent == that.trainingPercent && trainingPercent == that.trainingPercent
&& randomizeSeed == that.randomizeSeed; && randomizeSeed == that.randomizeSeed;
} }
@ -397,7 +433,7 @@ public class Classification implements DataFrameAnalysis {
@Override @Override
public int hashCode() { public int hashCode() {
return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, classAssignmentObjective, return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, classAssignmentObjective,
numTopClasses, trainingPercent, randomizeSeed); numTopClasses, trainingPercent, randomizeSeed, featureProcessors);
} }
public enum ClassAssignmentObjective { public enum ClassAssignmentObjective {

View File

@ -15,9 +15,13 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.mapper.NumberFieldMapper; import org.elasticsearch.index.mapper.NumberFieldMapper;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays; import java.util.Arrays;
@ -28,6 +32,7 @@ import java.util.Locale;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
@ -42,12 +47,14 @@ public class Regression implements DataFrameAnalysis {
public static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed"); public static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");
public static final ParseField LOSS_FUNCTION = new ParseField("loss_function"); public static final ParseField LOSS_FUNCTION = new ParseField("loss_function");
public static final ParseField LOSS_FUNCTION_PARAMETER = new ParseField("loss_function_parameter"); public static final ParseField LOSS_FUNCTION_PARAMETER = new ParseField("loss_function_parameter");
public static final ParseField FEATURE_PROCESSORS = new ParseField("feature_processors");
private static final String STATE_DOC_ID_SUFFIX = "_regression_state#1"; private static final String STATE_DOC_ID_SUFFIX = "_regression_state#1";
private static final ConstructingObjectParser<Regression, Void> LENIENT_PARSER = createParser(true); private static final ConstructingObjectParser<Regression, Void> LENIENT_PARSER = createParser(true);
private static final ConstructingObjectParser<Regression, Void> STRICT_PARSER = createParser(false); private static final ConstructingObjectParser<Regression, Void> STRICT_PARSER = createParser(false);
@SuppressWarnings("unchecked")
private static ConstructingObjectParser<Regression, Void> createParser(boolean lenient) { private static ConstructingObjectParser<Regression, Void> createParser(boolean lenient) {
ConstructingObjectParser<Regression, Void> parser = new ConstructingObjectParser<>( ConstructingObjectParser<Regression, Void> parser = new ConstructingObjectParser<>(
NAME.getPreferredName(), NAME.getPreferredName(),
@ -59,7 +66,8 @@ public class Regression implements DataFrameAnalysis {
(Double) a[8], (Double) a[8],
(Long) a[9], (Long) a[9],
(LossFunction) a[10], (LossFunction) a[10],
(Double) a[11])); (Double) a[11],
(List<PreProcessor>) a[12]));
parser.declareString(constructorArg(), DEPENDENT_VARIABLE); parser.declareString(constructorArg(), DEPENDENT_VARIABLE);
BoostedTreeParams.declareFields(parser); BoostedTreeParams.declareFields(parser);
parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME); parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME);
@ -67,6 +75,12 @@ public class Regression implements DataFrameAnalysis {
parser.declareLong(optionalConstructorArg(), RANDOMIZE_SEED); parser.declareLong(optionalConstructorArg(), RANDOMIZE_SEED);
parser.declareString(optionalConstructorArg(), LossFunction::fromString, LOSS_FUNCTION); parser.declareString(optionalConstructorArg(), LossFunction::fromString, LOSS_FUNCTION);
parser.declareDouble(optionalConstructorArg(), LOSS_FUNCTION_PARAMETER); parser.declareDouble(optionalConstructorArg(), LOSS_FUNCTION_PARAMETER);
parser.declareNamedObjects(optionalConstructorArg(),
(p, c, n) -> lenient ?
p.namedObject(LenientlyParsedPreProcessor.class, n, new PreProcessor.PreProcessorParseContext(true)) :
p.namedObject(StrictlyParsedPreProcessor.class, n, new PreProcessor.PreProcessorParseContext(true)),
(regression) -> {/*TODO should we throw if this is not set?*/},
FEATURE_PROCESSORS);
return parser; return parser;
} }
@ -90,6 +104,7 @@ public class Regression implements DataFrameAnalysis {
private final long randomizeSeed; private final long randomizeSeed;
private final LossFunction lossFunction; private final LossFunction lossFunction;
private final Double lossFunctionParameter; private final Double lossFunctionParameter;
private final List<PreProcessor> featureProcessors;
public Regression(String dependentVariable, public Regression(String dependentVariable,
BoostedTreeParams boostedTreeParams, BoostedTreeParams boostedTreeParams,
@ -97,7 +112,8 @@ public class Regression implements DataFrameAnalysis {
@Nullable Double trainingPercent, @Nullable Double trainingPercent,
@Nullable Long randomizeSeed, @Nullable Long randomizeSeed,
@Nullable LossFunction lossFunction, @Nullable LossFunction lossFunction,
@Nullable Double lossFunctionParameter) { @Nullable Double lossFunctionParameter,
@Nullable List<PreProcessor> featureProcessors) {
if (trainingPercent != null && (trainingPercent < 1.0 || trainingPercent > 100.0)) { if (trainingPercent != null && (trainingPercent < 1.0 || trainingPercent > 100.0)) {
throw ExceptionsHelper.badRequestException("[{}] must be a double in [1, 100]", TRAINING_PERCENT.getPreferredName()); throw ExceptionsHelper.badRequestException("[{}] must be a double in [1, 100]", TRAINING_PERCENT.getPreferredName());
} }
@ -112,10 +128,11 @@ public class Regression implements DataFrameAnalysis {
throw ExceptionsHelper.badRequestException("[{}] must be a positive double", LOSS_FUNCTION_PARAMETER.getPreferredName()); throw ExceptionsHelper.badRequestException("[{}] must be a positive double", LOSS_FUNCTION_PARAMETER.getPreferredName());
} }
this.lossFunctionParameter = lossFunctionParameter; this.lossFunctionParameter = lossFunctionParameter;
this.featureProcessors = featureProcessors == null ? Collections.emptyList() : Collections.unmodifiableList(featureProcessors);
} }
public Regression(String dependentVariable) { public Regression(String dependentVariable) {
this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null, null); this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null, null, null);
} }
public Regression(StreamInput in) throws IOException { public Regression(StreamInput in) throws IOException {
@ -136,6 +153,11 @@ public class Regression implements DataFrameAnalysis {
lossFunction = LossFunction.MSE; lossFunction = LossFunction.MSE;
lossFunctionParameter = null; lossFunctionParameter = null;
} }
if (in.getVersion().onOrAfter(Version.V_7_10_0)) {
featureProcessors = Collections.unmodifiableList(in.readNamedWriteableList(PreProcessor.class));
} else {
featureProcessors = Collections.emptyList();
}
} }
public String getDependentVariable() { public String getDependentVariable() {
@ -166,6 +188,10 @@ public class Regression implements DataFrameAnalysis {
return lossFunctionParameter; return lossFunctionParameter;
} }
public List<PreProcessor> getFeatureProcessors() {
return featureProcessors;
}
@Override @Override
public String getWriteableName() { public String getWriteableName() {
return NAME.getPreferredName(); return NAME.getPreferredName();
@ -184,6 +210,9 @@ public class Regression implements DataFrameAnalysis {
out.writeEnum(lossFunction); out.writeEnum(lossFunction);
out.writeOptionalDouble(lossFunctionParameter); out.writeOptionalDouble(lossFunctionParameter);
} }
if (out.getVersion().onOrAfter(Version.V_7_10_0)) {
out.writeNamedWriteableList(featureProcessors);
}
} }
@Override @Override
@ -204,6 +233,9 @@ public class Regression implements DataFrameAnalysis {
if (lossFunctionParameter != null) { if (lossFunctionParameter != null) {
builder.field(LOSS_FUNCTION_PARAMETER.getPreferredName(), lossFunctionParameter); builder.field(LOSS_FUNCTION_PARAMETER.getPreferredName(), lossFunctionParameter);
} }
if (featureProcessors.isEmpty() == false) {
NamedXContentObjectHelper.writeNamedObjects(builder, params, true, FEATURE_PROCESSORS.getPreferredName(), featureProcessors);
}
builder.endObject(); builder.endObject();
return builder; return builder;
} }
@ -221,6 +253,10 @@ public class Regression implements DataFrameAnalysis {
if (lossFunctionParameter != null) { if (lossFunctionParameter != null) {
params.put(LOSS_FUNCTION_PARAMETER.getPreferredName(), lossFunctionParameter); params.put(LOSS_FUNCTION_PARAMETER.getPreferredName(), lossFunctionParameter);
} }
if (featureProcessors.isEmpty() == false) {
params.put(FEATURE_PROCESSORS.getPreferredName(),
featureProcessors.stream().map(p -> Collections.singletonMap(p.getName(), p)).collect(Collectors.toList()));
}
return params; return params;
} }
@ -304,13 +340,14 @@ public class Regression implements DataFrameAnalysis {
&& trainingPercent == that.trainingPercent && trainingPercent == that.trainingPercent
&& randomizeSeed == that.randomizeSeed && randomizeSeed == that.randomizeSeed
&& lossFunction == that.lossFunction && lossFunction == that.lossFunction
&& Objects.equals(featureProcessors, that.featureProcessors)
&& Objects.equals(lossFunctionParameter, that.lossFunctionParameter); && Objects.equals(lossFunctionParameter, that.lossFunctionParameter);
} }
@Override @Override
public int hashCode() { public int hashCode() {
return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, trainingPercent, randomizeSeed, lossFunction, return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, trainingPercent, randomizeSeed, lossFunction,
lossFunctionParameter); lossFunctionParameter, featureProcessors);
} }
public enum LossFunction { public enum LossFunction {

View File

@ -57,23 +57,23 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
// PreProcessing Lenient // PreProcessing Lenient
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, OneHotEncoding.NAME, namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, OneHotEncoding.NAME,
OneHotEncoding::fromXContentLenient)); (p, c) -> OneHotEncoding.fromXContentLenient(p, (PreProcessor.PreProcessorParseContext) c)));
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, TargetMeanEncoding.NAME, namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, TargetMeanEncoding.NAME,
TargetMeanEncoding::fromXContentLenient)); (p, c) -> TargetMeanEncoding.fromXContentLenient(p, (PreProcessor.PreProcessorParseContext) c)));
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, FrequencyEncoding.NAME, namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, FrequencyEncoding.NAME,
FrequencyEncoding::fromXContentLenient)); (p, c) -> FrequencyEncoding.fromXContentLenient(p, (PreProcessor.PreProcessorParseContext) c)));
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, CustomWordEmbedding.NAME, namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, CustomWordEmbedding.NAME,
CustomWordEmbedding::fromXContentLenient)); (p, c) -> CustomWordEmbedding.fromXContentLenient(p)));
// PreProcessing Strict // PreProcessing Strict
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, OneHotEncoding.NAME, namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, OneHotEncoding.NAME,
OneHotEncoding::fromXContentStrict)); (p, c) -> OneHotEncoding.fromXContentStrict(p, (PreProcessor.PreProcessorParseContext) c)));
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, TargetMeanEncoding.NAME, namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, TargetMeanEncoding.NAME,
TargetMeanEncoding::fromXContentStrict)); (p, c) -> TargetMeanEncoding.fromXContentStrict(p, (PreProcessor.PreProcessorParseContext) c)));
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, FrequencyEncoding.NAME, namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, FrequencyEncoding.NAME,
FrequencyEncoding::fromXContentStrict)); (p, c) -> FrequencyEncoding.fromXContentStrict(p, (PreProcessor.PreProcessorParseContext) c)));
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, CustomWordEmbedding.NAME, namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, CustomWordEmbedding.NAME,
CustomWordEmbedding::fromXContentStrict)); (p, c) -> CustomWordEmbedding.fromXContentStrict(p)));
// Model Lenient // Model Lenient
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedTrainedModel.class, Tree.NAME, Tree::fromXContentLenient)); namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedTrainedModel.class, Tree.NAME, Tree::fromXContentLenient));

View File

@ -56,8 +56,8 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable, Acco
TRAINED_MODEL); TRAINED_MODEL);
parser.declareNamedObjects(TrainedModelDefinition.Builder::setPreProcessors, parser.declareNamedObjects(TrainedModelDefinition.Builder::setPreProcessors,
(p, c, n) -> ignoreUnknownFields ? (p, c, n) -> ignoreUnknownFields ?
p.namedObject(LenientlyParsedPreProcessor.class, n, null) : p.namedObject(LenientlyParsedPreProcessor.class, n, PreProcessor.PreProcessorParseContext.DEFAULT) :
p.namedObject(StrictlyParsedPreProcessor.class, n, null), p.namedObject(StrictlyParsedPreProcessor.class, n, PreProcessor.PreProcessorParseContext.DEFAULT),
(trainedModelDefBuilder) -> trainedModelDefBuilder.setProcessorsInOrder(true), (trainedModelDefBuilder) -> trainedModelDefBuilder.setProcessorsInOrder(true),
PREPROCESSORS); PREPROCESSORS);
return parser; return parser;

View File

@ -50,15 +50,15 @@ public class CustomWordEmbedding implements LenientlyParsedPreProcessor, Strictl
public static final ParseField EMBEDDING_WEIGHTS = new ParseField("embedding_weights"); public static final ParseField EMBEDDING_WEIGHTS = new ParseField("embedding_weights");
public static final ParseField EMBEDDING_QUANT_SCALES = new ParseField("embedding_quant_scales"); public static final ParseField EMBEDDING_QUANT_SCALES = new ParseField("embedding_quant_scales");
public static final ConstructingObjectParser<CustomWordEmbedding, Void> STRICT_PARSER = createParser(false); private static final ConstructingObjectParser<CustomWordEmbedding, PreProcessorParseContext> STRICT_PARSER = createParser(false);
public static final ConstructingObjectParser<CustomWordEmbedding, Void> LENIENT_PARSER = createParser(true); private static final ConstructingObjectParser<CustomWordEmbedding, PreProcessorParseContext> LENIENT_PARSER = createParser(true);
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
private static ConstructingObjectParser<CustomWordEmbedding, Void> createParser(boolean lenient) { private static ConstructingObjectParser<CustomWordEmbedding, PreProcessorParseContext> createParser(boolean lenient) {
ConstructingObjectParser<CustomWordEmbedding, Void> parser = new ConstructingObjectParser<>( ConstructingObjectParser<CustomWordEmbedding, PreProcessorParseContext> parser = new ConstructingObjectParser<>(
NAME.getPreferredName(), NAME.getPreferredName(),
lenient, lenient,
a -> new CustomWordEmbedding((short[][])a[0], (byte[][])a[1], (String)a[2], (String)a[3])); (a, c) -> new CustomWordEmbedding((short[][])a[0], (byte[][])a[1], (String)a[2], (String)a[3]));
parser.declareField(ConstructingObjectParser.constructorArg(), parser.declareField(ConstructingObjectParser.constructorArg(),
(p, c) -> { (p, c) -> {
@ -123,11 +123,11 @@ public class CustomWordEmbedding implements LenientlyParsedPreProcessor, Strictl
} }
public static CustomWordEmbedding fromXContentStrict(XContentParser parser) { public static CustomWordEmbedding fromXContentStrict(XContentParser parser) {
return STRICT_PARSER.apply(parser, null); return STRICT_PARSER.apply(parser, PreProcessorParseContext.DEFAULT);
} }
public static CustomWordEmbedding fromXContentLenient(XContentParser parser) { public static CustomWordEmbedding fromXContentLenient(XContentParser parser) {
return LENIENT_PARSER.apply(parser, null); return LENIENT_PARSER.apply(parser, PreProcessorParseContext.DEFAULT);
} }
private static final int CONCAT_LAYER_SIZE = 80; private static final int CONCAT_LAYER_SIZE = 80;
@ -256,6 +256,11 @@ public class CustomWordEmbedding implements LenientlyParsedPreProcessor, Strictl
return false; return false;
} }
@Override
public String getOutputFieldType(String outputField) {
return "dense_vector";
}
@Override @Override
public long ramBytesUsed() { public long ramBytesUsed() {
long size = SHALLOW_SIZE; long size = SHALLOW_SIZE;

View File

@ -14,6 +14,7 @@ import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.mapper.NumberFieldMapper;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException; import java.io.IOException;
@ -37,15 +38,18 @@ public class FrequencyEncoding implements LenientlyParsedPreProcessor, StrictlyP
public static final ParseField FREQUENCY_MAP = new ParseField("frequency_map"); public static final ParseField FREQUENCY_MAP = new ParseField("frequency_map");
public static final ParseField CUSTOM = new ParseField("custom"); public static final ParseField CUSTOM = new ParseField("custom");
public static final ConstructingObjectParser<FrequencyEncoding, Void> STRICT_PARSER = createParser(false); private static final ConstructingObjectParser<FrequencyEncoding, PreProcessorParseContext> STRICT_PARSER = createParser(false);
public static final ConstructingObjectParser<FrequencyEncoding, Void> LENIENT_PARSER = createParser(true); private static final ConstructingObjectParser<FrequencyEncoding, PreProcessorParseContext> LENIENT_PARSER = createParser(true);
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
private static ConstructingObjectParser<FrequencyEncoding, Void> createParser(boolean lenient) { private static ConstructingObjectParser<FrequencyEncoding, PreProcessorParseContext> createParser(boolean lenient) {
ConstructingObjectParser<FrequencyEncoding, Void> parser = new ConstructingObjectParser<>( ConstructingObjectParser<FrequencyEncoding, PreProcessorParseContext> parser = new ConstructingObjectParser<>(
NAME.getPreferredName(), NAME.getPreferredName(),
lenient, lenient,
a -> new FrequencyEncoding((String)a[0], (String)a[1], (Map<String, Double>)a[2], (Boolean)a[3])); (a, c) -> new FrequencyEncoding((String)a[0],
(String)a[1],
(Map<String, Double>)a[2],
a[3] == null ? c.isCustomByDefault() : (Boolean)a[3]));
parser.declareString(ConstructingObjectParser.constructorArg(), FIELD); parser.declareString(ConstructingObjectParser.constructorArg(), FIELD);
parser.declareString(ConstructingObjectParser.constructorArg(), FEATURE_NAME); parser.declareString(ConstructingObjectParser.constructorArg(), FEATURE_NAME);
parser.declareObject(ConstructingObjectParser.constructorArg(), parser.declareObject(ConstructingObjectParser.constructorArg(),
@ -55,12 +59,12 @@ public class FrequencyEncoding implements LenientlyParsedPreProcessor, StrictlyP
return parser; return parser;
} }
public static FrequencyEncoding fromXContentStrict(XContentParser parser) { public static FrequencyEncoding fromXContentStrict(XContentParser parser, PreProcessorParseContext context) {
return STRICT_PARSER.apply(parser, null); return STRICT_PARSER.apply(parser, context == null ? PreProcessorParseContext.DEFAULT : context);
} }
public static FrequencyEncoding fromXContentLenient(XContentParser parser) { public static FrequencyEncoding fromXContentLenient(XContentParser parser, PreProcessorParseContext context) {
return LENIENT_PARSER.apply(parser, null); return LENIENT_PARSER.apply(parser, context == null ? PreProcessorParseContext.DEFAULT : context);
} }
private final String field; private final String field;
@ -117,6 +121,11 @@ public class FrequencyEncoding implements LenientlyParsedPreProcessor, StrictlyP
return custom; return custom;
} }
@Override
public String getOutputFieldType(String outputField) {
return NumberFieldMapper.NumberType.DOUBLE.typeName();
}
@Override @Override
public String getName() { public String getName() {
return NAME.getPreferredName(); return NAME.getPreferredName();

View File

@ -14,6 +14,7 @@ import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.mapper.NumberFieldMapper;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException; import java.io.IOException;
@ -36,27 +37,29 @@ public class OneHotEncoding implements LenientlyParsedPreProcessor, StrictlyPars
public static final ParseField HOT_MAP = new ParseField("hot_map"); public static final ParseField HOT_MAP = new ParseField("hot_map");
public static final ParseField CUSTOM = new ParseField("custom"); public static final ParseField CUSTOM = new ParseField("custom");
public static final ConstructingObjectParser<OneHotEncoding, Void> STRICT_PARSER = createParser(false); private static final ConstructingObjectParser<OneHotEncoding, PreProcessorParseContext> STRICT_PARSER = createParser(false);
public static final ConstructingObjectParser<OneHotEncoding, Void> LENIENT_PARSER = createParser(true); private static final ConstructingObjectParser<OneHotEncoding, PreProcessorParseContext> LENIENT_PARSER = createParser(true);
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
private static ConstructingObjectParser<OneHotEncoding, Void> createParser(boolean lenient) { private static ConstructingObjectParser<OneHotEncoding, PreProcessorParseContext> createParser(boolean lenient) {
ConstructingObjectParser<OneHotEncoding, Void> parser = new ConstructingObjectParser<>( ConstructingObjectParser<OneHotEncoding, PreProcessorParseContext> parser = new ConstructingObjectParser<>(
NAME.getPreferredName(), NAME.getPreferredName(),
lenient, lenient,
a -> new OneHotEncoding((String)a[0], (Map<String, String>)a[1], (Boolean)a[2])); (a, c) -> new OneHotEncoding((String)a[0],
(Map<String, String>)a[1],
a[2] == null ? c.isCustomByDefault() : (Boolean)a[2]));
parser.declareString(ConstructingObjectParser.constructorArg(), FIELD); parser.declareString(ConstructingObjectParser.constructorArg(), FIELD);
parser.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> p.mapStrings(), HOT_MAP); parser.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> p.mapStrings(), HOT_MAP);
parser.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), CUSTOM); parser.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), CUSTOM);
return parser; return parser;
} }
public static OneHotEncoding fromXContentStrict(XContentParser parser) { public static OneHotEncoding fromXContentStrict(XContentParser parser, PreProcessorParseContext context) {
return STRICT_PARSER.apply(parser, null); return STRICT_PARSER.apply(parser, context == null ? PreProcessorParseContext.DEFAULT : context);
} }
public static OneHotEncoding fromXContentLenient(XContentParser parser) { public static OneHotEncoding fromXContentLenient(XContentParser parser, PreProcessorParseContext context) {
return LENIENT_PARSER.apply(parser, null); return LENIENT_PARSER.apply(parser, context == null ? PreProcessorParseContext.DEFAULT : context);
} }
private final String field; private final String field;
@ -103,6 +106,11 @@ public class OneHotEncoding implements LenientlyParsedPreProcessor, StrictlyPars
return custom; return custom;
} }
@Override
public String getOutputFieldType(String outputField) {
return NumberFieldMapper.NumberType.INTEGER.typeName();
}
@Override @Override
public String getName() { public String getName() {
return NAME.getPreferredName(); return NAME.getPreferredName();
@ -124,8 +132,9 @@ public class OneHotEncoding implements LenientlyParsedPreProcessor, StrictlyPars
if (value == null) { if (value == null) {
return; return;
} }
final String stringValue = value.toString();
hotMap.forEach((val, col) -> { hotMap.forEach((val, col) -> {
int encoding = value.toString().equals(val) ? 1 : 0; int encoding = stringValue.equals(val) ? 1 : 0;
fields.put(col, encoding); fields.put(col, encoding);
}); });
} }

View File

@ -18,6 +18,18 @@ import java.util.Map;
*/ */
public interface PreProcessor extends NamedXContentObject, NamedWriteable, Accountable { public interface PreProcessor extends NamedXContentObject, NamedWriteable, Accountable {
class PreProcessorParseContext {
public static final PreProcessorParseContext DEFAULT = new PreProcessorParseContext(false);
final boolean defaultIsCustomValue;
public PreProcessorParseContext(boolean defaultIsCustomValue) {
this.defaultIsCustomValue = defaultIsCustomValue;
}
public boolean isCustomByDefault() {
return defaultIsCustomValue;
}
}
/** /**
* The expected input fields * The expected input fields
*/ */
@ -48,4 +60,6 @@ public interface PreProcessor extends NamedXContentObject, NamedWriteable, Accou
*/ */
boolean isCustom(); boolean isCustom();
String getOutputFieldType(String outputField);
} }

View File

@ -14,6 +14,7 @@ import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.mapper.NumberFieldMapper;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException; import java.io.IOException;
@ -37,15 +38,19 @@ public class TargetMeanEncoding implements LenientlyParsedPreProcessor, Strictly
public static final ParseField DEFAULT_VALUE = new ParseField("default_value"); public static final ParseField DEFAULT_VALUE = new ParseField("default_value");
public static final ParseField CUSTOM = new ParseField("custom"); public static final ParseField CUSTOM = new ParseField("custom");
public static final ConstructingObjectParser<TargetMeanEncoding, Void> STRICT_PARSER = createParser(false); private static final ConstructingObjectParser<TargetMeanEncoding, PreProcessorParseContext> STRICT_PARSER = createParser(false);
public static final ConstructingObjectParser<TargetMeanEncoding, Void> LENIENT_PARSER = createParser(true); private static final ConstructingObjectParser<TargetMeanEncoding, PreProcessorParseContext> LENIENT_PARSER = createParser(true);
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
private static ConstructingObjectParser<TargetMeanEncoding, Void> createParser(boolean lenient) { private static ConstructingObjectParser<TargetMeanEncoding, PreProcessorParseContext> createParser(boolean lenient) {
ConstructingObjectParser<TargetMeanEncoding, Void> parser = new ConstructingObjectParser<>( ConstructingObjectParser<TargetMeanEncoding, PreProcessorParseContext> parser = new ConstructingObjectParser<>(
NAME.getPreferredName(), NAME.getPreferredName(),
lenient, lenient,
a -> new TargetMeanEncoding((String)a[0], (String)a[1], (Map<String, Double>)a[2], (Double)a[3], (Boolean)a[4])); (a, c) -> new TargetMeanEncoding((String)a[0],
(String)a[1],
(Map<String, Double>)a[2],
(Double)a[3],
a[4] == null ? c.isCustomByDefault() : (Boolean)a[4]));
parser.declareString(ConstructingObjectParser.constructorArg(), FIELD); parser.declareString(ConstructingObjectParser.constructorArg(), FIELD);
parser.declareString(ConstructingObjectParser.constructorArg(), FEATURE_NAME); parser.declareString(ConstructingObjectParser.constructorArg(), FEATURE_NAME);
parser.declareObject(ConstructingObjectParser.constructorArg(), parser.declareObject(ConstructingObjectParser.constructorArg(),
@ -56,12 +61,12 @@ public class TargetMeanEncoding implements LenientlyParsedPreProcessor, Strictly
return parser; return parser;
} }
public static TargetMeanEncoding fromXContentStrict(XContentParser parser) { public static TargetMeanEncoding fromXContentStrict(XContentParser parser, PreProcessorParseContext context) {
return STRICT_PARSER.apply(parser, null); return STRICT_PARSER.apply(parser, context == null ? PreProcessorParseContext.DEFAULT : context);
} }
public static TargetMeanEncoding fromXContentLenient(XContentParser parser) { public static TargetMeanEncoding fromXContentLenient(XContentParser parser, PreProcessorParseContext context) {
return LENIENT_PARSER.apply(parser, null); return LENIENT_PARSER.apply(parser, context == null ? PreProcessorParseContext.DEFAULT : context);
} }
private final String field; private final String field;
@ -128,6 +133,11 @@ public class TargetMeanEncoding implements LenientlyParsedPreProcessor, Strictly
return custom; return custom;
} }
@Override
public String getOutputFieldType(String outputField) {
return NumberFieldMapper.NumberType.DOUBLE.typeName();
}
@Override @Override
public String getName() { public String getName() {
return NAME.getPreferredName(); return NAME.getPreferredName();

View File

@ -41,7 +41,7 @@ public class InferenceDefinition {
(p, c, n) -> p.namedObject(InferenceModel.class, n, null), (p, c, n) -> p.namedObject(InferenceModel.class, n, null),
TRAINED_MODEL); TRAINED_MODEL);
PARSER.declareNamedObjects(InferenceDefinition.Builder::setPreProcessors, PARSER.declareNamedObjects(InferenceDefinition.Builder::setPreProcessors,
(p, c, n) -> p.namedObject(LenientlyParsedPreProcessor.class, n, null), (p, c, n) -> p.namedObject(LenientlyParsedPreProcessor.class, n, PreProcessor.PreProcessorParseContext.DEFAULT),
(trainedModelDefBuilder) -> {}, (trainedModelDefBuilder) -> {},
PREPROCESSORS); PREPROCESSORS);
} }

View File

@ -327,12 +327,14 @@ public final class ReservedFieldNames {
Regression.LOSS_FUNCTION_PARAMETER.getPreferredName(), Regression.LOSS_FUNCTION_PARAMETER.getPreferredName(),
Regression.PREDICTION_FIELD_NAME.getPreferredName(), Regression.PREDICTION_FIELD_NAME.getPreferredName(),
Regression.TRAINING_PERCENT.getPreferredName(), Regression.TRAINING_PERCENT.getPreferredName(),
Regression.FEATURE_PROCESSORS.getPreferredName(),
Classification.NAME.getPreferredName(), Classification.NAME.getPreferredName(),
Classification.DEPENDENT_VARIABLE.getPreferredName(), Classification.DEPENDENT_VARIABLE.getPreferredName(),
Classification.PREDICTION_FIELD_NAME.getPreferredName(), Classification.PREDICTION_FIELD_NAME.getPreferredName(),
Classification.CLASS_ASSIGNMENT_OBJECTIVE.getPreferredName(), Classification.CLASS_ASSIGNMENT_OBJECTIVE.getPreferredName(),
Classification.NUM_TOP_CLASSES.getPreferredName(), Classification.NUM_TOP_CLASSES.getPreferredName(),
Classification.TRAINING_PERCENT.getPreferredName(), Classification.TRAINING_PERCENT.getPreferredName(),
Classification.FEATURE_PROCESSORS.getPreferredName(),
BoostedTreeParams.LAMBDA.getPreferredName(), BoostedTreeParams.LAMBDA.getPreferredName(),
BoostedTreeParams.GAMMA.getPreferredName(), BoostedTreeParams.GAMMA.getPreferredName(),
BoostedTreeParams.ETA.getPreferredName(), BoostedTreeParams.ETA.getPreferredName(),

View File

@ -34,6 +34,9 @@
"feature_bag_fraction" : { "feature_bag_fraction" : {
"type" : "double" "type" : "double"
}, },
"feature_processors": {
"enabled": false
},
"gamma" : { "gamma" : {
"type" : "double" "type" : "double"
}, },
@ -84,6 +87,9 @@
"feature_bag_fraction" : { "feature_bag_fraction" : {
"type" : "double" "type" : "double"
}, },
"feature_processors": {
"enabled": false
},
"gamma" : { "gamma" : {
"type" : "double" "type" : "double"
}, },

View File

@ -16,6 +16,7 @@ import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsAction.Respon
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigTests; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigTests;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider; import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
@ -28,6 +29,7 @@ public class GetDataFrameAnalyticsActionResponseTests extends AbstractWireSerial
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>(); List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
namedWriteables.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedWriteables()); namedWriteables.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedWriteables());
namedWriteables.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedWriteables()); namedWriteables.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedWriteables());
namedWriteables.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables());
return new NamedWriteableRegistry(namedWriteables); return new NamedWriteableRegistry(namedWriteables);
} }
@ -36,6 +38,7 @@ public class GetDataFrameAnalyticsActionResponseTests extends AbstractWireSerial
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>(); List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers()); namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers());
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents()); namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
return new NamedXContentRegistry(namedXContent); return new NamedXContentRegistry(namedXContent);
} }

View File

@ -19,6 +19,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigTests;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider; import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetectionTests; import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetectionTests;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.junit.Before; import org.junit.Before;
import java.util.ArrayList; import java.util.ArrayList;
@ -44,6 +45,7 @@ public class PutDataFrameAnalyticsActionRequestTests extends AbstractSerializing
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>(); List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
namedWriteables.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedWriteables()); namedWriteables.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedWriteables());
namedWriteables.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedWriteables()); namedWriteables.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedWriteables());
namedWriteables.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables());
return new NamedWriteableRegistry(namedWriteables); return new NamedWriteableRegistry(namedWriteables);
} }
@ -52,6 +54,7 @@ public class PutDataFrameAnalyticsActionRequestTests extends AbstractSerializing
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>(); List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers()); namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers());
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents()); namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
return new NamedXContentRegistry(namedXContent); return new NamedXContentRegistry(namedXContent);
} }

View File

@ -13,6 +13,7 @@ import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction.Response; import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction.Response;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigTests; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigTests;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider; import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
@ -25,6 +26,7 @@ public class PutDataFrameAnalyticsActionResponseTests extends AbstractWireSerial
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>(); List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
namedWriteables.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedWriteables()); namedWriteables.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedWriteables());
namedWriteables.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedWriteables()); namedWriteables.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedWriteables());
namedWriteables.addAll(new MlInferenceNamedXContentProvider() .getNamedWriteables());
return new NamedWriteableRegistry(namedWriteables); return new NamedWriteableRegistry(namedWriteables);
} }

View File

@ -41,6 +41,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetectionTests; import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetectionTests;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.RegressionTests; import org.elasticsearch.xpack.core.ml.dataframe.analyses.RegressionTests;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
import org.junit.Before; import org.junit.Before;
@ -78,6 +79,7 @@ public class DataFrameAnalyticsConfigTests extends AbstractBWCSerializationTestC
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>(); List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
namedWriteables.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedWriteables()); namedWriteables.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedWriteables());
namedWriteables.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedWriteables()); namedWriteables.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedWriteables());
namedWriteables.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables());
return new NamedWriteableRegistry(namedWriteables); return new NamedWriteableRegistry(namedWriteables);
} }
@ -86,6 +88,7 @@ public class DataFrameAnalyticsConfigTests extends AbstractBWCSerializationTestC
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>(); List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers()); namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers());
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents()); namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
return new NamedXContentRegistry(namedXContent); return new NamedXContentRegistry(namedXContent);
} }
@ -144,14 +147,16 @@ public class DataFrameAnalyticsConfigTests extends AbstractBWCSerializationTestC
bwcRegression.getTrainingPercent(), bwcRegression.getTrainingPercent(),
42L, 42L,
bwcRegression.getLossFunction(), bwcRegression.getLossFunction(),
bwcRegression.getLossFunctionParameter()); bwcRegression.getLossFunctionParameter(),
bwcRegression.getFeatureProcessors());
testAnalysis = new Regression(testRegression.getDependentVariable(), testAnalysis = new Regression(testRegression.getDependentVariable(),
testRegression.getBoostedTreeParams(), testRegression.getBoostedTreeParams(),
testRegression.getPredictionFieldName(), testRegression.getPredictionFieldName(),
testRegression.getTrainingPercent(), testRegression.getTrainingPercent(),
42L, 42L,
testRegression.getLossFunction(), testRegression.getLossFunction(),
testRegression.getLossFunctionParameter()); testRegression.getLossFunctionParameter(),
bwcRegression.getFeatureProcessors());
} else { } else {
Classification testClassification = (Classification)testInstance.getAnalysis(); Classification testClassification = (Classification)testInstance.getAnalysis();
Classification bwcClassification = (Classification)bwcSerializedObject.getAnalysis(); Classification bwcClassification = (Classification)bwcSerializedObject.getAnalysis();
@ -161,14 +166,16 @@ public class DataFrameAnalyticsConfigTests extends AbstractBWCSerializationTestC
bwcClassification.getClassAssignmentObjective(), bwcClassification.getClassAssignmentObjective(),
bwcClassification.getNumTopClasses(), bwcClassification.getNumTopClasses(),
bwcClassification.getTrainingPercent(), bwcClassification.getTrainingPercent(),
42L); 42L,
bwcClassification.getFeatureProcessors());
testAnalysis = new Classification(testClassification.getDependentVariable(), testAnalysis = new Classification(testClassification.getDependentVariable(),
testClassification.getBoostedTreeParams(), testClassification.getBoostedTreeParams(),
testClassification.getPredictionFieldName(), testClassification.getPredictionFieldName(),
testClassification.getClassAssignmentObjective(), testClassification.getClassAssignmentObjective(),
testClassification.getNumTopClasses(), testClassification.getNumTopClasses(),
testClassification.getTrainingPercent(), testClassification.getTrainingPercent(),
42L); 42L,
testClassification.getFeatureProcessors());
} }
super.assertOnBWCObject(new DataFrameAnalyticsConfig.Builder(bwcSerializedObject) super.assertOnBWCObject(new DataFrameAnalyticsConfig.Builder(bwcSerializedObject)
.setAnalysis(bwcAnalysis) .setAnalysis(bwcAnalysis)

View File

@ -8,25 +8,41 @@ package org.elasticsearch.xpack.core.ml.dataframe.analyses;
import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.Version; import org.elasticsearch.Version;
import org.elasticsearch.common.Strings; import org.elasticsearch.common.Strings;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.DeprecationHandler;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.common.xcontent.json.JsonXContent; import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.index.mapper.BooleanFieldMapper; import org.elasticsearch.index.mapper.BooleanFieldMapper;
import org.elasticsearch.index.mapper.KeywordFieldMapper; import org.elasticsearch.index.mapper.KeywordFieldMapper;
import org.elasticsearch.index.mapper.NumberFieldMapper; import org.elasticsearch.index.mapper.NumberFieldMapper;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase; import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncodingTests;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncodingTests;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncodingTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.containsString;
@ -55,6 +71,21 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
return createRandom(); return createRandom();
} }
@Override
protected NamedXContentRegistry xContentRegistry() {
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
return new NamedXContentRegistry(namedXContent);
}
@Override
protected NamedWriteableRegistry getNamedWriteableRegistry() {
List<NamedWriteableRegistry.Entry> entries = new ArrayList<>();
entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables());
return new NamedWriteableRegistry(entries);
}
public static Classification createRandom() { public static Classification createRandom() {
String dependentVariableName = randomAlphaOfLength(10); String dependentVariableName = randomAlphaOfLength(10);
BoostedTreeParams boostedTreeParams = BoostedTreeParamsTests.createRandom(); BoostedTreeParams boostedTreeParams = BoostedTreeParamsTests.createRandom();
@ -65,7 +96,14 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
Double trainingPercent = randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true); Double trainingPercent = randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true);
Long randomizeSeed = randomBoolean() ? null : randomLong(); Long randomizeSeed = randomBoolean() ? null : randomLong();
return new Classification(dependentVariableName, boostedTreeParams, predictionFieldName, classAssignmentObjective, return new Classification(dependentVariableName, boostedTreeParams, predictionFieldName, classAssignmentObjective,
numTopClasses, trainingPercent, randomizeSeed); numTopClasses, trainingPercent, randomizeSeed,
randomBoolean() ?
null :
Stream.generate(() -> randomFrom(FrequencyEncodingTests.createRandom(true),
OneHotEncodingTests.createRandom(true),
TargetMeanEncodingTests.createRandom(true)))
.limit(randomIntBetween(0, 5))
.collect(Collectors.toList()));
} }
public static Classification mutateForVersion(Classification instance, Version version) { public static Classification mutateForVersion(Classification instance, Version version) {
@ -75,7 +113,8 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
version.onOrAfter(Version.V_7_7_0) ? instance.getClassAssignmentObjective() : null, version.onOrAfter(Version.V_7_7_0) ? instance.getClassAssignmentObjective() : null,
instance.getNumTopClasses(), instance.getNumTopClasses(),
instance.getTrainingPercent(), instance.getTrainingPercent(),
instance.getRandomizeSeed()); instance.getRandomizeSeed(),
version.onOrAfter(Version.V_7_10_0) ? instance.getFeatureProcessors() : Collections.emptyList());
} }
@Override @Override
@ -91,14 +130,16 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
bwcSerializedObject.getClassAssignmentObjective(), bwcSerializedObject.getClassAssignmentObjective(),
bwcSerializedObject.getNumTopClasses(), bwcSerializedObject.getNumTopClasses(),
bwcSerializedObject.getTrainingPercent(), bwcSerializedObject.getTrainingPercent(),
42L); 42L,
bwcSerializedObject.getFeatureProcessors());
Classification newInstance = new Classification(testInstance.getDependentVariable(), Classification newInstance = new Classification(testInstance.getDependentVariable(),
testInstance.getBoostedTreeParams(), testInstance.getBoostedTreeParams(),
testInstance.getPredictionFieldName(), testInstance.getPredictionFieldName(),
testInstance.getClassAssignmentObjective(), testInstance.getClassAssignmentObjective(),
testInstance.getNumTopClasses(), testInstance.getNumTopClasses(),
testInstance.getTrainingPercent(), testInstance.getTrainingPercent(),
42L); 42L,
testInstance.getFeatureProcessors());
super.assertOnBWCObject(newBwc, newInstance, version); super.assertOnBWCObject(newBwc, newInstance, version);
} }
@ -107,87 +148,138 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
return Classification::new; return Classification::new;
} }
public void testDeserialization() throws IOException {
String toDeserialize = "{\n" +
" \"dependent_variable\": \"FlightDelayMin\",\n" +
" \"feature_processors\": [\n" +
" {\n" +
" \"one_hot_encoding\": {\n" +
" \"field\": \"OriginWeather\",\n" +
" \"hot_map\": {\n" +
" \"sunny_col\": \"Sunny\",\n" +
" \"clear_col\": \"Clear\",\n" +
" \"rainy_col\": \"Rain\"\n" +
" }\n" +
" }\n" +
" },\n" +
" {\n" +
" \"one_hot_encoding\": {\n" +
" \"field\": \"DestWeather\",\n" +
" \"hot_map\": {\n" +
" \"dest_sunny_col\": \"Sunny\",\n" +
" \"dest_clear_col\": \"Clear\",\n" +
" \"dest_rainy_col\": \"Rain\"\n" +
" }\n" +
" }\n" +
" },\n" +
" {\n" +
" \"frequency_encoding\": {\n" +
" \"field\": \"OriginWeather\",\n" +
" \"feature_name\": \"mean\",\n" +
" \"frequency_map\": {\n" +
" \"Sunny\": 0.8,\n" +
" \"Rain\": 0.2\n" +
" }\n" +
" }\n" +
" }\n" +
" ]\n" +
" }" +
"";
try(XContentParser parser = XContentHelper.createParser(xContentRegistry(),
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
new BytesArray(toDeserialize),
XContentType.JSON)) {
Classification parsed = Classification.fromXContent(parser, false);
assertThat(parsed.getDependentVariable(), equalTo("FlightDelayMin"));
for (PreProcessor preProcessor : parsed.getFeatureProcessors()) {
assertThat(preProcessor.isCustom(), is(true));
}
}
}
public void testConstructor_GivenTrainingPercentIsLessThanOne() { public void testConstructor_GivenTrainingPercentIsLessThanOne() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 0.999, randomLong())); () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 0.999, randomLong(), null));
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]")); assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
} }
public void testConstructor_GivenTrainingPercentIsGreaterThan100() { public void testConstructor_GivenTrainingPercentIsGreaterThan100() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 100.0001, randomLong())); () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 100.0001, randomLong(), null));
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]")); assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
} }
public void testConstructor_GivenNumTopClassesIsLessThanZero() { public void testConstructor_GivenNumTopClassesIsLessThanZero() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, -1, 1.0, randomLong())); () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, -1, 1.0, randomLong(), null));
assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]")); assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]"));
} }
public void testConstructor_GivenNumTopClassesIsGreaterThan1000() { public void testConstructor_GivenNumTopClassesIsGreaterThan1000() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1001, 1.0, randomLong())); () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1001, 1.0, randomLong(), null));
assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]")); assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]"));
} }
public void testGetPredictionFieldName() { public void testGetPredictionFieldName() {
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 50.0, randomLong()); Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 50.0, randomLong(), null);
assertThat(classification.getPredictionFieldName(), equalTo("result")); assertThat(classification.getPredictionFieldName(), equalTo("result"));
classification = new Classification("foo", BOOSTED_TREE_PARAMS, null, null, 3, 50.0, randomLong()); classification = new Classification("foo", BOOSTED_TREE_PARAMS, null, null, 3, 50.0, randomLong(), null);
assertThat(classification.getPredictionFieldName(), equalTo("foo_prediction")); assertThat(classification.getPredictionFieldName(), equalTo("foo_prediction"));
} }
public void testClassAssignmentObjective() { public void testClassAssignmentObjective() {
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result",
Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY, 7, 1.0, randomLong()); Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY, 7, 1.0, randomLong(), null);
assertThat(classification.getClassAssignmentObjective(), equalTo(Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY)); assertThat(classification.getClassAssignmentObjective(), equalTo(Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY));
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result",
Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL, 7, 1.0, randomLong()); Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL, 7, 1.0, randomLong(), null);
assertThat(classification.getClassAssignmentObjective(), equalTo(Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL)); assertThat(classification.getClassAssignmentObjective(), equalTo(Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL));
// class_assignment_objective == null, default applied // class_assignment_objective == null, default applied
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong()); classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong(), null);
assertThat(classification.getClassAssignmentObjective(), equalTo(Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL)); assertThat(classification.getClassAssignmentObjective(), equalTo(Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL));
} }
public void testGetNumTopClasses() { public void testGetNumTopClasses() {
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong()); Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong(), null);
assertThat(classification.getNumTopClasses(), equalTo(7)); assertThat(classification.getNumTopClasses(), equalTo(7));
// Boundary condition: num_top_classes == 0 // Boundary condition: num_top_classes == 0
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 0, 1.0, randomLong()); classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 0, 1.0, randomLong(), null);
assertThat(classification.getNumTopClasses(), equalTo(0)); assertThat(classification.getNumTopClasses(), equalTo(0));
// Boundary condition: num_top_classes == 1000 // Boundary condition: num_top_classes == 1000
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1000, 1.0, randomLong()); classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1000, 1.0, randomLong(), null);
assertThat(classification.getNumTopClasses(), equalTo(1000)); assertThat(classification.getNumTopClasses(), equalTo(1000));
// num_top_classes == null, default applied // num_top_classes == null, default applied
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, null, 1.0, randomLong()); classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, null, 1.0, randomLong(), null);
assertThat(classification.getNumTopClasses(), equalTo(2)); assertThat(classification.getNumTopClasses(), equalTo(2));
} }
public void testGetTrainingPercent() { public void testGetTrainingPercent() {
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 50.0, randomLong()); Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 50.0, randomLong(), null);
assertThat(classification.getTrainingPercent(), equalTo(50.0)); assertThat(classification.getTrainingPercent(), equalTo(50.0));
// Boundary condition: training_percent == 1.0 // Boundary condition: training_percent == 1.0
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 1.0, randomLong()); classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 1.0, randomLong(), null);
assertThat(classification.getTrainingPercent(), equalTo(1.0)); assertThat(classification.getTrainingPercent(), equalTo(1.0));
// Boundary condition: training_percent == 100.0 // Boundary condition: training_percent == 100.0
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 100.0, randomLong()); classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 100.0, randomLong(), null);
assertThat(classification.getTrainingPercent(), equalTo(100.0)); assertThat(classification.getTrainingPercent(), equalTo(100.0));
// training_percent == null, default applied // training_percent == null, default applied
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, null, randomLong()); classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, null, randomLong(), null);
assertThat(classification.getTrainingPercent(), equalTo(100.0)); assertThat(classification.getTrainingPercent(), equalTo(100.0));
} }
@ -233,6 +325,7 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
null, null,
null, null,
50.0, 50.0,
null,
null).getParams(fieldInfo), null).getParams(fieldInfo),
equalTo( equalTo(
org.elasticsearch.common.collect.Map.of( org.elasticsearch.common.collect.Map.of(

View File

@ -8,18 +8,35 @@ package org.elasticsearch.xpack.core.ml.dataframe.analyses;
import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.Version; import org.elasticsearch.Version;
import org.elasticsearch.common.Strings; import org.elasticsearch.common.Strings;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.DeprecationHandler;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.common.xcontent.json.JsonXContent; import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase; import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncodingTests;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncodingTests;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncodingTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.empty;
@ -45,6 +62,21 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
return createRandom(); return createRandom();
} }
@Override
protected NamedXContentRegistry xContentRegistry() {
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
return new NamedXContentRegistry(namedXContent);
}
@Override
protected NamedWriteableRegistry getNamedWriteableRegistry() {
List<NamedWriteableRegistry.Entry> entries = new ArrayList<>();
entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables());
return new NamedWriteableRegistry(entries);
}
public static Regression createRandom() { public static Regression createRandom() {
return createRandom(BoostedTreeParamsTests.createRandom()); return createRandom(BoostedTreeParamsTests.createRandom());
} }
@ -57,7 +89,14 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
Regression.LossFunction lossFunction = randomBoolean() ? null : randomFrom(Regression.LossFunction.values()); Regression.LossFunction lossFunction = randomBoolean() ? null : randomFrom(Regression.LossFunction.values());
Double lossFunctionParameter = randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, false); Double lossFunctionParameter = randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, false);
return new Regression(dependentVariableName, boostedTreeParams, predictionFieldName, trainingPercent, randomizeSeed, lossFunction, return new Regression(dependentVariableName, boostedTreeParams, predictionFieldName, trainingPercent, randomizeSeed, lossFunction,
lossFunctionParameter); lossFunctionParameter,
randomBoolean() ?
null :
Stream.generate(() -> randomFrom(FrequencyEncodingTests.createRandom(true),
OneHotEncodingTests.createRandom(true),
TargetMeanEncodingTests.createRandom(true)))
.limit(randomIntBetween(0, 5))
.collect(Collectors.toList()));
} }
public static Regression mutateForVersion(Regression instance, Version version) { public static Regression mutateForVersion(Regression instance, Version version) {
@ -67,7 +106,8 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
instance.getTrainingPercent(), instance.getTrainingPercent(),
instance.getRandomizeSeed(), instance.getRandomizeSeed(),
version.onOrAfter(Version.V_7_8_0) ? instance.getLossFunction() : null, version.onOrAfter(Version.V_7_8_0) ? instance.getLossFunction() : null,
version.onOrAfter(Version.V_7_8_0) ? instance.getLossFunctionParameter() : null); version.onOrAfter(Version.V_7_8_0) ? instance.getLossFunctionParameter() : null,
version.onOrAfter(Version.V_7_10_0) ? instance.getFeatureProcessors() : Collections.emptyList());
} }
@Override @Override
@ -83,14 +123,16 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
bwcSerializedObject.getTrainingPercent(), bwcSerializedObject.getTrainingPercent(),
42L, 42L,
bwcSerializedObject.getLossFunction(), bwcSerializedObject.getLossFunction(),
bwcSerializedObject.getLossFunctionParameter()); bwcSerializedObject.getLossFunctionParameter(),
bwcSerializedObject.getFeatureProcessors());
Regression newInstance = new Regression(testInstance.getDependentVariable(), Regression newInstance = new Regression(testInstance.getDependentVariable(),
testInstance.getBoostedTreeParams(), testInstance.getBoostedTreeParams(),
testInstance.getPredictionFieldName(), testInstance.getPredictionFieldName(),
testInstance.getTrainingPercent(), testInstance.getTrainingPercent(),
42L, 42L,
testInstance.getLossFunction(), testInstance.getLossFunction(),
testInstance.getLossFunctionParameter()); testInstance.getLossFunctionParameter(),
testInstance.getFeatureProcessors());
super.assertOnBWCObject(newBwc, newInstance, version); super.assertOnBWCObject(newBwc, newInstance, version);
} }
@ -104,56 +146,122 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
return Regression::new; return Regression::new;
} }
public void testDeserialization() throws IOException {
String toDeserialize = "{\n" +
" \"dependent_variable\": \"FlightDelayMin\",\n" +
" \"feature_processors\": [\n" +
" {\n" +
" \"one_hot_encoding\": {\n" +
" \"field\": \"OriginWeather\",\n" +
" \"hot_map\": {\n" +
" \"sunny_col\": \"Sunny\",\n" +
" \"clear_col\": \"Clear\",\n" +
" \"rainy_col\": \"Rain\"\n" +
" }\n" +
" }\n" +
" },\n" +
" {\n" +
" \"one_hot_encoding\": {\n" +
" \"field\": \"DestWeather\",\n" +
" \"hot_map\": {\n" +
" \"dest_sunny_col\": \"Sunny\",\n" +
" \"dest_clear_col\": \"Clear\",\n" +
" \"dest_rainy_col\": \"Rain\"\n" +
" }\n" +
" }\n" +
" },\n" +
" {\n" +
" \"frequency_encoding\": {\n" +
" \"field\": \"OriginWeather\",\n" +
" \"feature_name\": \"mean\",\n" +
" \"frequency_map\": {\n" +
" \"Sunny\": 0.8,\n" +
" \"Rain\": 0.2\n" +
" }\n" +
" }\n" +
" }\n" +
" ]\n" +
" }" +
"";
try(XContentParser parser = XContentHelper.createParser(xContentRegistry(),
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
new BytesArray(toDeserialize),
XContentType.JSON)) {
Regression parsed = Regression.fromXContent(parser, false);
assertThat(parsed.getDependentVariable(), equalTo("FlightDelayMin"));
for (PreProcessor preProcessor : parsed.getFeatureProcessors()) {
assertThat(preProcessor.isCustom(), is(true));
}
}
}
public void testConstructor_GivenTrainingPercentIsLessThanOne() { public void testConstructor_GivenTrainingPercentIsLessThanOne() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 0.999, randomLong(), Regression.LossFunction.MSE, null)); () -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 0.999, randomLong(), Regression.LossFunction.MSE, null, null));
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]")); assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
} }
public void testConstructor_GivenTrainingPercentIsGreaterThan100() { public void testConstructor_GivenTrainingPercentIsGreaterThan100() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0001, randomLong(), Regression.LossFunction.MSE, null)); () -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0001, randomLong(), Regression.LossFunction.MSE, null, null));
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]")); assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
} }
public void testConstructor_GivenLossFunctionParameterIsZero() { public void testConstructor_GivenLossFunctionParameterIsZero() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong(), Regression.LossFunction.MSE, 0.0)); () -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong(), Regression.LossFunction.MSE, 0.0, null));
assertThat(e.getMessage(), equalTo("[loss_function_parameter] must be a positive double")); assertThat(e.getMessage(), equalTo("[loss_function_parameter] must be a positive double"));
} }
public void testConstructor_GivenLossFunctionParameterIsNegative() { public void testConstructor_GivenLossFunctionParameterIsNegative() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong(), Regression.LossFunction.MSE, -1.0)); () -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong(), Regression.LossFunction.MSE, -1.0, null));
assertThat(e.getMessage(), equalTo("[loss_function_parameter] must be a positive double")); assertThat(e.getMessage(), equalTo("[loss_function_parameter] must be a positive double"));
} }
public void testGetPredictionFieldName() { public void testGetPredictionFieldName() {
Regression regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 50.0, randomLong(), Regression.LossFunction.MSE, 1.0); Regression regression = new Regression(
"foo",
BOOSTED_TREE_PARAMS,
"result",
50.0,
randomLong(),
Regression.LossFunction.MSE,
1.0,
null);
assertThat(regression.getPredictionFieldName(), equalTo("result")); assertThat(regression.getPredictionFieldName(), equalTo("result"));
regression = new Regression("foo", BOOSTED_TREE_PARAMS, null, 50.0, randomLong(), Regression.LossFunction.MSE, null); regression = new Regression("foo", BOOSTED_TREE_PARAMS, null, 50.0, randomLong(), Regression.LossFunction.MSE, null, null);
assertThat(regression.getPredictionFieldName(), equalTo("foo_prediction")); assertThat(regression.getPredictionFieldName(), equalTo("foo_prediction"));
} }
public void testGetTrainingPercent() { public void testGetTrainingPercent() {
Regression regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 50.0, randomLong(), Regression.LossFunction.MSE, 1.0); Regression regression = new Regression("foo",
BOOSTED_TREE_PARAMS,
"result",
50.0,
randomLong(),
Regression.LossFunction.MSE,
1.0,
null);
assertThat(regression.getTrainingPercent(), equalTo(50.0)); assertThat(regression.getTrainingPercent(), equalTo(50.0));
// Boundary condition: training_percent == 1.0 // Boundary condition: training_percent == 1.0
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 1.0, randomLong(), Regression.LossFunction.MSE, null); regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 1.0, randomLong(), Regression.LossFunction.MSE, null, null);
assertThat(regression.getTrainingPercent(), equalTo(1.0)); assertThat(regression.getTrainingPercent(), equalTo(1.0));
// Boundary condition: training_percent == 100.0 // Boundary condition: training_percent == 100.0
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong(), Regression.LossFunction.MSE, null); regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong(), Regression.LossFunction.MSE, null, null);
assertThat(regression.getTrainingPercent(), equalTo(100.0)); assertThat(regression.getTrainingPercent(), equalTo(100.0));
// training_percent == null, default applied // training_percent == null, default applied
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", null, randomLong(), Regression.LossFunction.MSE, null); regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", null, randomLong(), Regression.LossFunction.MSE, null, null);
assertThat(regression.getTrainingPercent(), equalTo(100.0)); assertThat(regression.getTrainingPercent(), equalTo(100.0));
} }
@ -165,6 +273,7 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
100.0, 100.0,
0L, 0L,
Regression.LossFunction.MSE, Regression.LossFunction.MSE,
null,
null); null);
Map<String, Object> params = regression.getParams(null); Map<String, Object> params = regression.getParams(null);
@ -182,7 +291,9 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
Map<String, Object> params = regression.getParams(null); Map<String, Object> params = regression.getParams(null);
int expectedParamsCount = 4 + (regression.getLossFunctionParameter() == null ? 0 : 1); int expectedParamsCount = 4
+ (regression.getLossFunctionParameter() == null ? 0 : 1)
+ (regression.getFeatureProcessors().isEmpty() ? 0 : 1);
assertThat(params.size(), equalTo(expectedParamsCount)); assertThat(params.size(), equalTo(expectedParamsCount));
assertThat(params.get("dependent_variable"), equalTo(regression.getDependentVariable())); assertThat(params.get("dependent_variable"), equalTo(regression.getDependentVariable()));
assertThat(params.get("prediction_field_name"), equalTo(regression.getPredictionFieldName())); assertThat(params.get("prediction_field_name"), equalTo(regression.getPredictionFieldName()));

View File

@ -24,7 +24,9 @@ public class FrequencyEncodingTests extends PreProcessingTests<FrequencyEncoding
@Override @Override
protected FrequencyEncoding doParseInstance(XContentParser parser) throws IOException { protected FrequencyEncoding doParseInstance(XContentParser parser) throws IOException {
return lenient ? FrequencyEncoding.fromXContentLenient(parser) : FrequencyEncoding.fromXContentStrict(parser); return lenient ?
FrequencyEncoding.fromXContentLenient(parser, PreProcessor.PreProcessorParseContext.DEFAULT) :
FrequencyEncoding.fromXContentStrict(parser, PreProcessor.PreProcessorParseContext.DEFAULT);
} }
@Override @Override
@ -33,6 +35,10 @@ public class FrequencyEncodingTests extends PreProcessingTests<FrequencyEncoding
} }
public static FrequencyEncoding createRandom() { public static FrequencyEncoding createRandom() {
return createRandom(randomBoolean() ? null : randomBoolean());
}
public static FrequencyEncoding createRandom(Boolean isCustom) {
int valuesSize = randomIntBetween(1, 10); int valuesSize = randomIntBetween(1, 10);
Map<String, Double> valueMap = new HashMap<>(); Map<String, Double> valueMap = new HashMap<>();
for (int i = 0; i < valuesSize; i++) { for (int i = 0; i < valuesSize; i++) {
@ -41,7 +47,7 @@ public class FrequencyEncodingTests extends PreProcessingTests<FrequencyEncoding
return new FrequencyEncoding(randomAlphaOfLength(10), return new FrequencyEncoding(randomAlphaOfLength(10),
randomAlphaOfLength(10), randomAlphaOfLength(10),
valueMap, valueMap,
randomBoolean() ? null : randomBoolean()); isCustom);
} }
@Override @Override

View File

@ -24,7 +24,9 @@ public class OneHotEncodingTests extends PreProcessingTests<OneHotEncoding> {
@Override @Override
protected OneHotEncoding doParseInstance(XContentParser parser) throws IOException { protected OneHotEncoding doParseInstance(XContentParser parser) throws IOException {
return lenient ? OneHotEncoding.fromXContentLenient(parser) : OneHotEncoding.fromXContentStrict(parser); return lenient ?
OneHotEncoding.fromXContentLenient(parser, PreProcessor.PreProcessorParseContext.DEFAULT) :
OneHotEncoding.fromXContentStrict(parser, PreProcessor.PreProcessorParseContext.DEFAULT);
} }
@Override @Override
@ -33,6 +35,10 @@ public class OneHotEncodingTests extends PreProcessingTests<OneHotEncoding> {
} }
public static OneHotEncoding createRandom() { public static OneHotEncoding createRandom() {
return createRandom(randomBoolean() ? randomBoolean() : null);
}
public static OneHotEncoding createRandom(Boolean isCustom) {
int valuesSize = randomIntBetween(1, 10); int valuesSize = randomIntBetween(1, 10);
Map<String, String> valueMap = new HashMap<>(); Map<String, String> valueMap = new HashMap<>();
for (int i = 0; i < valuesSize; i++) { for (int i = 0; i < valuesSize; i++) {
@ -40,7 +46,7 @@ public class OneHotEncodingTests extends PreProcessingTests<OneHotEncoding> {
} }
return new OneHotEncoding(randomAlphaOfLength(10), return new OneHotEncoding(randomAlphaOfLength(10),
valueMap, valueMap,
randomBoolean() ? randomBoolean() : null); isCustom);
} }
@Override @Override

View File

@ -24,7 +24,9 @@ public class TargetMeanEncodingTests extends PreProcessingTests<TargetMeanEncodi
@Override @Override
protected TargetMeanEncoding doParseInstance(XContentParser parser) throws IOException { protected TargetMeanEncoding doParseInstance(XContentParser parser) throws IOException {
return lenient ? TargetMeanEncoding.fromXContentLenient(parser) : TargetMeanEncoding.fromXContentStrict(parser); return lenient ?
TargetMeanEncoding.fromXContentLenient(parser, PreProcessor.PreProcessorParseContext.DEFAULT) :
TargetMeanEncoding.fromXContentStrict(parser, PreProcessor.PreProcessorParseContext.DEFAULT);
} }
@Override @Override
@ -32,7 +34,12 @@ public class TargetMeanEncodingTests extends PreProcessingTests<TargetMeanEncodi
return createRandom(); return createRandom();
} }
public static TargetMeanEncoding createRandom() { public static TargetMeanEncoding createRandom() {
return createRandom(randomBoolean() ? randomBoolean() : null);
}
public static TargetMeanEncoding createRandom(Boolean isCustom) {
int valuesSize = randomIntBetween(1, 10); int valuesSize = randomIntBetween(1, 10);
Map<String, Double> valueMap = new HashMap<>(); Map<String, Double> valueMap = new HashMap<>();
for (int i = 0; i < valuesSize; i++) { for (int i = 0; i < valuesSize; i++) {
@ -42,7 +49,7 @@ public class TargetMeanEncodingTests extends PreProcessingTests<TargetMeanEncodi
randomAlphaOfLength(10), randomAlphaOfLength(10),
valueMap, valueMap,
randomDoubleBetween(0.0, 1.0, false), randomDoubleBetween(0.0, 1.0, false),
randomBoolean() ? randomBoolean() : null); isCustom);
} }
@Override @Override

View File

@ -21,22 +21,30 @@ import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.rest.RestStatus; import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction; import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction; import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.action.NodeAcknowledgedResponse; import org.elasticsearch.xpack.core.ml.action.NodeAcknowledgedResponse;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigUpdate; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigUpdate;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams; import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
@ -108,6 +116,15 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
.get(); .get();
} }
@Override
protected NamedXContentRegistry xContentRegistry() {
SearchModule searchModule = new SearchModule(Settings.EMPTY, false, Collections.emptyList());
List<NamedXContentRegistry.Entry> entries = new ArrayList<>(searchModule.getNamedXContents());
entries.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
entries.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers());
return new NamedXContentRegistry(entries);
}
public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception { public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception {
initialize("classification_single_numeric_feature_and_mixed_data_set"); initialize("classification_single_numeric_feature_and_mixed_data_set");
String predictedClassField = KEYWORD_FIELD + "_prediction"; String predictedClassField = KEYWORD_FIELD + "_prediction";
@ -121,6 +138,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
null, null,
null, null,
null, null,
null,
null)); null));
putAnalytics(config); putAnalytics(config);
@ -176,6 +194,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
null, null,
null, null,
null, null,
null,
null)); null));
putAnalytics(config); putAnalytics(config);
@ -268,6 +287,76 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField); assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
} }
public void testWithCustomFeatureProcessors() throws Exception {
initialize("classification_with_custom_feature_processors");
String predictedClassField = KEYWORD_FIELD + "_prediction";
indexData(sourceIndex, 300, 50, KEYWORD_FIELD);
DataFrameAnalyticsConfig config =
buildAnalytics(jobId, sourceIndex, destIndex, null,
new Classification(
KEYWORD_FIELD,
BoostedTreeParams.builder().setNumTopFeatureImportanceValues(1).build(),
null,
null,
null,
null,
null,
Arrays.asList(
new OneHotEncoding(TEXT_FIELD, Collections.singletonMap(KEYWORD_FIELD_VALUES.get(0), "cat_column_custom"), true)
)));
putAnalytics(config);
assertIsStopped(jobId);
assertProgressIsZero(jobId);
startAnalytics(jobId);
waitUntilAnalyticsIsStopped(jobId);
client().admin().indices().refresh(new RefreshRequest(destIndex));
SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get();
for (SearchHit hit : sourceData.getHits()) {
Map<String, Object> destDoc = getDestDoc(config, hit);
Map<String, Object> resultsObject = getFieldValue(destDoc, "ml");
assertThat(getFieldValue(resultsObject, predictedClassField), is(in(KEYWORD_FIELD_VALUES)));
assertThat(getFieldValue(resultsObject, "is_training"), is(destDoc.containsKey(KEYWORD_FIELD)));
assertTopClasses(resultsObject, 2, KEYWORD_FIELD, KEYWORD_FIELD_VALUES);
@SuppressWarnings("unchecked")
List<Map<String, Object>> importanceArray = (List<Map<String, Object>>)resultsObject.get("feature_importance");
assertThat(importanceArray, hasSize(greaterThan(0)));
}
assertProgressComplete(jobId);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);
assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword");
assertThatAuditMessagesMatch(jobId,
"Created analytics with analysis type [classification]",
"Estimated memory usage for this analytics to be",
"Starting analytics on node",
"Started analytics",
expectedDestIndexAuditMessage(),
"Started reindexing to destination index [" + destIndex + "]",
"Finished reindexing to destination index [" + destIndex + "]",
"Started loading data",
"Started analyzing",
"Started writing results",
"Finished analysis");
assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
GetTrainedModelsAction.Response response = client().execute(GetTrainedModelsAction.INSTANCE,
new GetTrainedModelsAction.Request(jobId + "*", true, Collections.emptyList())).actionGet();
assertThat(response.getResources().results().size(), equalTo(1));
TrainedModelConfig modelConfig = response.getResources().results().get(0);
modelConfig.ensureParsedDefinition(xContentRegistry());
assertThat(modelConfig.getModelDefinition().getPreProcessors().size(), greaterThan(0));
for (int i = 0; i < modelConfig.getModelDefinition().getPreProcessors().size(); i++) {
PreProcessor preProcessor = modelConfig.getModelDefinition().getPreProcessors().get(i);
assertThat(preProcessor.isCustom(), equalTo(i == 0));
}
}
public <T> void testWithOnlyTrainingRowsAndTrainingPercentIsFifty(String jobId, public <T> void testWithOnlyTrainingRowsAndTrainingPercentIsFifty(String jobId,
String dependentVariable, String dependentVariable,
List<T> dependentVariableValues, List<T> dependentVariableValues,
@ -283,7 +372,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
sourceIndex, sourceIndex,
destIndex, destIndex,
null, null,
new Classification(dependentVariable, BoostedTreeParams.builder().build(), null, null, numTopClasses, 50.0, null)); new Classification(dependentVariable, BoostedTreeParams.builder().build(), null, null, numTopClasses, 50.0, null, null));
putAnalytics(config); putAnalytics(config);
assertIsStopped(jobId); assertIsStopped(jobId);
@ -352,7 +441,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
"classification_training_percent_is_50_integer", DISCRETE_NUMERICAL_FIELD, DISCRETE_NUMERICAL_FIELD_VALUES, "integer"); "classification_training_percent_is_50_integer", DISCRETE_NUMERICAL_FIELD, DISCRETE_NUMERICAL_FIELD_VALUES, "integer");
} }
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsDouble() throws Exception { public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsDouble() {
ElasticsearchStatusException e = expectThrows( ElasticsearchStatusException e = expectThrows(
ElasticsearchStatusException.class, ElasticsearchStatusException.class,
() -> testWithOnlyTrainingRowsAndTrainingPercentIsFifty( () -> testWithOnlyTrainingRowsAndTrainingPercentIsFifty(
@ -360,7 +449,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
assertThat(e.getMessage(), startsWith("invalid types [double] for required field [numerical-field];")); assertThat(e.getMessage(), startsWith("invalid types [double] for required field [numerical-field];"));
} }
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsText() throws Exception { public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsText() {
ElasticsearchStatusException e = expectThrows( ElasticsearchStatusException e = expectThrows(
ElasticsearchStatusException.class, ElasticsearchStatusException.class,
() -> testWithOnlyTrainingRowsAndTrainingPercentIsFifty( () -> testWithOnlyTrainingRowsAndTrainingPercentIsFifty(
@ -549,7 +638,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
.build(); .build();
DataFrameAnalyticsConfig firstJob = buildAnalytics(firstJobId, sourceIndex, firstJobDestIndex, null, DataFrameAnalyticsConfig firstJob = buildAnalytics(firstJobId, sourceIndex, firstJobDestIndex, null,
new Classification(dependentVariable, boostedTreeParams, null, null, 1, 50.0, null)); new Classification(dependentVariable, boostedTreeParams, null, null, 1, 50.0, null, null));
putAnalytics(firstJob); putAnalytics(firstJob);
String secondJobId = "classification_two_jobs_with_same_randomize_seed_2"; String secondJobId = "classification_two_jobs_with_same_randomize_seed_2";
@ -557,7 +646,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
long randomizeSeed = ((Classification) firstJob.getAnalysis()).getRandomizeSeed(); long randomizeSeed = ((Classification) firstJob.getAnalysis()).getRandomizeSeed();
DataFrameAnalyticsConfig secondJob = buildAnalytics(secondJobId, sourceIndex, secondJobDestIndex, null, DataFrameAnalyticsConfig secondJob = buildAnalytics(secondJobId, sourceIndex, secondJobDestIndex, null,
new Classification(dependentVariable, boostedTreeParams, null, null, 1, 50.0, randomizeSeed)); new Classification(dependentVariable, boostedTreeParams, null, null, 1, 50.0, randomizeSeed, null));
putAnalytics(secondJob); putAnalytics(secondJob);

View File

@ -104,6 +104,7 @@ public class ExplainDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsInteg
100.0, 100.0,
null, null,
null, null,
null,
null)) null))
.buildForExplain(); .buildForExplain();
@ -122,6 +123,7 @@ public class ExplainDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsInteg
50.0, 50.0,
null, null,
null, null,
null,
null)) null))
.buildForExplain(); .buildForExplain();
@ -149,6 +151,7 @@ public class ExplainDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsInteg
100.0, 100.0,
null, null,
null, null,
null,
null)) null))
.buildForExplain(); .buildForExplain();

View File

@ -14,24 +14,36 @@ import org.elasticsearch.action.get.GetResponse;
import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.rest.RestStatus; import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.search.fetch.subphase.FetchSourceContext; import org.elasticsearch.search.fetch.subphase.FetchSourceContext;
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction; import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.action.NodeAcknowledgedResponse; import org.elasticsearch.xpack.core.ml.action.NodeAcknowledgedResponse;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams; import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
import org.junit.After; import org.junit.After;
import java.io.IOException; import java.io.IOException;
import java.time.Instant; import java.time.Instant;
import java.util.Arrays; import java.util.Arrays;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
@ -65,6 +77,15 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
cleanUp(); cleanUp();
} }
@Override
protected NamedXContentRegistry xContentRegistry() {
SearchModule searchModule = new SearchModule(Settings.EMPTY, false, Collections.emptyList());
List<NamedXContentRegistry.Entry> entries = new ArrayList<>(searchModule.getNamedXContents());
entries.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
entries.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers());
return new NamedXContentRegistry(entries);
}
@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/59413") @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/59413")
public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception { public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception {
initialize("regression_single_numeric_feature_and_mixed_data_set"); initialize("regression_single_numeric_feature_and_mixed_data_set");
@ -79,6 +100,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
null, null,
null, null,
null, null,
null,
null) null)
); );
putAnalytics(config); putAnalytics(config);
@ -192,7 +214,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
sourceIndex, sourceIndex,
destIndex, destIndex,
null, null,
new Regression(DEPENDENT_VARIABLE_FIELD, BoostedTreeParams.builder().build(), null, 50.0, null, null, null)); new Regression(DEPENDENT_VARIABLE_FIELD, BoostedTreeParams.builder().build(), null, 50.0, null, null, null, null));
putAnalytics(config); putAnalytics(config);
assertIsStopped(jobId); assertIsStopped(jobId);
@ -319,7 +341,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
.build(); .build();
DataFrameAnalyticsConfig firstJob = buildAnalytics(firstJobId, sourceIndex, firstJobDestIndex, null, DataFrameAnalyticsConfig firstJob = buildAnalytics(firstJobId, sourceIndex, firstJobDestIndex, null,
new Regression(DEPENDENT_VARIABLE_FIELD, boostedTreeParams, null, 50.0, null, null, null)); new Regression(DEPENDENT_VARIABLE_FIELD, boostedTreeParams, null, 50.0, null, null, null, null));
putAnalytics(firstJob); putAnalytics(firstJob);
String secondJobId = "regression_two_jobs_with_same_randomize_seed_2"; String secondJobId = "regression_two_jobs_with_same_randomize_seed_2";
@ -327,7 +349,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
long randomizeSeed = ((Regression) firstJob.getAnalysis()).getRandomizeSeed(); long randomizeSeed = ((Regression) firstJob.getAnalysis()).getRandomizeSeed();
DataFrameAnalyticsConfig secondJob = buildAnalytics(secondJobId, sourceIndex, secondJobDestIndex, null, DataFrameAnalyticsConfig secondJob = buildAnalytics(secondJobId, sourceIndex, secondJobDestIndex, null,
new Regression(DEPENDENT_VARIABLE_FIELD, boostedTreeParams, null, 50.0, randomizeSeed, null, null)); new Regression(DEPENDENT_VARIABLE_FIELD, boostedTreeParams, null, 50.0, randomizeSeed, null, null, null));
putAnalytics(secondJob); putAnalytics(secondJob);
@ -388,7 +410,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
sourceIndex, sourceIndex,
destIndex, destIndex,
null, null,
new Regression(DISCRETE_NUMERICAL_FEATURE_FIELD, BoostedTreeParams.builder().build(), null, null, null, null, null)); new Regression(DISCRETE_NUMERICAL_FEATURE_FIELD, BoostedTreeParams.builder().build(), null, null, null, null, null, null));
putAnalytics(config); putAnalytics(config);
assertIsStopped(jobId); assertIsStopped(jobId);
@ -415,6 +437,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
null, null,
null, null,
null, null,
null,
null) null)
); );
putAnalytics(config); putAnalytics(config);
@ -511,6 +534,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
90.0, 90.0,
null, null,
null, null,
null,
null); null);
DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder() DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder()
.setId(jobId) .setId(jobId)
@ -566,6 +590,73 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
"Finished analysis"); "Finished analysis");
} }
public void testWithCustomFeatureProcessors() throws Exception {
initialize("regression_with_custom_feature_processors");
String predictedClassField = DEPENDENT_VARIABLE_FIELD + "_prediction";
indexData(sourceIndex, 300, 50);
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null,
new Regression(
DEPENDENT_VARIABLE_FIELD,
BoostedTreeParams.builder().setNumTopFeatureImportanceValues(1).build(),
null,
null,
null,
null,
null,
Arrays.asList(
new OneHotEncoding(DISCRETE_NUMERICAL_FEATURE_FIELD,
Collections.singletonMap(DISCRETE_NUMERICAL_FEATURE_VALUES.get(0).toString(), "tenner"), true)
))
);
putAnalytics(config);
assertIsStopped(jobId);
assertProgressIsZero(jobId);
startAnalytics(jobId);
waitUntilAnalyticsIsStopped(jobId);
// for debugging
SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get();
for (SearchHit hit : sourceData.getHits()) {
Map<String, Object> destDoc = getDestDoc(config, hit);
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(destDoc);
assertThat(resultsObject.containsKey(predictedClassField), is(true));
assertThat(resultsObject.containsKey("is_training"), is(true));
assertThat(resultsObject.get("is_training"), is(destDoc.containsKey(DEPENDENT_VARIABLE_FIELD)));
}
assertProgressComplete(jobId);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);
assertMlResultsFieldMappings(destIndex, predictedClassField, "double");
assertThatAuditMessagesMatch(jobId,
"Created analytics with analysis type [regression]",
"Estimated memory usage for this analytics to be",
"Starting analytics on node",
"Started analytics",
"Creating destination index [" + destIndex + "]",
"Started reindexing to destination index [" + destIndex + "]",
"Finished reindexing to destination index [" + destIndex + "]",
"Started loading data",
"Started analyzing",
"Started writing results",
"Finished analysis");
GetTrainedModelsAction.Response response = client().execute(GetTrainedModelsAction.INSTANCE,
new GetTrainedModelsAction.Request(jobId + "*", true, Collections.emptyList())).actionGet();
assertThat(response.getResources().results().size(), equalTo(1));
TrainedModelConfig modelConfig = response.getResources().results().get(0);
modelConfig.ensureParsedDefinition(xContentRegistry());
assertThat(modelConfig.getModelDefinition().getPreProcessors().size(), greaterThan(0));
for (int i = 0; i < modelConfig.getModelDefinition().getPreProcessors().size(); i++) {
PreProcessor preProcessor = modelConfig.getModelDefinition().getPreProcessors().get(i);
assertThat(preProcessor.isCustom(), equalTo(i == 0));
}
}
private void initialize(String jobId) { private void initialize(String jobId) {
this.jobId = jobId; this.jobId = jobId;
this.sourceIndex = jobId + "_source_index"; this.sourceIndex = jobId + "_source_index";

View File

@ -71,7 +71,7 @@ public class ChunkedTrainedModelPersisterIT extends MlSingleNodeTestCase {
analyticsConfig, analyticsConfig,
new DataFrameAnalyticsAuditor(client(), "test-node"), new DataFrameAnalyticsAuditor(client(), "test-node"),
(ex) -> { throw new ElasticsearchException(ex); }, (ex) -> { throw new ElasticsearchException(ex); },
new ExtractedFields(extractedFieldList, Collections.emptyMap()) new ExtractedFields(extractedFieldList, Collections.emptyList(), Collections.emptyMap())
); );
//Accuracy for size is not tested here //Accuracy for size is not tested here

View File

@ -27,6 +27,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigUpdate;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider; import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.ml.MlSingleNodeTestCase; import org.elasticsearch.xpack.ml.MlSingleNodeTestCase;
import org.elasticsearch.xpack.ml.dataframe.persistence.DataFrameAnalyticsConfigProvider; import org.elasticsearch.xpack.ml.dataframe.persistence.DataFrameAnalyticsConfigProvider;
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
@ -171,9 +172,9 @@ public class DataFrameAnalyticsConfigProviderIT extends MlSingleNodeTestCase {
blockingCall( blockingCall(
actionListener -> configProvider.put(initialConfig, emptyMap(), actionListener), configHolder, exceptionHolder); actionListener -> configProvider.put(initialConfig, emptyMap(), actionListener), configHolder, exceptionHolder);
assertNoException(exceptionHolder);
assertThat(configHolder.get(), is(notNullValue())); assertThat(configHolder.get(), is(notNullValue()));
assertThat(configHolder.get(), is(equalTo(initialConfig))); assertThat(configHolder.get(), is(equalTo(initialConfig)));
assertThat(exceptionHolder.get(), is(nullValue()));
} }
{ // Update that changes description { // Update that changes description
AtomicReference<DataFrameAnalyticsConfig> updatedConfigHolder = new AtomicReference<>(); AtomicReference<DataFrameAnalyticsConfig> updatedConfigHolder = new AtomicReference<>();
@ -188,7 +189,7 @@ public class DataFrameAnalyticsConfigProviderIT extends MlSingleNodeTestCase {
actionListener -> configProvider.update(configUpdate, emptyMap(), ClusterState.EMPTY_STATE, actionListener), actionListener -> configProvider.update(configUpdate, emptyMap(), ClusterState.EMPTY_STATE, actionListener),
updatedConfigHolder, updatedConfigHolder,
exceptionHolder); exceptionHolder);
assertNoException(exceptionHolder);
assertThat(updatedConfigHolder.get(), is(notNullValue())); assertThat(updatedConfigHolder.get(), is(notNullValue()));
assertThat( assertThat(
updatedConfigHolder.get(), updatedConfigHolder.get(),
@ -196,7 +197,6 @@ public class DataFrameAnalyticsConfigProviderIT extends MlSingleNodeTestCase {
new DataFrameAnalyticsConfig.Builder(initialConfig) new DataFrameAnalyticsConfig.Builder(initialConfig)
.setDescription("description-1") .setDescription("description-1")
.build()))); .build())));
assertThat(exceptionHolder.get(), is(nullValue()));
} }
{ // Update that changes model memory limit { // Update that changes model memory limit
AtomicReference<DataFrameAnalyticsConfig> updatedConfigHolder = new AtomicReference<>(); AtomicReference<DataFrameAnalyticsConfig> updatedConfigHolder = new AtomicReference<>();
@ -212,6 +212,7 @@ public class DataFrameAnalyticsConfigProviderIT extends MlSingleNodeTestCase {
updatedConfigHolder, updatedConfigHolder,
exceptionHolder); exceptionHolder);
assertNoException(exceptionHolder);
assertThat(updatedConfigHolder.get(), is(notNullValue())); assertThat(updatedConfigHolder.get(), is(notNullValue()));
assertThat( assertThat(
updatedConfigHolder.get(), updatedConfigHolder.get(),
@ -220,7 +221,6 @@ public class DataFrameAnalyticsConfigProviderIT extends MlSingleNodeTestCase {
.setDescription("description-1") .setDescription("description-1")
.setModelMemoryLimit(new ByteSizeValue(1024)) .setModelMemoryLimit(new ByteSizeValue(1024))
.build()))); .build())));
assertThat(exceptionHolder.get(), is(nullValue()));
} }
{ // Noop update { // Noop update
AtomicReference<DataFrameAnalyticsConfig> updatedConfigHolder = new AtomicReference<>(); AtomicReference<DataFrameAnalyticsConfig> updatedConfigHolder = new AtomicReference<>();
@ -233,6 +233,7 @@ public class DataFrameAnalyticsConfigProviderIT extends MlSingleNodeTestCase {
updatedConfigHolder, updatedConfigHolder,
exceptionHolder); exceptionHolder);
assertNoException(exceptionHolder);
assertThat(updatedConfigHolder.get(), is(notNullValue())); assertThat(updatedConfigHolder.get(), is(notNullValue()));
assertThat( assertThat(
updatedConfigHolder.get(), updatedConfigHolder.get(),
@ -241,7 +242,6 @@ public class DataFrameAnalyticsConfigProviderIT extends MlSingleNodeTestCase {
.setDescription("description-1") .setDescription("description-1")
.setModelMemoryLimit(new ByteSizeValue(1024)) .setModelMemoryLimit(new ByteSizeValue(1024))
.build()))); .build())));
assertThat(exceptionHolder.get(), is(nullValue()));
} }
{ // Update that changes both description and model memory limit { // Update that changes both description and model memory limit
AtomicReference<DataFrameAnalyticsConfig> updatedConfigHolder = new AtomicReference<>(); AtomicReference<DataFrameAnalyticsConfig> updatedConfigHolder = new AtomicReference<>();
@ -258,6 +258,7 @@ public class DataFrameAnalyticsConfigProviderIT extends MlSingleNodeTestCase {
updatedConfigHolder, updatedConfigHolder,
exceptionHolder); exceptionHolder);
assertNoException(exceptionHolder);
assertThat(updatedConfigHolder.get(), is(notNullValue())); assertThat(updatedConfigHolder.get(), is(notNullValue()));
assertThat( assertThat(
updatedConfigHolder.get(), updatedConfigHolder.get(),
@ -266,7 +267,6 @@ public class DataFrameAnalyticsConfigProviderIT extends MlSingleNodeTestCase {
.setDescription("description-2") .setDescription("description-2")
.setModelMemoryLimit(new ByteSizeValue(2048)) .setModelMemoryLimit(new ByteSizeValue(2048))
.build()))); .build())));
assertThat(exceptionHolder.get(), is(nullValue()));
} }
{ // Update that applies security headers { // Update that applies security headers
Map<String, String> securityHeaders = Collections.singletonMap("_xpack_security_authentication", "dummy"); Map<String, String> securityHeaders = Collections.singletonMap("_xpack_security_authentication", "dummy");
@ -281,6 +281,7 @@ public class DataFrameAnalyticsConfigProviderIT extends MlSingleNodeTestCase {
updatedConfigHolder, updatedConfigHolder,
exceptionHolder); exceptionHolder);
assertNoException(exceptionHolder);
assertThat(updatedConfigHolder.get(), is(notNullValue())); assertThat(updatedConfigHolder.get(), is(notNullValue()));
assertThat( assertThat(
updatedConfigHolder.get(), updatedConfigHolder.get(),
@ -290,7 +291,6 @@ public class DataFrameAnalyticsConfigProviderIT extends MlSingleNodeTestCase {
.setModelMemoryLimit(new ByteSizeValue(2048)) .setModelMemoryLimit(new ByteSizeValue(2048))
.setHeaders(securityHeaders) .setHeaders(securityHeaders)
.build()))); .build())));
assertThat(exceptionHolder.get(), is(nullValue()));
} }
} }
@ -371,6 +371,7 @@ public class DataFrameAnalyticsConfigProviderIT extends MlSingleNodeTestCase {
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>(); List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers()); namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers());
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents()); namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
return new NamedXContentRegistry(namedXContent); return new NamedXContentRegistry(namedXContent);
} }
} }

View File

@ -28,7 +28,9 @@ public class TimeBasedExtractedFields extends ExtractedFields {
private final ExtractedField timeField; private final ExtractedField timeField;
public TimeBasedExtractedFields(ExtractedField timeField, List<ExtractedField> allFields) { public TimeBasedExtractedFields(ExtractedField timeField, List<ExtractedField> allFields) {
super(allFields, Collections.emptyMap()); super(allFields,
Collections.emptyList(),
Collections.emptyMap());
if (!allFields.contains(timeField)) { if (!allFields.contains(timeField)) {
throw new IllegalArgumentException("timeField should also be contained in allFields"); throw new IllegalArgumentException("timeField should also be contained in allFields");
} }

View File

@ -28,15 +28,18 @@ import org.elasticsearch.search.fetch.StoredFieldsContext;
import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis; import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.dataframe.DestinationIndex; import org.elasticsearch.xpack.ml.dataframe.DestinationIndex;
import org.elasticsearch.xpack.ml.dataframe.traintestsplit.TrainTestSplitter; import org.elasticsearch.xpack.ml.dataframe.traintestsplit.TrainTestSplitter;
import org.elasticsearch.xpack.ml.extractor.ExtractedField; import org.elasticsearch.xpack.ml.extractor.ExtractedField;
import org.elasticsearch.xpack.ml.extractor.ExtractedFields; import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
import org.elasticsearch.xpack.ml.extractor.ProcessedField;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.NoSuchElementException; import java.util.NoSuchElementException;
@ -46,6 +49,7 @@ import java.util.Set;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.function.Supplier; import java.util.function.Supplier;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.Stream;
/** /**
* An implementation that extracts data from elasticsearch using search and scroll on a client. * An implementation that extracts data from elasticsearch using search and scroll on a client.
@ -67,10 +71,29 @@ public class DataFrameDataExtractor {
private boolean hasNext; private boolean hasNext;
private boolean searchHasShardFailure; private boolean searchHasShardFailure;
private final CachedSupplier<TrainTestSplitter> trainTestSplitter; private final CachedSupplier<TrainTestSplitter> trainTestSplitter;
// These are fields that are sent directly to the analytics process
// They are not passed through a feature_processor
private final String[] organicFeatures;
// These are the output field names for the feature_processors
private final String[] processedFeatures;
private final Map<String, ExtractedField> extractedFieldsByName;
DataFrameDataExtractor(Client client, DataFrameDataExtractorContext context) { DataFrameDataExtractor(Client client, DataFrameDataExtractorContext context) {
this.client = Objects.requireNonNull(client); this.client = Objects.requireNonNull(client);
this.context = Objects.requireNonNull(context); this.context = Objects.requireNonNull(context);
Set<String> processedFieldInputs = context.extractedFields.getProcessedFieldInputs();
this.organicFeatures = context.extractedFields.getAllFields()
.stream()
.map(ExtractedField::getName)
.filter(f -> processedFieldInputs.contains(f) == false)
.toArray(String[]::new);
this.processedFeatures = context.extractedFields.getProcessedFields()
.stream()
.map(ProcessedField::getOutputFieldNames)
.flatMap(List::stream)
.toArray(String[]::new);
this.extractedFieldsByName = new LinkedHashMap<>();
context.extractedFields.getAllFields().forEach(f -> this.extractedFieldsByName.put(f.getName(), f));
hasNext = true; hasNext = true;
searchHasShardFailure = false; searchHasShardFailure = false;
this.trainTestSplitter = new CachedSupplier<>(context.trainTestSplitterFactory::create); this.trainTestSplitter = new CachedSupplier<>(context.trainTestSplitterFactory::create);
@ -188,26 +211,78 @@ public class DataFrameDataExtractor {
return rows; return rows;
} }
private String extractNonProcessedValues(SearchHit hit, String organicFeature) {
ExtractedField field = extractedFieldsByName.get(organicFeature);
Object[] values = field.value(hit);
if (values.length == 1 && isValidValue(values[0])) {
return Objects.toString(values[0]);
}
if (values.length == 0 && context.supportsRowsWithMissingValues) {
// if values is empty then it means it's a missing value
return NULL_VALUE;
}
// we are here if we have a missing value but the analysis does not support those
// or the value type is not supported (e.g. arrays, etc.)
return null;
}
private String[] extractProcessedValue(ProcessedField processedField, SearchHit hit) {
Object[] values = processedField.value(hit, extractedFieldsByName::get);
if (values.length == 0 && context.supportsRowsWithMissingValues == false) {
return null;
}
final String[] extractedValue = new String[processedField.getOutputFieldNames().size()];
for (int i = 0; i < processedField.getOutputFieldNames().size(); i++) {
extractedValue[i] = NULL_VALUE;
}
// if values is empty then it means it's a missing value
if (values.length == 0) {
return extractedValue;
}
if (values.length != processedField.getOutputFieldNames().size()) {
throw ExceptionsHelper.badRequestException(
"field_processor [{}] output size expected to be [{}], instead it was [{}]",
processedField.getProcessorName(),
processedField.getOutputFieldNames().size(),
values.length);
}
for (int i = 0; i < processedField.getOutputFieldNames().size(); ++i) {
Object value = values[i];
if (value == null && context.supportsRowsWithMissingValues) {
continue;
}
if (isValidValue(value) == false) {
// we are here if we have a missing value but the analysis does not support those
// or the value type is not supported (e.g. arrays, etc.)
return null;
}
extractedValue[i] = Objects.toString(value);
}
return extractedValue;
}
private Row createRow(SearchHit hit) { private Row createRow(SearchHit hit) {
String[] extractedValues = new String[context.extractedFields.getAllFields().size()]; String[] extractedValues = new String[organicFeatures.length + processedFeatures.length];
for (int i = 0; i < extractedValues.length; ++i) { int i = 0;
ExtractedField field = context.extractedFields.getAllFields().get(i); for (String organicFeature : organicFeatures) {
Object[] values = field.value(hit); String extractedValue = extractNonProcessedValues(hit, organicFeature);
if (values.length == 1 && (values[0] instanceof Number || values[0] instanceof String)) { if (extractedValue == null) {
extractedValues[i] = Objects.toString(values[0]); return new Row(null, hit, true);
} else { }
if (values.length == 0 && context.supportsRowsWithMissingValues) { extractedValues[i++] = extractedValue;
// if values is empty then it means it's a missing value }
extractedValues[i] = NULL_VALUE; for (ProcessedField processedField : context.extractedFields.getProcessedFields()) {
} else { String[] processedValues = extractProcessedValue(processedField, hit);
// we are here if we have a missing value but the analysis does not support those if (processedValues == null) {
// or the value type is not supported (e.g. arrays, etc.) return new Row(null, hit, true);
extractedValues = null; }
break; for (String processedValue : processedValues) {
} extractedValues[i++] = processedValue;
} }
} }
boolean isTraining = extractedValues == null ? false : trainTestSplitter.get().isTraining(extractedValues); boolean isTraining = trainTestSplitter.get().isTraining(extractedValues);
return new Row(extractedValues, hit, isTraining); return new Row(extractedValues, hit, isTraining);
} }
@ -241,7 +316,7 @@ public class DataFrameDataExtractor {
} }
public List<String> getFieldNames() { public List<String> getFieldNames() {
return context.extractedFields.getAllFields().stream().map(ExtractedField::getName).collect(Collectors.toList()); return Stream.concat(Arrays.stream(organicFeatures), Arrays.stream(processedFeatures)).collect(Collectors.toList());
} }
public ExtractedFields getExtractedFields() { public ExtractedFields getExtractedFields() {
@ -253,12 +328,12 @@ public class DataFrameDataExtractor {
SearchResponse searchResponse = executeSearchRequest(searchRequestBuilder); SearchResponse searchResponse = executeSearchRequest(searchRequestBuilder);
long rows = searchResponse.getHits().getTotalHits().value; long rows = searchResponse.getHits().getTotalHits().value;
LOGGER.debug("[{}] Data summary rows [{}]", context.jobId, rows); LOGGER.debug("[{}] Data summary rows [{}]", context.jobId, rows);
return new DataSummary(rows, context.extractedFields.getAllFields().size()); return new DataSummary(rows, organicFeatures.length + processedFeatures.length);
} }
public void collectDataSummaryAsync(ActionListener<DataSummary> dataSummaryActionListener) { public void collectDataSummaryAsync(ActionListener<DataSummary> dataSummaryActionListener) {
SearchRequestBuilder searchRequestBuilder = buildDataSummarySearchRequestBuilder(); SearchRequestBuilder searchRequestBuilder = buildDataSummarySearchRequestBuilder();
final int numberOfFields = context.extractedFields.getAllFields().size(); final int numberOfFields = organicFeatures.length + processedFeatures.length;
ClientHelper.executeWithHeadersAsync(context.headers, ClientHelper.executeWithHeadersAsync(context.headers,
ClientHelper.ML_ORIGIN, ClientHelper.ML_ORIGIN,
@ -298,7 +373,11 @@ public class DataFrameDataExtractor {
} }
public Set<String> getCategoricalFields(DataFrameAnalysis analysis) { public Set<String> getCategoricalFields(DataFrameAnalysis analysis) {
return ExtractedFieldsDetector.getCategoricalFields(context.extractedFields, analysis); return ExtractedFieldsDetector.getCategoricalOutputFields(context.extractedFields, analysis);
}
private static boolean isValidValue(Object value) {
return value instanceof Number || value instanceof String;
} }
public static class DataSummary { public static class DataSummary {

View File

@ -13,27 +13,33 @@ import org.elasticsearch.action.fieldcaps.FieldCapabilitiesResponse;
import org.elasticsearch.common.Strings; import org.elasticsearch.common.Strings;
import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.regex.Regex; import org.elasticsearch.common.regex.Regex;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.IndexSettings;
import org.elasticsearch.index.mapper.BooleanFieldMapper; import org.elasticsearch.index.mapper.BooleanFieldMapper;
import org.elasticsearch.index.mapper.ObjectMapper; import org.elasticsearch.index.mapper.ObjectMapper;
import org.elasticsearch.search.fetch.subphase.FetchSourceContext; import org.elasticsearch.search.fetch.subphase.FetchSourceContext;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis; import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.FieldCardinalityConstraint; import org.elasticsearch.xpack.core.ml.dataframe.analyses.FieldCardinalityConstraint;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.RequiredField; import org.elasticsearch.xpack.core.ml.dataframe.analyses.RequiredField;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Types; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Types;
import org.elasticsearch.xpack.core.ml.dataframe.explain.FieldSelection; import org.elasticsearch.xpack.core.ml.dataframe.explain.FieldSelection;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.NameResolver; import org.elasticsearch.xpack.core.ml.utils.NameResolver;
import org.elasticsearch.xpack.ml.dataframe.DestinationIndex; import org.elasticsearch.xpack.ml.dataframe.DestinationIndex;
import org.elasticsearch.xpack.ml.extractor.ExtractedField; import org.elasticsearch.xpack.ml.extractor.ExtractedField;
import org.elasticsearch.xpack.ml.extractor.ExtractedFields; import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
import org.elasticsearch.xpack.ml.extractor.ProcessedField;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.Comparator; import java.util.Comparator;
import java.util.HashSet;
import java.util.Iterator; import java.util.Iterator;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.List; import java.util.List;
@ -60,7 +66,9 @@ public class ExtractedFieldsDetector {
private final FieldCapabilitiesResponse fieldCapabilitiesResponse; private final FieldCapabilitiesResponse fieldCapabilitiesResponse;
private final Map<String, Long> cardinalitiesForFieldsWithConstraints; private final Map<String, Long> cardinalitiesForFieldsWithConstraints;
ExtractedFieldsDetector(DataFrameAnalyticsConfig config, int docValueFieldsLimit, FieldCapabilitiesResponse fieldCapabilitiesResponse, ExtractedFieldsDetector(DataFrameAnalyticsConfig config,
int docValueFieldsLimit,
FieldCapabilitiesResponse fieldCapabilitiesResponse,
Map<String, Long> cardinalitiesForFieldsWithConstraints) { Map<String, Long> cardinalitiesForFieldsWithConstraints) {
this.config = Objects.requireNonNull(config); this.config = Objects.requireNonNull(config);
this.docValueFieldsLimit = docValueFieldsLimit; this.docValueFieldsLimit = docValueFieldsLimit;
@ -69,23 +77,39 @@ public class ExtractedFieldsDetector {
} }
public Tuple<ExtractedFields, List<FieldSelection>> detect() { public Tuple<ExtractedFields, List<FieldSelection>> detect() {
List<ProcessedField> processedFields = extractFeatureProcessors()
.stream()
.map(ProcessedField::new)
.collect(Collectors.toList());
TreeSet<FieldSelection> fieldSelection = new TreeSet<>(Comparator.comparing(FieldSelection::getName)); TreeSet<FieldSelection> fieldSelection = new TreeSet<>(Comparator.comparing(FieldSelection::getName));
Set<String> fields = getIncludedFields(fieldSelection); Set<String> fields = getIncludedFields(fieldSelection,
processedFields.stream()
.map(ProcessedField::getInputFieldNames)
.flatMap(List::stream)
.collect(Collectors.toSet()));
checkFieldsHaveCompatibleTypes(fields); checkFieldsHaveCompatibleTypes(fields);
checkRequiredFields(fields); checkRequiredFields(fields);
checkFieldsWithCardinalityLimit(); checkFieldsWithCardinalityLimit();
ExtractedFields extractedFields = detectExtractedFields(fields, fieldSelection); ExtractedFields extractedFields = detectExtractedFields(fields, fieldSelection, processedFields);
addIncludedFields(extractedFields, fieldSelection); addIncludedFields(extractedFields, fieldSelection);
checkOutputFeatureUniqueness(processedFields, fields);
return Tuple.tuple(extractedFields, Collections.unmodifiableList(new ArrayList<>(fieldSelection))); return Tuple.tuple(extractedFields, Collections.unmodifiableList(new ArrayList<>(fieldSelection)));
} }
private Set<String> getIncludedFields(Set<FieldSelection> fieldSelection) { private Set<String> getIncludedFields(Set<FieldSelection> fieldSelection, Set<String> requiredFieldsForProcessors) {
Set<String> fields = new TreeSet<>(fieldCapabilitiesResponse.get().keySet()); Set<String> fields = new TreeSet<>(fieldCapabilitiesResponse.get().keySet());
validateFieldsRequireForProcessors(requiredFieldsForProcessors);
fields.removeAll(IGNORE_FIELDS); fields.removeAll(IGNORE_FIELDS);
removeFieldsUnderResultsField(fields); removeFieldsUnderResultsField(fields);
removeObjects(fields); removeObjects(fields);
applySourceFiltering(fields); applySourceFiltering(fields);
if (fields.containsAll(requiredFieldsForProcessors) == false) {
throw ExceptionsHelper.badRequestException(
"fields {} required by field_processors are not included in source filtering.",
Sets.difference(requiredFieldsForProcessors, fields));
}
FetchSourceContext analyzedFields = config.getAnalyzedFields(); FetchSourceContext analyzedFields = config.getAnalyzedFields();
// If the user has not explicitly included fields we'll include all compatible fields // If the user has not explicitly included fields we'll include all compatible fields
@ -93,20 +117,63 @@ public class ExtractedFieldsDetector {
removeFieldsWithIncompatibleTypes(fields, fieldSelection); removeFieldsWithIncompatibleTypes(fields, fieldSelection);
} }
includeAndExcludeFields(fields, fieldSelection); includeAndExcludeFields(fields, fieldSelection);
if (fields.containsAll(requiredFieldsForProcessors) == false) {
throw ExceptionsHelper.badRequestException(
"fields {} required by field_processors are not included in the analyzed_fields.",
Sets.difference(requiredFieldsForProcessors, fields));
}
return fields; return fields;
} }
private void validateFieldsRequireForProcessors(Set<String> processorFields) {
Set<String> fieldsForProcessor = new HashSet<>(processorFields);
removeFieldsUnderResultsField(fieldsForProcessor);
if (fieldsForProcessor.size() < processorFields.size()) {
throw ExceptionsHelper.badRequestException("fields contained in results field [{}] cannot be used in a feature_processor",
config.getDest().getResultsField());
}
removeObjects(fieldsForProcessor);
if (fieldsForProcessor.size() < processorFields.size()) {
throw ExceptionsHelper.badRequestException("fields for feature_processors must not be objects");
}
fieldsForProcessor.removeAll(IGNORE_FIELDS);
if (fieldsForProcessor.size() < processorFields.size()) {
throw ExceptionsHelper.badRequestException("the following fields cannot be used in feature_processors {}", IGNORE_FIELDS);
}
List<String> fieldsMissingInMapping = processorFields.stream()
.filter(f -> fieldCapabilitiesResponse.get().containsKey(f) == false)
.collect(Collectors.toList());
if (fieldsMissingInMapping.isEmpty() == false) {
throw ExceptionsHelper.badRequestException(
"the fields {} were not found in the field capabilities of the source indices [{}]. "
+ "Fields must exist and be mapped to be used in feature_processors.",
fieldsMissingInMapping,
Strings.arrayToCommaDelimitedString(config.getSource().getIndex()));
}
List<String> processedRequiredFields = config.getAnalysis()
.getRequiredFields()
.stream()
.map(RequiredField::getName)
.filter(processorFields::contains)
.collect(Collectors.toList());
if (processedRequiredFields.isEmpty() == false) {
throw ExceptionsHelper.badRequestException(
"required analysis fields {} cannot be used in a feature_processor",
processedRequiredFields);
}
}
private void removeFieldsUnderResultsField(Set<String> fields) { private void removeFieldsUnderResultsField(Set<String> fields) {
String resultsField = config.getDest().getResultsField(); final String resultsFieldPrefix = config.getDest().getResultsField() + ".";
Iterator<String> fieldsIterator = fields.iterator(); Iterator<String> fieldsIterator = fields.iterator();
while (fieldsIterator.hasNext()) { while (fieldsIterator.hasNext()) {
String field = fieldsIterator.next(); String field = fieldsIterator.next();
if (field.startsWith(resultsField + ".")) { if (field.startsWith(resultsFieldPrefix)) {
fieldsIterator.remove(); fieldsIterator.remove();
} }
} }
fields.removeIf(field -> field.startsWith(resultsField + ".")); fields.removeIf(field -> field.startsWith(resultsFieldPrefix));
} }
private void removeObjects(Set<String> fields) { private void removeObjects(Set<String> fields) {
@ -287,9 +354,23 @@ public class ExtractedFieldsDetector {
} }
} }
private ExtractedFields detectExtractedFields(Set<String> fields, Set<FieldSelection> fieldSelection) { private List<PreProcessor> extractFeatureProcessors() {
ExtractedFields extractedFields = ExtractedFields.build(fields, Collections.emptySet(), fieldCapabilitiesResponse, if (config.getAnalysis() instanceof Classification) {
cardinalitiesForFieldsWithConstraints); return ((Classification)config.getAnalysis()).getFeatureProcessors();
} else if (config.getAnalysis() instanceof Regression) {
return ((Regression)config.getAnalysis()).getFeatureProcessors();
}
return Collections.emptyList();
}
private ExtractedFields detectExtractedFields(Set<String> fields,
Set<FieldSelection> fieldSelection,
List<ProcessedField> processedFields) {
ExtractedFields extractedFields = ExtractedFields.build(fields,
Collections.emptySet(),
fieldCapabilitiesResponse,
cardinalitiesForFieldsWithConstraints,
processedFields);
boolean preferSource = extractedFields.getDocValueFields().size() > docValueFieldsLimit; boolean preferSource = extractedFields.getDocValueFields().size() > docValueFieldsLimit;
extractedFields = deduplicateMultiFields(extractedFields, preferSource, fieldSelection); extractedFields = deduplicateMultiFields(extractedFields, preferSource, fieldSelection);
if (preferSource) { if (preferSource) {
@ -304,10 +385,15 @@ public class ExtractedFieldsDetector {
return extractedFields; return extractedFields;
} }
private ExtractedFields deduplicateMultiFields(ExtractedFields extractedFields, boolean preferSource, private ExtractedFields deduplicateMultiFields(ExtractedFields extractedFields,
boolean preferSource,
Set<FieldSelection> fieldSelection) { Set<FieldSelection> fieldSelection) {
Set<String> requiredFields = config.getAnalysis().getRequiredFields().stream().map(RequiredField::getName) Set<String> requiredFields = config.getAnalysis()
.getRequiredFields()
.stream()
.map(RequiredField::getName)
.collect(Collectors.toSet()); .collect(Collectors.toSet());
Set<String> processorInputFields = extractedFields.getProcessedFieldInputs();
Map<String, ExtractedField> nameOrParentToField = new LinkedHashMap<>(); Map<String, ExtractedField> nameOrParentToField = new LinkedHashMap<>();
for (ExtractedField currentField : extractedFields.getAllFields()) { for (ExtractedField currentField : extractedFields.getAllFields()) {
String nameOrParent = currentField.isMultiField() ? currentField.getParentField() : currentField.getName(); String nameOrParent = currentField.isMultiField() ? currentField.getParentField() : currentField.getName();
@ -315,15 +401,37 @@ public class ExtractedFieldsDetector {
if (existingField != null) { if (existingField != null) {
ExtractedField parent = currentField.isMultiField() ? existingField : currentField; ExtractedField parent = currentField.isMultiField() ? existingField : currentField;
ExtractedField multiField = currentField.isMultiField() ? currentField : existingField; ExtractedField multiField = currentField.isMultiField() ? currentField : existingField;
// If required fields contains parent or multifield and the processor input fields reference the other, that is an error
// we should not allow processing of data that is required.
if ((requiredFields.contains(parent.getName()) && processorInputFields.contains(multiField.getName()))
|| (requiredFields.contains(multiField.getName()) && processorInputFields.contains(parent.getName()))) {
throw ExceptionsHelper.badRequestException(
"feature_processors cannot be applied to required fields for analysis; multi-field [{}] parent [{}]",
multiField.getName(),
parent.getName());
}
// If processor input fields have BOTH, we need to keep both.
if (processorInputFields.contains(parent.getName()) && processorInputFields.contains(multiField.getName())) {
throw ExceptionsHelper.badRequestException(
"feature_processors refer to both multi-field [{}] and parent [{}]. Please only refer to one or the other",
multiField.getName(),
parent.getName());
}
nameOrParentToField.put(nameOrParent, nameOrParentToField.put(nameOrParent,
chooseMultiFieldOrParent(preferSource, requiredFields, parent, multiField, fieldSelection)); chooseMultiFieldOrParent(preferSource, requiredFields, processorInputFields, parent, multiField, fieldSelection));
} }
} }
return new ExtractedFields(new ArrayList<>(nameOrParentToField.values()), cardinalitiesForFieldsWithConstraints); return new ExtractedFields(new ArrayList<>(nameOrParentToField.values()),
extractedFields.getProcessedFields(),
cardinalitiesForFieldsWithConstraints);
} }
private ExtractedField chooseMultiFieldOrParent(boolean preferSource, Set<String> requiredFields, ExtractedField parent, private ExtractedField chooseMultiFieldOrParent(boolean preferSource,
ExtractedField multiField, Set<FieldSelection> fieldSelection) { Set<String> requiredFields,
Set<String> processorInputFields,
ExtractedField parent,
ExtractedField multiField,
Set<FieldSelection> fieldSelection) {
// Check requirements first // Check requirements first
if (requiredFields.contains(parent.getName())) { if (requiredFields.contains(parent.getName())) {
addExcludedField(multiField.getName(), "[" + parent.getName() + "] is required instead", fieldSelection); addExcludedField(multiField.getName(), "[" + parent.getName() + "] is required instead", fieldSelection);
@ -333,6 +441,19 @@ public class ExtractedFieldsDetector {
addExcludedField(parent.getName(), "[" + multiField.getName() + "] is required instead", fieldSelection); addExcludedField(parent.getName(), "[" + multiField.getName() + "] is required instead", fieldSelection);
return multiField; return multiField;
} }
// Choose the one required by our processors
if (processorInputFields.contains(parent.getName())) {
addExcludedField(multiField.getName(),
"[" + parent.getName() + "] is referenced by feature_processors instead",
fieldSelection);
return parent;
}
if (processorInputFields.contains(multiField.getName())) {
addExcludedField(parent.getName(),
"[" + multiField.getName() + "] is referenced by feature_processors instead",
fieldSelection);
return multiField;
}
// If both are multi-fields it means there are several. In this case parent is the previous multi-field // If both are multi-fields it means there are several. In this case parent is the previous multi-field
// we selected. We'll just keep that. // we selected. We'll just keep that.
@ -370,7 +491,9 @@ public class ExtractedFieldsDetector {
for (ExtractedField field : extractedFields.getAllFields()) { for (ExtractedField field : extractedFields.getAllFields()) {
adjusted.add(field.supportsFromSource() ? field.newFromSource() : field); adjusted.add(field.supportsFromSource() ? field.newFromSource() : field);
} }
return new ExtractedFields(adjusted, cardinalitiesForFieldsWithConstraints); return new ExtractedFields(adjusted,
extractedFields.getProcessedFields(),
cardinalitiesForFieldsWithConstraints);
} }
private ExtractedFields fetchBooleanFieldsAsIntegers(ExtractedFields extractedFields) { private ExtractedFields fetchBooleanFieldsAsIntegers(ExtractedFields extractedFields) {
@ -387,13 +510,15 @@ public class ExtractedFieldsDetector {
adjusted.add(field); adjusted.add(field);
} }
} }
return new ExtractedFields(adjusted, cardinalitiesForFieldsWithConstraints); return new ExtractedFields(adjusted,
extractedFields.getProcessedFields(),
cardinalitiesForFieldsWithConstraints);
} }
private void addIncludedFields(ExtractedFields extractedFields, Set<FieldSelection> fieldSelection) { private void addIncludedFields(ExtractedFields extractedFields, Set<FieldSelection> fieldSelection) {
Set<String> requiredFields = config.getAnalysis().getRequiredFields().stream().map(RequiredField::getName) Set<String> requiredFields = config.getAnalysis().getRequiredFields().stream().map(RequiredField::getName)
.collect(Collectors.toSet()); .collect(Collectors.toSet());
Set<String> categoricalFields = getCategoricalFields(extractedFields, config.getAnalysis()); Set<String> categoricalFields = getCategoricalInputFields(extractedFields, config.getAnalysis());
for (ExtractedField includedField : extractedFields.getAllFields()) { for (ExtractedField includedField : extractedFields.getAllFields()) {
FieldSelection.FeatureType featureType = categoricalFields.contains(includedField.getName()) ? FieldSelection.FeatureType featureType = categoricalFields.contains(includedField.getName()) ?
FieldSelection.FeatureType.CATEGORICAL : FieldSelection.FeatureType.NUMERICAL; FieldSelection.FeatureType.CATEGORICAL : FieldSelection.FeatureType.NUMERICAL;
@ -402,7 +527,38 @@ public class ExtractedFieldsDetector {
} }
} }
static Set<String> getCategoricalFields(ExtractedFields extractedFields, DataFrameAnalysis analysis) { static void checkOutputFeatureUniqueness(List<ProcessedField> processedFields, Set<String> selectedFields) {
Set<String> processInputs = processedFields.stream()
.map(ProcessedField::getInputFieldNames)
.flatMap(List::stream)
.collect(Collectors.toSet());
// All analysis fields that we include that are NOT processed
// This indicates that they are sent as is
Set<String> organicFields = Sets.difference(selectedFields, processInputs);
Set<String> processedFeatures = new HashSet<>();
Set<String> duplicatedFields = new HashSet<>();
for (ProcessedField processedField : processedFields) {
for (String output : processedField.getOutputFieldNames()) {
if (processedFeatures.add(output) == false) {
duplicatedFields.add(output);
}
}
}
if (duplicatedFields.isEmpty() == false) {
throw ExceptionsHelper.badRequestException(
"feature_processors must define unique output field names; duplicate fields {}",
duplicatedFields);
}
Set<String> duplicateOrganicAndProcessed = Sets.intersection(organicFields, processedFeatures);
if (duplicateOrganicAndProcessed.isEmpty() == false) {
throw ExceptionsHelper.badRequestException(
"feature_processors output fields must not include non-processed analysis fields; duplicate fields {}",
duplicateOrganicAndProcessed);
}
}
static Set<String> getCategoricalInputFields(ExtractedFields extractedFields, DataFrameAnalysis analysis) {
return extractedFields.getAllFields().stream() return extractedFields.getAllFields().stream()
.filter(extractedField -> analysis.getAllowedCategoricalTypes(extractedField.getName()) .filter(extractedField -> analysis.getAllowedCategoricalTypes(extractedField.getName())
.containsAll(extractedField.getTypes())) .containsAll(extractedField.getTypes()))
@ -410,6 +566,25 @@ public class ExtractedFieldsDetector {
.collect(Collectors.toSet()); .collect(Collectors.toSet());
} }
static Set<String> getCategoricalOutputFields(ExtractedFields extractedFields, DataFrameAnalysis analysis) {
Set<String> processInputFields = extractedFields.getProcessedFieldInputs();
Set<String> categoricalFields = extractedFields.getAllFields().stream()
.filter(extractedField -> analysis.getAllowedCategoricalTypes(extractedField.getName())
.containsAll(extractedField.getTypes()))
.map(ExtractedField::getName)
.filter(name -> processInputFields.contains(name) == false)
.collect(Collectors.toSet());
extractedFields.getProcessedFields().forEach(processedField ->
processedField.getOutputFieldNames().forEach(outputField -> {
if (analysis.getAllowedCategoricalTypes(outputField).containsAll(processedField.getOutputFieldType(outputField))) {
categoricalFields.add(outputField);
}
})
);
return Collections.unmodifiableSet(categoricalFields);
}
private static boolean isBoolean(Set<String> types) { private static boolean isBoolean(Set<String> types) {
return types.size() == 1 && types.contains(BooleanFieldMapper.CONTENT_TYPE); return types.size() == 1 && types.contains(BooleanFieldMapper.CONTENT_TYPE);
} }

View File

@ -178,7 +178,7 @@ public class AnalyticsProcessManager {
AnalyticsProcess<AnalyticsResult> process = processContext.process.get(); AnalyticsProcess<AnalyticsResult> process = processContext.process.get();
AnalyticsResultProcessor resultProcessor = processContext.resultProcessor.get(); AnalyticsResultProcessor resultProcessor = processContext.resultProcessor.get();
try { try {
writeHeaderRecord(dataExtractor, process); writeHeaderRecord(dataExtractor, process, task);
writeDataRows(dataExtractor, process, task); writeDataRows(dataExtractor, process, task);
process.writeEndOfDataMessage(); process.writeEndOfDataMessage();
process.flushStream(); process.flushStream();
@ -268,8 +268,11 @@ public class AnalyticsProcessManager {
} }
} }
private void writeHeaderRecord(DataFrameDataExtractor dataExtractor, AnalyticsProcess<AnalyticsResult> process) throws IOException { private void writeHeaderRecord(DataFrameDataExtractor dataExtractor,
AnalyticsProcess<AnalyticsResult> process,
DataFrameAnalyticsTask task) throws IOException {
List<String> fieldNames = dataExtractor.getFieldNames(); List<String> fieldNames = dataExtractor.getFieldNames();
LOGGER.debug(() -> new ParameterizedMessage("[{}] header row fields {}", task.getParams().getId(), fieldNames));
// We add 2 extra fields, both named dot: // We add 2 extra fields, both named dot:
// - the document hash // - the document hash

View File

@ -9,6 +9,7 @@ package org.elasticsearch.xpack.ml.dataframe.process;
import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage; import org.apache.logging.log4j.message.ParameterizedMessage;
import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.Version; import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.LatchedActionListener; import org.elasticsearch.action.LatchedActionListener;
@ -22,6 +23,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.security.user.XPackUser; import org.elasticsearch.xpack.core.security.user.XPackUser;
import org.elasticsearch.xpack.ml.dataframe.process.results.TrainedModelDefinitionChunk; import org.elasticsearch.xpack.ml.dataframe.process.results.TrainedModelDefinitionChunk;
@ -34,6 +36,7 @@ import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
import java.time.Instant; import java.time.Instant;
import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@ -191,8 +194,21 @@ public class ChunkedTrainedModelPersister {
return latch; return latch;
} }
private long customProcessorSize() {
List<PreProcessor> preProcessors = new ArrayList<>();
if (analytics.getAnalysis() instanceof Classification) {
preProcessors = ((Classification) analytics.getAnalysis()).getFeatureProcessors();
} else if (analytics.getAnalysis() instanceof Regression) {
preProcessors = ((Regression) analytics.getAnalysis()).getFeatureProcessors();
}
return preProcessors.stream().mapToLong(PreProcessor::ramBytesUsed).sum()
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF * preProcessors.size();
}
private TrainedModelConfig createTrainedModelConfig(ModelSizeInfo modelSize) { private TrainedModelConfig createTrainedModelConfig(ModelSizeInfo modelSize) {
Instant createTime = Instant.now(); Instant createTime = Instant.now();
// The native process does not provide estimates for the custom feature_processor objects
long customProcessorSize = customProcessorSize();
String modelId = analytics.getId() + "-" + createTime.toEpochMilli(); String modelId = analytics.getId() + "-" + createTime.toEpochMilli();
currentModelId.set(modelId); currentModelId.set(modelId);
List<ExtractedField> fieldNames = extractedFields.getAllFields(); List<ExtractedField> fieldNames = extractedFields.getAllFields();
@ -214,7 +230,7 @@ public class ChunkedTrainedModelPersister {
.setDescription(analytics.getDescription()) .setDescription(analytics.getDescription())
.setMetadata(Collections.singletonMap("analytics_config", .setMetadata(Collections.singletonMap("analytics_config",
XContentHelper.convertToMap(JsonXContent.jsonXContent, analytics.toString(), true))) XContentHelper.convertToMap(JsonXContent.jsonXContent, analytics.toString(), true)))
.setEstimatedHeapMemory(modelSize.ramBytesUsed()) .setEstimatedHeapMemory(modelSize.ramBytesUsed() + customProcessorSize)
.setEstimatedOperations(modelSize.numOperations()) .setEstimatedOperations(modelSize.numOperations())
.setInput(new TrainedModelInput(fieldNamesWithoutDependentVariable)) .setInput(new TrainedModelInput(fieldNamesWithoutDependentVariable))
.setLicenseLevel(License.OperationMode.PLATINUM.description()) .setLicenseLevel(License.OperationMode.PLATINUM.description())

View File

@ -12,7 +12,7 @@ import org.elasticsearch.index.mapper.BooleanFieldMapper;
import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHit;
import org.elasticsearch.xpack.core.ml.utils.MlStrings; import org.elasticsearch.xpack.core.ml.utils.MlStrings;
import java.util.Collection; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@ -21,27 +21,39 @@ import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
/** /**
* The fields the datafeed has to extract * The fields the data[feed|frame] has to extract
*/ */
public class ExtractedFields { public class ExtractedFields {
private final List<ExtractedField> allFields; private final List<ExtractedField> allFields;
private final List<ExtractedField> docValueFields; private final List<ExtractedField> docValueFields;
private final List<ProcessedField> processedFields;
private final String[] sourceFields; private final String[] sourceFields;
private final Map<String, Long> cardinalitiesForFieldsWithConstraints; private final Map<String, Long> cardinalitiesForFieldsWithConstraints;
public ExtractedFields(List<ExtractedField> allFields, Map<String, Long> cardinalitiesForFieldsWithConstraints) { public ExtractedFields(List<ExtractedField> allFields,
this.allFields = Collections.unmodifiableList(allFields); List<ProcessedField> processedFields,
Map<String, Long> cardinalitiesForFieldsWithConstraints) {
this.allFields = new ArrayList<>(allFields);
this.docValueFields = filterFields(ExtractedField.Method.DOC_VALUE, allFields); this.docValueFields = filterFields(ExtractedField.Method.DOC_VALUE, allFields);
this.sourceFields = filterFields(ExtractedField.Method.SOURCE, allFields).stream().map(ExtractedField::getSearchField) this.sourceFields = filterFields(ExtractedField.Method.SOURCE, allFields).stream().map(ExtractedField::getSearchField)
.toArray(String[]::new); .toArray(String[]::new);
this.cardinalitiesForFieldsWithConstraints = Collections.unmodifiableMap(cardinalitiesForFieldsWithConstraints); this.cardinalitiesForFieldsWithConstraints = Collections.unmodifiableMap(cardinalitiesForFieldsWithConstraints);
this.processedFields = processedFields == null ? Collections.emptyList() : processedFields;
}
public List<ProcessedField> getProcessedFields() {
return processedFields;
} }
public List<ExtractedField> getAllFields() { public List<ExtractedField> getAllFields() {
return allFields; return allFields;
} }
public Set<String> getProcessedFieldInputs() {
return processedFields.stream().map(ProcessedField::getInputFieldNames).flatMap(List::stream).collect(Collectors.toSet());
}
public String[] getSourceFields() { public String[] getSourceFields() {
return sourceFields; return sourceFields;
} }
@ -58,11 +70,15 @@ public class ExtractedFields {
return fields.stream().filter(field -> field.getMethod() == method).collect(Collectors.toList()); return fields.stream().filter(field -> field.getMethod() == method).collect(Collectors.toList());
} }
public static ExtractedFields build(Collection<String> allFields, Set<String> scriptFields, public static ExtractedFields build(Set<String> allFields,
Set<String> scriptFields,
FieldCapabilitiesResponse fieldsCapabilities, FieldCapabilitiesResponse fieldsCapabilities,
Map<String, Long> cardinalitiesForFieldsWithConstraints) { Map<String, Long> cardinalitiesForFieldsWithConstraints,
List<ProcessedField> processedFields) {
ExtractionMethodDetector extractionMethodDetector = new ExtractionMethodDetector(scriptFields, fieldsCapabilities); ExtractionMethodDetector extractionMethodDetector = new ExtractionMethodDetector(scriptFields, fieldsCapabilities);
return new ExtractedFields(allFields.stream().map(field -> extractionMethodDetector.detect(field)).collect(Collectors.toList()), return new ExtractedFields(
allFields.stream().map(extractionMethodDetector::detect).collect(Collectors.toList()),
processedFields,
cardinalitiesForFieldsWithConstraints); cardinalitiesForFieldsWithConstraints);
} }

View File

@ -0,0 +1,62 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.extractor;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;
public class ProcessedField {
private final PreProcessor preProcessor;
public ProcessedField(PreProcessor processor) {
this.preProcessor = Objects.requireNonNull(processor);
}
public List<String> getInputFieldNames() {
return preProcessor.inputFields();
}
public List<String> getOutputFieldNames() {
return preProcessor.outputFields();
}
public Set<String> getOutputFieldType(String outputField) {
return Collections.singleton(preProcessor.getOutputFieldType(outputField));
}
public Object[] value(SearchHit hit, Function<String, ExtractedField> fieldExtractor) {
Map<String, Object> inputs = new HashMap<>(preProcessor.inputFields().size(), 1.0f);
for (String field : preProcessor.inputFields()) {
ExtractedField extractedField = fieldExtractor.apply(field);
if (extractedField == null) {
return new Object[0];
}
Object[] values = extractedField.value(hit);
if (values == null || values.length == 0) {
continue;
}
final Object value = values[0];
if (values.length == 1 && (value instanceof String || value instanceof Number)) {
inputs.put(field, value);
}
}
preProcessor.process(inputs);
return preProcessor.outputFields().stream().map(inputs::get).toArray();
}
public String getProcessorName() {
return preProcessor.getName();
}
}

View File

@ -128,4 +128,11 @@ public abstract class MlSingleNodeTestCase extends ESSingleNodeTestCase {
return responseHolder.get(); return responseHolder.get();
} }
public static void assertNoException(AtomicReference<Exception> error) throws Exception {
if (error.get() == null) {
return;
}
throw error.get();
}
} }

View File

@ -15,8 +15,10 @@ import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.search.ShardSearchFailure; import org.elasticsearch.action.search.ShardSearchFailure;
import org.elasticsearch.client.Client; import org.elasticsearch.client.Client;
import org.elasticsearch.common.Nullable; import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.rest.RestStatus; import org.elasticsearch.rest.RestStatus;
@ -27,10 +29,13 @@ import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetectionTests; import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetectionTests;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
import org.elasticsearch.xpack.ml.dataframe.traintestsplit.TrainTestSplitterFactory; import org.elasticsearch.xpack.ml.dataframe.traintestsplit.TrainTestSplitterFactory;
import org.elasticsearch.xpack.ml.extractor.DocValueField; import org.elasticsearch.xpack.ml.extractor.DocValueField;
import org.elasticsearch.xpack.ml.extractor.ExtractedField; import org.elasticsearch.xpack.ml.extractor.ExtractedField;
import org.elasticsearch.xpack.ml.extractor.ExtractedFields; import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
import org.elasticsearch.xpack.ml.extractor.ProcessedField;
import org.elasticsearch.xpack.ml.extractor.SourceField; import org.elasticsearch.xpack.ml.extractor.SourceField;
import org.elasticsearch.xpack.ml.test.SearchHitBuilder; import org.elasticsearch.xpack.ml.test.SearchHitBuilder;
import org.junit.Before; import org.junit.Before;
@ -45,8 +50,10 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.Queue; import java.util.Queue;
import java.util.function.Function;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.empty;
@ -83,7 +90,9 @@ public class DataFrameDataExtractorTests extends ESTestCase {
query = QueryBuilders.matchAllQuery(); query = QueryBuilders.matchAllQuery();
extractedFields = new ExtractedFields(Arrays.asList( extractedFields = new ExtractedFields(Arrays.asList(
new DocValueField("field_1", Collections.singleton("keyword")), new DocValueField("field_1", Collections.singleton("keyword")),
new DocValueField("field_2", Collections.singleton("keyword"))), Collections.emptyMap()); new DocValueField("field_2", Collections.singleton("keyword"))),
Collections.emptyList(),
Collections.emptyMap());
scrollSize = 1000; scrollSize = 1000;
headers = Collections.emptyMap(); headers = Collections.emptyMap();
@ -304,7 +313,9 @@ public class DataFrameDataExtractorTests extends ESTestCase {
// Explicit cast of ExtractedField args necessary for Eclipse due to https://bugs.eclipse.org/bugs/show_bug.cgi?id=530915 // Explicit cast of ExtractedField args necessary for Eclipse due to https://bugs.eclipse.org/bugs/show_bug.cgi?id=530915
extractedFields = new ExtractedFields(Arrays.asList( extractedFields = new ExtractedFields(Arrays.asList(
(ExtractedField) new DocValueField("field_1", Collections.singleton("keyword")), (ExtractedField) new DocValueField("field_1", Collections.singleton("keyword")),
(ExtractedField) new SourceField("field_2", Collections.singleton("text"))), Collections.emptyMap()); (ExtractedField) new SourceField("field_2", Collections.singleton("text"))),
Collections.emptyList(),
Collections.emptyMap());
TestExtractor dataExtractor = createExtractor(false, false); TestExtractor dataExtractor = createExtractor(false, false);
@ -445,7 +456,9 @@ public class DataFrameDataExtractorTests extends ESTestCase {
(ExtractedField) new DocValueField("field_integer", Collections.singleton("integer")), (ExtractedField) new DocValueField("field_integer", Collections.singleton("integer")),
(ExtractedField) new DocValueField("field_long", Collections.singleton("long")), (ExtractedField) new DocValueField("field_long", Collections.singleton("long")),
(ExtractedField) new DocValueField("field_keyword", Collections.singleton("keyword")), (ExtractedField) new DocValueField("field_keyword", Collections.singleton("keyword")),
(ExtractedField) new SourceField("field_text", Collections.singleton("text"))), Collections.emptyMap()); (ExtractedField) new SourceField("field_text", Collections.singleton("text"))),
Collections.emptyList(),
Collections.emptyMap());
TestExtractor dataExtractor = createExtractor(true, true); TestExtractor dataExtractor = createExtractor(true, true);
assertThat(dataExtractor.getCategoricalFields(OutlierDetectionTests.createRandom()), empty()); assertThat(dataExtractor.getCategoricalFields(OutlierDetectionTests.createRandom()), empty());
@ -465,12 +478,100 @@ public class DataFrameDataExtractorTests extends ESTestCase {
containsInAnyOrder("field_keyword", "field_text", "field_boolean")); containsInAnyOrder("field_keyword", "field_text", "field_boolean"));
} }
public void testGetFieldNames_GivenProcessesFeatures() {
// Explicit cast of ExtractedField args necessary for Eclipse due to https://bugs.eclipse.org/bugs/show_bug.cgi?id=530915
extractedFields = new ExtractedFields(Arrays.asList(
(ExtractedField) new DocValueField("field_boolean", Collections.singleton("boolean")),
(ExtractedField) new DocValueField("field_float", Collections.singleton("float")),
(ExtractedField) new DocValueField("field_double", Collections.singleton("double")),
(ExtractedField) new DocValueField("field_byte", Collections.singleton("byte")),
(ExtractedField) new DocValueField("field_short", Collections.singleton("short")),
(ExtractedField) new DocValueField("field_integer", Collections.singleton("integer")),
(ExtractedField) new DocValueField("field_long", Collections.singleton("long")),
(ExtractedField) new DocValueField("field_keyword", Collections.singleton("keyword")),
(ExtractedField) new SourceField("field_text", Collections.singleton("text"))),
Arrays.asList(
new ProcessedField(new CategoricalPreProcessor("field_long", "animal")),
buildProcessedField("field_short", "field_1", "field_2")
),
Collections.emptyMap());
TestExtractor dataExtractor = createExtractor(true, true);
assertThat(dataExtractor.getCategoricalFields(new Regression("field_double")),
containsInAnyOrder("field_keyword", "field_text", "animal"));
List<String> fieldNames = dataExtractor.getFieldNames();
assertThat(fieldNames, containsInAnyOrder(
"animal",
"field_1",
"field_2",
"field_boolean",
"field_float",
"field_double",
"field_byte",
"field_integer",
"field_keyword",
"field_text"));
assertThat(dataExtractor.getFieldNames(), contains(fieldNames.toArray(new String[0])));
}
public void testExtractionWithProcessedFeatures() throws IOException {
extractedFields = new ExtractedFields(Arrays.asList(
new DocValueField("field_1", Collections.singleton("keyword")),
new DocValueField("field_2", Collections.singleton("keyword"))),
Arrays.asList(
new ProcessedField(new CategoricalPreProcessor("field_1", "animal")),
new ProcessedField(new OneHotEncoding("field_1",
Arrays.asList("11", "12")
.stream()
.collect(Collectors.toMap(Function.identity(), s -> s.equals("11") ? "field_11" : "field_12")),
true))
),
Collections.emptyMap());
TestExtractor dataExtractor = createExtractor(true, true);
// First and only batch
SearchResponse response1 = createSearchResponse(Arrays.asList(1_1, null, 1_3), Arrays.asList(2_1, 2_2, 2_3));
dataExtractor.setNextResponse(response1);
// Empty
SearchResponse lastAndEmptyResponse = createEmptySearchResponse();
dataExtractor.setNextResponse(lastAndEmptyResponse);
assertThat(dataExtractor.hasNext(), is(true));
// First batch
Optional<List<DataFrameDataExtractor.Row>> rows = dataExtractor.next();
assertThat(rows.isPresent(), is(true));
assertThat(rows.get().size(), equalTo(3));
assertThat(rows.get().get(0).getValues(), equalTo(new String[] {"21", "dog", "1", "0"}));
assertThat(rows.get().get(1).getValues(),
equalTo(new String[] {"22", "dog", DataFrameDataExtractor.NULL_VALUE, DataFrameDataExtractor.NULL_VALUE}));
assertThat(rows.get().get(2).getValues(), equalTo(new String[] {"23", "dog", "0", "0"}));
assertThat(rows.get().get(0).shouldSkip(), is(false));
assertThat(rows.get().get(1).shouldSkip(), is(false));
assertThat(rows.get().get(2).shouldSkip(), is(false));
}
private TestExtractor createExtractor(boolean includeSource, boolean supportsRowsWithMissingValues) { private TestExtractor createExtractor(boolean includeSource, boolean supportsRowsWithMissingValues) {
DataFrameDataExtractorContext context = new DataFrameDataExtractorContext(JOB_ID, extractedFields, indices, query, scrollSize, DataFrameDataExtractorContext context = new DataFrameDataExtractorContext(JOB_ID, extractedFields, indices, query, scrollSize,
headers, includeSource, supportsRowsWithMissingValues, trainTestSplitterFactory); headers, includeSource, supportsRowsWithMissingValues, trainTestSplitterFactory);
return new TestExtractor(client, context); return new TestExtractor(client, context);
} }
private static ProcessedField buildProcessedField(String inputField, String... outputFields) {
return new ProcessedField(buildPreProcessor(inputField, outputFields));
}
private static PreProcessor buildPreProcessor(String inputField, String... outputFields) {
return new OneHotEncoding(inputField,
Arrays.stream(outputFields).collect(Collectors.toMap((s) -> randomAlphaOfLength(10), Function.identity())),
true);
}
private SearchResponse createSearchResponse(List<Number> field1Values, List<Number> field2Values) { private SearchResponse createSearchResponse(List<Number> field1Values, List<Number> field2Values) {
assertThat(field1Values.size(), equalTo(field2Values.size())); assertThat(field1Values.size(), equalTo(field2Values.size()));
SearchResponse searchResponse = mock(SearchResponse.class); SearchResponse searchResponse = mock(SearchResponse.class);
@ -544,4 +645,70 @@ public class DataFrameDataExtractorTests extends ESTestCase {
return searchResponse; return searchResponse;
} }
} }
private static class CategoricalPreProcessor implements PreProcessor {
private final List<String> inputFields;
private final List<String> outputFields;
CategoricalPreProcessor(String inputField, String outputField) {
this.inputFields = Arrays.asList(inputField);
this.outputFields = Arrays.asList(outputField);
}
@Override
public List<String> inputFields() {
return inputFields;
}
@Override
public List<String> outputFields() {
return outputFields;
}
@Override
public void process(Map<String, Object> fields) {
fields.put(outputFields.get(0), "dog");
}
@Override
public Map<String, String> reverseLookup() {
return null;
}
@Override
public boolean isCustom() {
return true;
}
@Override
public String getOutputFieldType(String outputField) {
return "text";
}
@Override
public long ramBytesUsed() {
return 0;
}
@Override
public String getWriteableName() {
return null;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
}
@Override
public String getName() {
return null;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
return null;
}
}
} }

View File

@ -15,10 +15,13 @@ import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection; import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.elasticsearch.xpack.core.ml.dataframe.explain.FieldSelection; import org.elasticsearch.xpack.core.ml.dataframe.explain.FieldSelection;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
import org.elasticsearch.xpack.ml.extractor.ExtractedField; import org.elasticsearch.xpack.ml.extractor.ExtractedField;
import org.elasticsearch.xpack.ml.extractor.ExtractedFields; import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
import org.elasticsearch.xpack.ml.test.SearchHitBuilder; import org.elasticsearch.xpack.ml.test.SearchHitBuilder;
@ -30,11 +33,14 @@ import java.util.HashMap;
import java.util.HashSet; import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import static org.hamcrest.Matchers.arrayContaining; import static org.hamcrest.Matchers.arrayContaining;
import static org.hamcrest.Matchers.arrayContainingInAnyOrder;
import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.is;
@ -929,12 +935,23 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
assertThat(e.getMessage(), equalTo("analyzed_fields must not include or exclude object fields: [object_field]")); assertThat(e.getMessage(), equalTo("analyzed_fields must not include or exclude object fields: [object_field]"));
} }
private static FieldCapabilitiesResponse simpleFieldResponse() {
return new MockFieldCapsResponseBuilder()
.addAggregatableField("field_11", "float")
.addNonAggregatableField("field_21", "float")
.addAggregatableField("field_21.child", "float")
.addNonAggregatableField("field_31", "float")
.addAggregatableField("field_31.child", "float")
.addNonAggregatableField("object_field", "object")
.build();
}
public void testDetect_GivenAnalyzedFieldExcludesObjectField() { public void testDetect_GivenAnalyzedFieldExcludesObjectField() {
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
.addAggregatableField("float_field", "float") .addAggregatableField("float_field", "float")
.addNonAggregatableField("object_field", "object").build(); .addNonAggregatableField("object_field", "object").build();
analyzedFields = new FetchSourceContext(true, null, new String[] { "object_field" }); analyzedFields = new FetchSourceContext(true, null, new String[]{"object_field"});
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
buildOutlierDetectionConfig(), 100, fieldCapabilities, Collections.emptyMap()); buildOutlierDetectionConfig(), 100, fieldCapabilities, Collections.emptyMap());
@ -943,6 +960,177 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
assertThat(e.getMessage(), equalTo("analyzed_fields must not include or exclude object fields: [object_field]")); assertThat(e.getMessage(), equalTo("analyzed_fields must not include or exclude object fields: [object_field]"));
} }
public void testDetect_givenFeatureProcessorsFailures_ResultsField() {
FieldCapabilitiesResponse fieldCapabilities = simpleFieldResponse();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
buildRegressionConfig("field_31", Arrays.asList(buildPreProcessor("ml.result", "foo"))),
100,
fieldCapabilities,
Collections.emptyMap());
ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect);
assertThat(ex.getMessage(),
containsString("fields contained in results field [ml] cannot be used in a feature_processor"));
}
public void testDetect_givenFeatureProcessorsFailures_Objects() {
FieldCapabilitiesResponse fieldCapabilities = simpleFieldResponse();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
buildRegressionConfig("field_31", Arrays.asList(buildPreProcessor("object_field", "foo"))),
100,
fieldCapabilities,
Collections.emptyMap());
ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect);
assertThat(ex.getMessage(),
containsString("fields for feature_processors must not be objects"));
}
public void testDetect_givenFeatureProcessorsFailures_ReservedFields() {
FieldCapabilitiesResponse fieldCapabilities = simpleFieldResponse();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
buildRegressionConfig("field_31", Arrays.asList(buildPreProcessor("_id", "foo"))),
100,
fieldCapabilities,
Collections.emptyMap());
ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect);
assertThat(ex.getMessage(),
containsString("the following fields cannot be used in feature_processors"));
}
public void testDetect_givenFeatureProcessorsFailures_MissingFieldFromIndex() {
FieldCapabilitiesResponse fieldCapabilities = simpleFieldResponse();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
buildRegressionConfig("field_31", Arrays.asList(buildPreProcessor("bar", "foo"))),
100,
fieldCapabilities,
Collections.emptyMap());
ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect);
assertThat(ex.getMessage(),
containsString("the fields [bar] were not found in the field capabilities of the source indices"));
}
public void testDetect_givenFeatureProcessorsFailures_UsingRequiredField() {
FieldCapabilitiesResponse fieldCapabilities = simpleFieldResponse();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
buildRegressionConfig("field_31", Arrays.asList(buildPreProcessor("field_31", "foo"))),
100,
fieldCapabilities,
Collections.emptyMap());
ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect);
assertThat(ex.getMessage(),
containsString("required analysis fields [field_31] cannot be used in a feature_processor"));
}
public void testDetect_givenFeatureProcessorsFailures_BadSourceFiltering() {
FieldCapabilitiesResponse fieldCapabilities = simpleFieldResponse();
sourceFiltering = new FetchSourceContext(true, null, new String[]{"field_1*"});
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
buildRegressionConfig("field_31", Arrays.asList(buildPreProcessor("field_11", "foo"))),
100,
fieldCapabilities,
Collections.emptyMap());
ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect);
assertThat(ex.getMessage(),
containsString("fields [field_11] required by field_processors are not included in source filtering."));
}
public void testDetect_givenFeatureProcessorsFailures_MissingAnalyzedField() {
FieldCapabilitiesResponse fieldCapabilities = simpleFieldResponse();
analyzedFields = new FetchSourceContext(true, null, new String[]{"field_1*"});
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
buildRegressionConfig("field_31", Arrays.asList(buildPreProcessor("field_11", "foo"))),
100,
fieldCapabilities,
Collections.emptyMap());
ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect);
assertThat(ex.getMessage(),
containsString("fields [field_11] required by field_processors are not included in the analyzed_fields"));
}
public void testDetect_givenFeatureProcessorsFailures_RequiredMultiFields() {
FieldCapabilitiesResponse fieldCapabilities = simpleFieldResponse();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
buildRegressionConfig("field_31", Arrays.asList(buildPreProcessor("field_31.child", "foo"))),
100,
fieldCapabilities,
Collections.emptyMap());
ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect);
assertThat(ex.getMessage(),
containsString("feature_processors cannot be applied to required fields for analysis; "));
extractedFieldsDetector = new ExtractedFieldsDetector(
buildRegressionConfig("field_31.child", Arrays.asList(buildPreProcessor("field_31", "foo"))),
100,
fieldCapabilities,
Collections.emptyMap());
ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect);
assertThat(ex.getMessage(),
containsString("feature_processors cannot be applied to required fields for analysis; "));
}
public void testDetect_givenFeatureProcessorsFailures_BothMultiFields() {
FieldCapabilitiesResponse fieldCapabilities = simpleFieldResponse();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
buildRegressionConfig("field_31",
Arrays.asList(
buildPreProcessor("field_21", "foo"),
buildPreProcessor("field_21.child", "bar")
)),
100,
fieldCapabilities,
Collections.emptyMap());
ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect);
assertThat(ex.getMessage(),
containsString("feature_processors refer to both multi-field "));
}
public void testDetect_givenFeatureProcessorsFailures_DuplicateOutputFields() {
FieldCapabilitiesResponse fieldCapabilities = simpleFieldResponse();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
buildRegressionConfig("field_31",
Arrays.asList(
buildPreProcessor("field_11", "foo"),
buildPreProcessor("field_21", "foo")
)),
100,
fieldCapabilities,
Collections.emptyMap());
ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect);
assertThat(ex.getMessage(),
containsString("feature_processors must define unique output field names; duplicate fields [foo]"));
}
public void testDetect_withFeatureProcessors() {
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
.addAggregatableField("field_11", "float")
.addAggregatableField("field_21", "float")
.addNonAggregatableField("field_31", "float")
.addAggregatableField("field_31.child", "float")
.addNonAggregatableField("object_field", "object")
.build();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
buildRegressionConfig("field_11",
Arrays.asList(buildPreProcessor("field_31", "foo", "bar"))),
100,
fieldCapabilities,
Collections.emptyMap());
ExtractedFields extracted = extractedFieldsDetector.detect().v1();
assertThat(extracted.getProcessedFieldInputs(), containsInAnyOrder("field_31"));
assertThat(extracted.getAllFields().stream().map(ExtractedField::getName).collect(Collectors.toSet()),
containsInAnyOrder("field_11", "field_21", "field_31"));
assertThat(extracted.getSourceFields(), arrayContainingInAnyOrder("field_31"));
assertThat(extracted.getDocValueFields().stream().map(ExtractedField::getName).collect(Collectors.toSet()),
containsInAnyOrder("field_21", "field_11"));
assertThat(extracted.getProcessedFields(), hasSize(1));
}
private DataFrameAnalyticsConfig buildOutlierDetectionConfig() { private DataFrameAnalyticsConfig buildOutlierDetectionConfig() {
return new DataFrameAnalyticsConfig.Builder() return new DataFrameAnalyticsConfig.Builder()
.setId("foo") .setId("foo")
@ -954,13 +1142,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
} }
private DataFrameAnalyticsConfig buildRegressionConfig(String dependentVariable) { private DataFrameAnalyticsConfig buildRegressionConfig(String dependentVariable) {
return new DataFrameAnalyticsConfig.Builder() return buildRegressionConfig(dependentVariable, Collections.emptyList());
.setId("foo")
.setSource(new DataFrameAnalyticsSource(SOURCE_INDEX, null, sourceFiltering))
.setDest(new DataFrameAnalyticsDest(DEST_INDEX, RESULTS_FIELD))
.setAnalyzedFields(analyzedFields)
.setAnalysis(new Regression(dependentVariable))
.build();
} }
private DataFrameAnalyticsConfig buildClassificationConfig(String dependentVariable) { private DataFrameAnalyticsConfig buildClassificationConfig(String dependentVariable) {
@ -972,6 +1154,29 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
.build(); .build();
} }
private DataFrameAnalyticsConfig buildRegressionConfig(String dependentVariable, List<PreProcessor> featureprocessors) {
return new DataFrameAnalyticsConfig.Builder()
.setId("foo")
.setSource(new DataFrameAnalyticsSource(SOURCE_INDEX, null, sourceFiltering))
.setDest(new DataFrameAnalyticsDest(DEST_INDEX, RESULTS_FIELD))
.setAnalyzedFields(analyzedFields)
.setAnalysis(new Regression(dependentVariable,
BoostedTreeParams.builder().build(),
null,
null,
null,
null,
null,
featureprocessors))
.build();
}
private static PreProcessor buildPreProcessor(String inputField, String... outputFields) {
return new OneHotEncoding(inputField,
Arrays.stream(outputFields).collect(Collectors.toMap((s) -> randomAlphaOfLength(10), Function.identity())),
true);
}
/** /**
* We assert each field individually to get useful error messages in case of failure * We assert each field individually to get useful error messages in case of failure
*/ */
@ -987,6 +1192,23 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
} }
} }
public void testDetect_givenFeatureProcessorsFailures_DuplicateOutputFieldsWithUnProcessedField() {
FieldCapabilitiesResponse fieldCapabilities = simpleFieldResponse();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
buildRegressionConfig("field_31",
Arrays.asList(
buildPreProcessor("field_11", "field_21")
)),
100,
fieldCapabilities,
Collections.emptyMap());
ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect);
assertThat(ex.getMessage(),
containsString(
"feature_processors output fields must not include non-processed analysis fields; duplicate fields [field_21]"));
}
private static class MockFieldCapsResponseBuilder { private static class MockFieldCapsResponseBuilder {
private final Map<String, Map<String, FieldCapabilities>> fieldCaps = new HashMap<>(); private final Map<String, Map<String, FieldCapabilities>> fieldCaps = new HashMap<>();

View File

@ -80,6 +80,7 @@ public class InferenceRunnerTests extends ESTestCase {
public void testInferTestDocs() { public void testInferTestDocs() {
ExtractedFields extractedFields = new ExtractedFields( ExtractedFields extractedFields = new ExtractedFields(
Collections.singletonList(new SourceField("key", Collections.singleton("integer"))), Collections.singletonList(new SourceField("key", Collections.singleton("integer"))),
Collections.emptyList(),
Collections.emptyMap()); Collections.emptyMap());
Map<String, Object> doc1 = new HashMap<>(); Map<String, Object> doc1 = new HashMap<>();

View File

@ -63,7 +63,9 @@ public class AnalyticsProcessConfigTests extends ESTestCase {
public void testToXContent_GivenOutlierDetection() throws IOException { public void testToXContent_GivenOutlierDetection() throws IOException {
ExtractedFields extractedFields = new ExtractedFields(Arrays.asList( ExtractedFields extractedFields = new ExtractedFields(Arrays.asList(
new DocValueField("field_1", Collections.singleton("double")), new DocValueField("field_1", Collections.singleton("double")),
new DocValueField("field_2", Collections.singleton("float"))), Collections.emptyMap()); new DocValueField("field_2", Collections.singleton("float"))),
Collections.emptyList(),
Collections.emptyMap());
DataFrameAnalysis analysis = new OutlierDetection.Builder().build(); DataFrameAnalysis analysis = new OutlierDetection.Builder().build();
AnalyticsProcessConfig processConfig = createProcessConfig(analysis, extractedFields); AnalyticsProcessConfig processConfig = createProcessConfig(analysis, extractedFields);
@ -82,7 +84,9 @@ public class AnalyticsProcessConfigTests extends ESTestCase {
ExtractedFields extractedFields = new ExtractedFields(Arrays.asList( ExtractedFields extractedFields = new ExtractedFields(Arrays.asList(
new DocValueField("field_1", Collections.singleton("double")), new DocValueField("field_1", Collections.singleton("double")),
new DocValueField("field_2", Collections.singleton("float")), new DocValueField("field_2", Collections.singleton("float")),
new DocValueField("test_dep_var", Collections.singleton("keyword"))), Collections.emptyMap()); new DocValueField("test_dep_var", Collections.singleton("keyword"))),
Collections.emptyList(),
Collections.emptyMap());
DataFrameAnalysis analysis = new Regression("test_dep_var"); DataFrameAnalysis analysis = new Regression("test_dep_var");
AnalyticsProcessConfig processConfig = createProcessConfig(analysis, extractedFields); AnalyticsProcessConfig processConfig = createProcessConfig(analysis, extractedFields);
@ -103,7 +107,9 @@ public class AnalyticsProcessConfigTests extends ESTestCase {
ExtractedFields extractedFields = new ExtractedFields(Arrays.asList( ExtractedFields extractedFields = new ExtractedFields(Arrays.asList(
new DocValueField("field_1", Collections.singleton("double")), new DocValueField("field_1", Collections.singleton("double")),
new DocValueField("field_2", Collections.singleton("float")), new DocValueField("field_2", Collections.singleton("float")),
new DocValueField("test_dep_var", Collections.singleton("keyword"))), Collections.singletonMap("test_dep_var", 5L)); new DocValueField("test_dep_var", Collections.singleton("keyword"))),
Collections.emptyList(),
Collections.singletonMap("test_dep_var", 5L));
DataFrameAnalysis analysis = new Classification("test_dep_var"); DataFrameAnalysis analysis = new Classification("test_dep_var");
AnalyticsProcessConfig processConfig = createProcessConfig(analysis, extractedFields); AnalyticsProcessConfig processConfig = createProcessConfig(analysis, extractedFields);
@ -126,7 +132,9 @@ public class AnalyticsProcessConfigTests extends ESTestCase {
ExtractedFields extractedFields = new ExtractedFields(Arrays.asList( ExtractedFields extractedFields = new ExtractedFields(Arrays.asList(
new DocValueField("field_1", Collections.singleton("double")), new DocValueField("field_1", Collections.singleton("double")),
new DocValueField("field_2", Collections.singleton("float")), new DocValueField("field_2", Collections.singleton("float")),
new DocValueField("test_dep_var", Collections.singleton("integer"))), Collections.singletonMap("test_dep_var", 8L)); new DocValueField("test_dep_var", Collections.singleton("integer"))),
Collections.emptyList(),
Collections.singletonMap("test_dep_var", 8L));
DataFrameAnalysis analysis = new Classification("test_dep_var"); DataFrameAnalysis analysis = new Classification("test_dep_var");
AnalyticsProcessConfig processConfig = createProcessConfig(analysis, extractedFields); AnalyticsProcessConfig processConfig = createProcessConfig(analysis, extractedFields);

View File

@ -105,7 +105,9 @@ public class AnalyticsProcessManagerTests extends ESTestCase {
OutlierDetectionTests.createRandom()).build(); OutlierDetectionTests.createRandom()).build();
dataExtractor = mock(DataFrameDataExtractor.class); dataExtractor = mock(DataFrameDataExtractor.class);
when(dataExtractor.collectDataSummary()).thenReturn(new DataFrameDataExtractor.DataSummary(NUM_ROWS, NUM_COLS)); when(dataExtractor.collectDataSummary()).thenReturn(new DataFrameDataExtractor.DataSummary(NUM_ROWS, NUM_COLS));
when(dataExtractor.getExtractedFields()).thenReturn(new ExtractedFields(Collections.emptyList(), Collections.emptyMap())); when(dataExtractor.getExtractedFields()).thenReturn(new ExtractedFields(Collections.emptyList(),
Collections.emptyList(),
Collections.emptyMap()));
dataExtractorFactory = mock(DataFrameDataExtractorFactory.class); dataExtractorFactory = mock(DataFrameDataExtractorFactory.class);
when(dataExtractorFactory.newExtractor(anyBoolean())).thenReturn(dataExtractor); when(dataExtractorFactory.newExtractor(anyBoolean())).thenReturn(dataExtractor);
when(dataExtractorFactory.getExtractedFields()).thenReturn(mock(ExtractedFields.class)); when(dataExtractorFactory.getExtractedFields()).thenReturn(mock(ExtractedFields.class));

View File

@ -314,6 +314,6 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
trainedModelProvider, trainedModelProvider,
auditor, auditor,
statsPersister, statsPersister,
new ExtractedFields(fieldNames, Collections.emptyMap())); new ExtractedFields(fieldNames, Collections.emptyList(), Collections.emptyMap()));
} }
} }

View File

@ -144,7 +144,7 @@ public class ChunkedTrainedModelPersisterTests extends ESTestCase {
analyticsConfig, analyticsConfig,
auditor, auditor,
(unused)->{}, (unused)->{},
new ExtractedFields(fieldNames, Collections.emptyMap())); new ExtractedFields(fieldNames, Collections.emptyList(), Collections.emptyMap()));
} }
} }

View File

@ -16,6 +16,7 @@ import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.HashSet; import java.util.HashSet;
import java.util.Map; import java.util.Map;
import java.util.TreeSet;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.is;
@ -31,8 +32,10 @@ public class ExtractedFieldsTests extends ESTestCase {
ExtractedField scriptField2 = new ScriptField("scripted2"); ExtractedField scriptField2 = new ScriptField("scripted2");
ExtractedField sourceField1 = new SourceField("src1", Collections.singleton("text")); ExtractedField sourceField1 = new SourceField("src1", Collections.singleton("text"));
ExtractedField sourceField2 = new SourceField("src2", Collections.singleton("text")); ExtractedField sourceField2 = new SourceField("src2", Collections.singleton("text"));
ExtractedFields extractedFields = new ExtractedFields(Arrays.asList( ExtractedFields extractedFields = new ExtractedFields(
docValue1, docValue2, scriptField1, scriptField2, sourceField1, sourceField2), Collections.emptyMap()); Arrays.asList(docValue1, docValue2, scriptField1, scriptField2, sourceField1, sourceField2),
Collections.emptyList(),
Collections.emptyMap());
assertThat(extractedFields.getAllFields().size(), equalTo(6)); assertThat(extractedFields.getAllFields().size(), equalTo(6));
assertThat(extractedFields.getDocValueFields().stream().map(ExtractedField::getName).toArray(String[]::new), assertThat(extractedFields.getDocValueFields().stream().map(ExtractedField::getName).toArray(String[]::new),
@ -53,8 +56,11 @@ public class ExtractedFieldsTests extends ESTestCase {
when(fieldCapabilitiesResponse.getField("value")).thenReturn(valueCaps); when(fieldCapabilitiesResponse.getField("value")).thenReturn(valueCaps);
when(fieldCapabilitiesResponse.getField("airline")).thenReturn(airlineCaps); when(fieldCapabilitiesResponse.getField("airline")).thenReturn(airlineCaps);
ExtractedFields extractedFields = ExtractedFields.build(Arrays.asList("time", "value", "airline", "airport"), ExtractedFields extractedFields = ExtractedFields.build(new TreeSet<>(Arrays.asList("time", "value", "airline", "airport")),
new HashSet<>(Collections.singletonList("airport")), fieldCapabilitiesResponse, Collections.emptyMap()); new HashSet<>(Collections.singletonList("airport")),
fieldCapabilitiesResponse,
Collections.emptyMap(),
Collections.emptyList());
assertThat(extractedFields.getDocValueFields().size(), equalTo(2)); assertThat(extractedFields.getDocValueFields().size(), equalTo(2));
assertThat(extractedFields.getDocValueFields().get(0).getName(), equalTo("time")); assertThat(extractedFields.getDocValueFields().get(0).getName(), equalTo("time"));
@ -76,8 +82,8 @@ public class ExtractedFieldsTests extends ESTestCase {
when(fieldCapabilitiesResponse.getField("airport")).thenReturn(text); when(fieldCapabilitiesResponse.getField("airport")).thenReturn(text);
when(fieldCapabilitiesResponse.getField("airport.keyword")).thenReturn(keyword); when(fieldCapabilitiesResponse.getField("airport.keyword")).thenReturn(keyword);
ExtractedFields extractedFields = ExtractedFields.build(Arrays.asList("airline.text", "airport.keyword"), ExtractedFields extractedFields = ExtractedFields.build(new TreeSet<>(Arrays.asList("airline.text", "airport.keyword")),
Collections.emptySet(), fieldCapabilitiesResponse, Collections.emptyMap()); Collections.emptySet(), fieldCapabilitiesResponse, Collections.emptyMap(), Collections.emptyList());
assertThat(extractedFields.getDocValueFields().size(), equalTo(1)); assertThat(extractedFields.getDocValueFields().size(), equalTo(1));
assertThat(extractedFields.getDocValueFields().get(0).getName(), equalTo("airport.keyword")); assertThat(extractedFields.getDocValueFields().get(0).getName(), equalTo("airport.keyword"));
@ -112,14 +118,18 @@ public class ExtractedFieldsTests extends ESTestCase {
assertThat(mapped.getName(), equalTo(aBool.getName())); assertThat(mapped.getName(), equalTo(aBool.getName()));
assertThat(mapped.getMethod(), equalTo(aBool.getMethod())); assertThat(mapped.getMethod(), equalTo(aBool.getMethod()));
assertThat(mapped.supportsFromSource(), is(false)); assertThat(mapped.supportsFromSource(), is(false));
expectThrows(UnsupportedOperationException.class, () -> mapped.newFromSource()); expectThrows(UnsupportedOperationException.class, mapped::newFromSource);
} }
public void testBuildGivenFieldWithoutMappings() { public void testBuildGivenFieldWithoutMappings() {
FieldCapabilitiesResponse fieldCapabilitiesResponse = mock(FieldCapabilitiesResponse.class); FieldCapabilitiesResponse fieldCapabilitiesResponse = mock(FieldCapabilitiesResponse.class);
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> ExtractedFields.build( IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> ExtractedFields.build(
Collections.singletonList("value"), Collections.emptySet(), fieldCapabilitiesResponse, Collections.emptyMap())); Collections.singleton("value"),
Collections.emptySet(),
fieldCapabilitiesResponse,
Collections.emptyMap(),
Collections.emptyList()));
assertThat(e.getMessage(), equalTo("cannot retrieve field [value] because it has no mappings")); assertThat(e.getMessage(), equalTo("cannot retrieve field [value] because it has no mappings"));
} }

View File

@ -0,0 +1,76 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.extractor;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
import org.elasticsearch.xpack.ml.test.SearchHitBuilder;
import java.util.Arrays;
import java.util.Collections;
import java.util.function.Function;
import java.util.stream.Collectors;
import static org.hamcrest.Matchers.arrayContaining;
import static org.hamcrest.Matchers.emptyArray;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasItems;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.nullValue;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class ProcessedFieldTests extends ESTestCase {
public void testOneHotGetters() {
String inputField = "foo";
ProcessedField processedField = new ProcessedField(makePreProcessor(inputField, "bar", "baz"));
assertThat(processedField.getInputFieldNames(), hasItems(inputField));
assertThat(processedField.getOutputFieldNames(), hasItems("bar_column", "baz_column"));
assertThat(processedField.getOutputFieldType("bar_column"), equalTo(Collections.singleton("integer")));
assertThat(processedField.getOutputFieldType("baz_column"), equalTo(Collections.singleton("integer")));
assertThat(processedField.getProcessorName(), equalTo(OneHotEncoding.NAME.getPreferredName()));
}
public void testMissingExtractor() {
String inputField = "foo";
ProcessedField processedField = new ProcessedField(makePreProcessor(inputField, "bar", "baz"));
assertThat(processedField.value(makeHit(), (s) -> null), emptyArray());
}
public void testMissingInputValues() {
String inputField = "foo";
ExtractedField extractedField = makeExtractedField(new Object[0]);
ProcessedField processedField = new ProcessedField(makePreProcessor(inputField, "bar", "baz"));
assertThat(processedField.value(makeHit(), (s) -> extractedField), arrayContaining(is(nullValue()), is(nullValue())));
}
public void testProcessedField() {
ProcessedField processedField = new ProcessedField(makePreProcessor("foo", "bar", "baz"));
assertThat(processedField.value(makeHit(), (s) -> makeExtractedField(new Object[] { "bar" })), arrayContaining(1, 0));
assertThat(processedField.value(makeHit(), (s) -> makeExtractedField(new Object[] { "baz" })), arrayContaining(0, 1));
}
private static PreProcessor makePreProcessor(String inputField, String... expectedExtractedValues) {
return new OneHotEncoding(inputField,
Arrays.stream(expectedExtractedValues).collect(Collectors.toMap(Function.identity(), (s) -> s + "_column")),
true);
}
private static ExtractedField makeExtractedField(Object[] value) {
ExtractedField extractedField = mock(ExtractedField.class);
when(extractedField.value(any())).thenReturn(value);
return extractedField;
}
private static SearchHit makeHit() {
return new SearchHitBuilder(42).addField("a_keyword", "bar").build();
}
}